Refactor AVX code for separate compilation and runtime dispatch

This commit is contained in:
Ken Cooke 2016-02-06 10:55:23 -08:00
parent 08e299f724
commit 3d684dd0e7
2 changed files with 140 additions and 96 deletions

View file

@ -64,82 +64,7 @@ static const float crossfadeTable[HRTF_BLOCK] = {
//
#if defined(_M_IX86) || defined(_M_X64) || defined(__i386__) || defined(__x86_64__)
#include <immintrin.h> // AVX
#if 0 // AVX disabled for now..
// 1 channel input, 4 channel output
static void FIR_1x4_AVX(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames) {
float* coef0 = coef[0] + HRTF_TAPS - 1; // process backwards
float* coef1 = coef[1] + HRTF_TAPS - 1;
float* coef2 = coef[2] + HRTF_TAPS - 1;
float* coef3 = coef[3] + HRTF_TAPS - 1;
assert(numFrames % 8 == 0);
for (int i = 0; i < numFrames; i += 8) {
__m256 acc0 = _mm256_setzero_ps();
__m256 acc1 = _mm256_setzero_ps();
__m256 acc2 = _mm256_setzero_ps();
__m256 acc3 = _mm256_setzero_ps();
float* ps = &src[i - HRTF_TAPS + 1]; // process forwards
assert(HRTF_TAPS % 8 == 0);
for (int k = 0; k < HRTF_TAPS; k += 8) {
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-0]), _mm256_loadu_ps(&ps[k+0])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-0]), _mm256_loadu_ps(&ps[k+0])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-0]), _mm256_loadu_ps(&ps[k+0])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-0]), _mm256_loadu_ps(&ps[k+0])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-1]), _mm256_loadu_ps(&ps[k+1])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-1]), _mm256_loadu_ps(&ps[k+1])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-1]), _mm256_loadu_ps(&ps[k+1])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-1]), _mm256_loadu_ps(&ps[k+1])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-2]), _mm256_loadu_ps(&ps[k+2])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-2]), _mm256_loadu_ps(&ps[k+2])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-2]), _mm256_loadu_ps(&ps[k+2])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-2]), _mm256_loadu_ps(&ps[k+2])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-3]), _mm256_loadu_ps(&ps[k+3])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-3]), _mm256_loadu_ps(&ps[k+3])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-3]), _mm256_loadu_ps(&ps[k+3])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-3]), _mm256_loadu_ps(&ps[k+3])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-4]), _mm256_loadu_ps(&ps[k+4])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-4]), _mm256_loadu_ps(&ps[k+4])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-4]), _mm256_loadu_ps(&ps[k+4])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-4]), _mm256_loadu_ps(&ps[k+4])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-5]), _mm256_loadu_ps(&ps[k+5])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-5]), _mm256_loadu_ps(&ps[k+5])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-5]), _mm256_loadu_ps(&ps[k+5])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-5]), _mm256_loadu_ps(&ps[k+5])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-6]), _mm256_loadu_ps(&ps[k+6])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-6]), _mm256_loadu_ps(&ps[k+6])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-6]), _mm256_loadu_ps(&ps[k+6])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-6]), _mm256_loadu_ps(&ps[k+6])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-7]), _mm256_loadu_ps(&ps[k+7])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-7]), _mm256_loadu_ps(&ps[k+7])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-7]), _mm256_loadu_ps(&ps[k+7])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-7]), _mm256_loadu_ps(&ps[k+7])));
}
_mm256_storeu_ps(&dst0[i], acc0);
_mm256_storeu_ps(&dst1[i], acc1);
_mm256_storeu_ps(&dst2[i], acc2);
_mm256_storeu_ps(&dst3[i], acc3);
}
_mm256_zeroupper();
}
#endif
#include <emmintrin.h>
// 1 channel input, 4 channel output
static void FIR_1x4_SSE(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames) {
@ -193,23 +118,22 @@ static void FIR_1x4_SSE(float* src, float* dst0, float* dst1, float* dst2, float
}
//
// Runtime CPU dispatch
// Detect AVX/AVX2 support
//
#if 0
//#if defined(_MSC_VER)
#if defined(_MSC_VER)
#include <intrin.h>
// detect AVX support
static bool cpuSupportsAVX() {
static bool cpuSupportsAVX() {
int info[4];
int mask = (1<<27) | (1<<28); // OSXSAVE and AVX
int mask = (1 << 27) | (1 << 28); // OSXSAVE and AVX
__cpuidex(info, 0x1, 0);
bool result = false;
if ((info[2] & mask) == mask) {
if ((_xgetbv(_XCR_XFEATURE_ENABLED_MASK) & 0x6) == 0x6) {
result = true;
}
@ -217,33 +141,57 @@ static bool cpuSupportsAVX() {
return result;
}
typedef void (*t_FIR_1x4)(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames);
static bool cpuSupportsAVX2() {
int info[4];
int mask = (1 << 5); // AVX2
// dispatch stub
static void FIR_1x4(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames) {
static t_FIR_1x4 f = cpuSupportsAVX() ? FIR_1x4_AVX : FIR_1x4_SSE; // init on first call
(*f)(src, dst0, dst1, dst2, dst3, coef, numFrames); // dispatch
bool result = false;
if (cpuSupportsAVX()) {
__cpuidex(info, 0x7, 0);
if ((info[1] & mask) == mask) {
result = true;
}
}
return result;
}
//#elif defined(__GNU__)
#elif defined(__GNU__) || defined(__clang__)
typedef void (*t_FIR_1x4)(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames);
static bool cpuSupportsAVX() {
return __builtin_cpu_supports("avx");
}
// dispatch stub
static void FIR_1x4(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames) {
static t_FIR_1x4 f = __builtin_cpu_supports("avx") ? FIR_1x4_AVX : FIR_1x4_SSE; // init on first call
(*f)(src, dst0, dst1, dst2, dst3, coef, numFrames); // dispatch
static bool cpuSupportsAVX2() {
return __builtin_cpu_supports("avx2");
}
#else
// always use SSE version
static void FIR_1x4(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames) {
FIR_1x4_SSE(src, dst0, dst1, dst2, dst3, coef, numFrames);
static bool cpuSupportsAVX() {
return false;
}
static bool cpuSupportsAVX2() {
return false;
}
#endif
//
// Runtime CPU dispatch
//
typedef void FIR_1x4_t(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames);
FIR_1x4_t FIR_1x4_AVX; // separate compilation with VEX-encoding enabled
static void FIR_1x4(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames) {
static FIR_1x4_t* f = cpuSupportsAVX() ? FIR_1x4_AVX : FIR_1x4_SSE; // init on first call
(*f)(src, dst0, dst1, dst2, dst3, coef, numFrames); // dispatch
}
// 4 channel planar to interleaved
static void interleave_4x4(float* src0, float* src1, float* src2, float* src3, float* dst, int numFrames) {

View file

@ -0,0 +1,96 @@
//
// AudioHRTF_avx.cpp
// libraries/audio/src/avx
//
// Created by Ken Cooke on 1/17/16.
// Copyright 2016 High Fidelity, Inc.
//
// Distributed under the Apache License, Version 2.0.
// See the accompanying file LICENSE or http://www.apache.org/licenses/LICENSE-2.0.html
//
#if defined(_M_IX86) || defined(_M_X64) || defined(__i386__) || defined(__x86_64__)
#include <assert.h>
#include <immintrin.h>
#include "../AudioHRTF.h"
#ifndef __AVX__
#error Must be compiled with /arch:AVX or -mavx.
#endif
// 1 channel input, 4 channel output
void FIR_1x4_AVX(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames) {
float* coef0 = coef[0] + HRTF_TAPS - 1; // process backwards
float* coef1 = coef[1] + HRTF_TAPS - 1;
float* coef2 = coef[2] + HRTF_TAPS - 1;
float* coef3 = coef[3] + HRTF_TAPS - 1;
assert(numFrames % 8 == 0);
for (int i = 0; i < numFrames; i += 8) {
__m256 acc0 = _mm256_setzero_ps();
__m256 acc1 = _mm256_setzero_ps();
__m256 acc2 = _mm256_setzero_ps();
__m256 acc3 = _mm256_setzero_ps();
float* ps = &src[i - HRTF_TAPS + 1]; // process forwards
assert(HRTF_TAPS % 8 == 0);
for (int k = 0; k < HRTF_TAPS; k += 8) {
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-0]), _mm256_loadu_ps(&ps[k+0])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-0]), _mm256_loadu_ps(&ps[k+0])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-0]), _mm256_loadu_ps(&ps[k+0])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-0]), _mm256_loadu_ps(&ps[k+0])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-1]), _mm256_loadu_ps(&ps[k+1])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-1]), _mm256_loadu_ps(&ps[k+1])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-1]), _mm256_loadu_ps(&ps[k+1])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-1]), _mm256_loadu_ps(&ps[k+1])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-2]), _mm256_loadu_ps(&ps[k+2])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-2]), _mm256_loadu_ps(&ps[k+2])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-2]), _mm256_loadu_ps(&ps[k+2])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-2]), _mm256_loadu_ps(&ps[k+2])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-3]), _mm256_loadu_ps(&ps[k+3])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-3]), _mm256_loadu_ps(&ps[k+3])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-3]), _mm256_loadu_ps(&ps[k+3])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-3]), _mm256_loadu_ps(&ps[k+3])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-4]), _mm256_loadu_ps(&ps[k+4])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-4]), _mm256_loadu_ps(&ps[k+4])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-4]), _mm256_loadu_ps(&ps[k+4])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-4]), _mm256_loadu_ps(&ps[k+4])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-5]), _mm256_loadu_ps(&ps[k+5])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-5]), _mm256_loadu_ps(&ps[k+5])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-5]), _mm256_loadu_ps(&ps[k+5])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-5]), _mm256_loadu_ps(&ps[k+5])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-6]), _mm256_loadu_ps(&ps[k+6])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-6]), _mm256_loadu_ps(&ps[k+6])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-6]), _mm256_loadu_ps(&ps[k+6])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-6]), _mm256_loadu_ps(&ps[k+6])));
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_broadcast_ss(&coef0[-k-7]), _mm256_loadu_ps(&ps[k+7])));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_broadcast_ss(&coef1[-k-7]), _mm256_loadu_ps(&ps[k+7])));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_broadcast_ss(&coef2[-k-7]), _mm256_loadu_ps(&ps[k+7])));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_broadcast_ss(&coef3[-k-7]), _mm256_loadu_ps(&ps[k+7])));
}
_mm256_storeu_ps(&dst0[i], acc0);
_mm256_storeu_ps(&dst1[i], acc1);
_mm256_storeu_ps(&dst2[i], acc2);
_mm256_storeu_ps(&dst3[i], acc3);
}
_mm256_zeroupper();
}
#endif