Merge pull request #13927 from kencooke/audio-hrtf-avx2-more

Low-level performance improvements to audio-mixer
This commit is contained in:
John Conklin II 2018-09-04 12:06:32 -07:00 committed by GitHub
commit eaa5e2ea9a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 211 additions and 27 deletions

View file

@ -240,7 +240,7 @@ static void FIR_1x4_SSE(float* src, float* dst0, float* dst1, float* dst2, float
float* ps = &src[i - HRTF_TAPS + 1]; // process forwards
assert(HRTF_TAPS % 4 == 0);
static_assert(HRTF_TAPS % 4 == 0, "HRTF_TAPS must be a multiple of 4");
for (int k = 0; k < HRTF_TAPS; k += 4) {
@ -276,23 +276,8 @@ static void FIR_1x4_SSE(float* src, float* dst0, float* dst1, float* dst2, float
}
}
//
// Runtime CPU dispatch
//
#include "CPUDetect.h"
void FIR_1x4_AVX2(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames);
void FIR_1x4_AVX512(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames);
static void FIR_1x4(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames) {
static auto f = cpuSupportsAVX512() ? FIR_1x4_AVX512 : (cpuSupportsAVX2() ? FIR_1x4_AVX2 : FIR_1x4_SSE);
(*f)(src, dst0, dst1, dst2, dst3, coef, numFrames); // dispatch
}
// 4 channel planar to interleaved
static void interleave_4x4(float* src0, float* src1, float* src2, float* src3, float* dst, int numFrames) {
static void interleave_4x4_SSE(float* src0, float* src1, float* src2, float* src3, float* dst, int numFrames) {
assert(numFrames % 4 == 0);
@ -323,7 +308,7 @@ static void interleave_4x4(float* src0, float* src1, float* src2, float* src3, f
// process 2 cascaded biquads on 4 channels (interleaved)
// biquads computed in parallel, by adding one sample of delay
static void biquad2_4x4(float* src, float* dst, float coef[5][8], float state[3][8], int numFrames) {
static void biquad2_4x4_SSE(float* src, float* dst, float coef[5][8], float state[3][8], int numFrames) {
// enable flush-to-zero mode to prevent denormals
unsigned int ftz = _MM_GET_FLUSH_ZERO_MODE();
@ -388,7 +373,7 @@ static void biquad2_4x4(float* src, float* dst, float coef[5][8], float state[3]
}
// crossfade 4 inputs into 2 outputs with accumulation (interleaved)
static void crossfade_4x2(float* src, float* dst, const float* win, int numFrames) {
static void crossfade_4x2_SSE(float* src, float* dst, const float* win, int numFrames) {
assert(numFrames % 4 == 0);
@ -435,12 +420,12 @@ static void crossfade_4x2(float* src, float* dst, const float* win, int numFrame
}
// linear interpolation with gain
static void interpolate(float* dst, const float* src0, const float* src1, float frac, float gain) {
static void interpolate_SSE(const float* src0, const float* src1, float* dst, float frac, float gain) {
__m128 f0 = _mm_set1_ps(gain * (1.0f - frac));
__m128 f1 = _mm_set1_ps(gain * frac);
assert(HRTF_TAPS % 4 == 0);
static_assert(HRTF_TAPS % 4 == 0, "HRTF_TAPS must be a multiple of 4");
for (int k = 0; k < HRTF_TAPS; k += 4) {
@ -453,6 +438,44 @@ static void interpolate(float* dst, const float* src0, const float* src1, float
}
}
//
// Runtime CPU dispatch
//
#include "CPUDetect.h"
void FIR_1x4_AVX2(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames);
void FIR_1x4_AVX512(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames);
void interleave_4x4_AVX2(float* src0, float* src1, float* src2, float* src3, float* dst, int numFrames);
void biquad2_4x4_AVX2(float* src, float* dst, float coef[5][8], float state[3][8], int numFrames);
void crossfade_4x2_AVX2(float* src, float* dst, const float* win, int numFrames);
void interpolate_AVX2(const float* src0, const float* src1, float* dst, float frac, float gain);
static void FIR_1x4(float* src, float* dst0, float* dst1, float* dst2, float* dst3, float coef[4][HRTF_TAPS], int numFrames) {
static auto f = cpuSupportsAVX512() ? FIR_1x4_AVX512 : (cpuSupportsAVX2() ? FIR_1x4_AVX2 : FIR_1x4_SSE);
(*f)(src, dst0, dst1, dst2, dst3, coef, numFrames); // dispatch
}
static void interleave_4x4(float* src0, float* src1, float* src2, float* src3, float* dst, int numFrames) {
static auto f = cpuSupportsAVX2() ? interleave_4x4_AVX2 : interleave_4x4_SSE;
(*f)(src0, src1, src2, src3, dst, numFrames); // dispatch
}
static void biquad2_4x4(float* src, float* dst, float coef[5][8], float state[3][8], int numFrames) {
static auto f = cpuSupportsAVX2() ? biquad2_4x4_AVX2 : biquad2_4x4_SSE;
(*f)(src, dst, coef, state, numFrames); // dispatch
}
static void crossfade_4x2(float* src, float* dst, const float* win, int numFrames) {
static auto f = cpuSupportsAVX2() ? crossfade_4x2_AVX2 : crossfade_4x2_SSE;
(*f)(src, dst, win, numFrames); // dispatch
}
static void interpolate(const float* src0, const float* src1, float* dst, float frac, float gain) {
static auto f = cpuSupportsAVX2() ? interpolate_AVX2 : interpolate_SSE;
(*f)(src0, src1, dst, frac, gain); // dispatch
}
#else // portable reference code
// 1 channel input, 4 channel output
@ -489,7 +512,7 @@ static void FIR_1x4(float* src, float* dst0, float* dst1, float* dst2, float* ds
float* ps = &src[i - HRTF_TAPS + 1]; // process forwards
assert(HRTF_TAPS % 4 == 0);
static_assert(HRTF_TAPS % 4 == 0, "HRTF_TAPS must be a multiple of 4");
for (int k = 0; k < HRTF_TAPS; k += 4) {
@ -715,7 +738,7 @@ static void crossfade_4x2(float* src, float* dst, const float* win, int numFrame
}
// linear interpolation with gain
static void interpolate(float* dst, const float* src0, const float* src1, float frac, float gain) {
static void interpolate(const float* src0, const float* src1, float* dst, float frac, float gain) {
float f0 = gain * (1.0f - frac);
float f1 = gain * frac;
@ -967,8 +990,8 @@ static void setFilters(float firCoef[4][HRTF_TAPS], float bqCoef[5][8], int dela
azimuthToIndex(azimuth, az0, az1, frac);
// interpolate FIR
interpolate(firCoef[channel+0], ir_table_table[index][azL0][0], ir_table_table[index][azL1][0], fracL, gain * gainL);
interpolate(firCoef[channel+1], ir_table_table[index][azR0][1], ir_table_table[index][azR1][1], fracR, gain * gainR);
interpolate(ir_table_table[index][azL0][0], ir_table_table[index][azL1][0], firCoef[channel+0], fracL, gain * gainL);
interpolate(ir_table_table[index][azR0][1], ir_table_table[index][azR1][1], firCoef[channel+1], fracR, gain * gainR);
// interpolate ITD
float itd = (1.0f - frac) * itd_table_table[index][az0] + frac * itd_table_table[index][az1];

View file

@ -44,7 +44,7 @@ void FIR_1x4_AVX2(float* src, float* dst0, float* dst1, float* dst2, float* dst3
float* ps = &src[i - HRTF_TAPS + 1]; // process forwards
assert(HRTF_TAPS % 4 == 0);
static_assert(HRTF_TAPS % 4 == 0, "HRTF_TAPS must be a multiple of 4");
for (int k = 0; k < HRTF_TAPS; k += 4) {
@ -87,4 +87,165 @@ void FIR_1x4_AVX2(float* src, float* dst0, float* dst1, float* dst2, float* dst3
_mm256_zeroupper();
}
// 4 channel planar to interleaved
void interleave_4x4_AVX2(float* src0, float* src1, float* src2, float* src3, float* dst, int numFrames) {
assert(numFrames % 8 == 0);
for (int i = 0; i < numFrames; i += 8) {
__m256 x0 = _mm256_loadu_ps(&src0[i]);
__m256 x1 = _mm256_loadu_ps(&src1[i]);
__m256 x2 = _mm256_loadu_ps(&src2[i]);
__m256 x3 = _mm256_loadu_ps(&src3[i]);
// interleave (4x4 matrix transpose)
__m256 t0 = _mm256_unpacklo_ps(x0, x1);
__m256 t1 = _mm256_unpackhi_ps(x0, x1);
__m256 t2 = _mm256_unpacklo_ps(x2, x3);
__m256 t3 = _mm256_unpackhi_ps(x2, x3);
x0 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(1,0,1,0));
x1 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(3,2,3,2));
x2 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(1,0,1,0));
x3 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(3,2,3,2));
t0 = _mm256_permute2f128_ps(x0, x1, 0x20);
t1 = _mm256_permute2f128_ps(x2, x3, 0x20);
t2 = _mm256_permute2f128_ps(x0, x1, 0x31);
t3 = _mm256_permute2f128_ps(x2, x3, 0x31);
_mm256_storeu_ps(&dst[4*i+0], t0);
_mm256_storeu_ps(&dst[4*i+8], t1);
_mm256_storeu_ps(&dst[4*i+16], t2);
_mm256_storeu_ps(&dst[4*i+24], t3);
}
_mm256_zeroupper();
}
// process 2 cascaded biquads on 4 channels (interleaved)
// biquads are computed in parallel, by adding one sample of delay
void biquad2_4x4_AVX2(float* src, float* dst, float coef[5][8], float state[3][8], int numFrames) {
// enable flush-to-zero mode to prevent denormals
unsigned int ftz = _MM_GET_FLUSH_ZERO_MODE();
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
// restore state
__m256 x0 = _mm256_setzero_ps();
__m256 y0 = _mm256_loadu_ps(state[0]);
__m256 w1 = _mm256_loadu_ps(state[1]);
__m256 w2 = _mm256_loadu_ps(state[2]);
// biquad coefs
__m256 b0 = _mm256_loadu_ps(coef[0]);
__m256 b1 = _mm256_loadu_ps(coef[1]);
__m256 b2 = _mm256_loadu_ps(coef[2]);
__m256 a1 = _mm256_loadu_ps(coef[3]);
__m256 a2 = _mm256_loadu_ps(coef[4]);
for (int i = 0; i < numFrames; i++) {
// x0 = (first biquad output << 128) | input
x0 = _mm256_insertf128_ps(_mm256_permute2f128_ps(y0, y0, 0x01), _mm_loadu_ps(&src[4*i]), 0);
// transposed Direct Form II
y0 = _mm256_fmadd_ps(x0, b0, w1);
w1 = _mm256_fmadd_ps(x0, b1, w2);
w2 = _mm256_mul_ps(x0, b2);
w1 = _mm256_fnmadd_ps(y0, a1, w1);
w2 = _mm256_fnmadd_ps(y0, a2, w2);
_mm_storeu_ps(&dst[4*i], _mm256_extractf128_ps(y0, 1)); // second biquad output
}
// save state
_mm256_storeu_ps(state[0], y0);
_mm256_storeu_ps(state[1], w1);
_mm256_storeu_ps(state[2], w2);
_MM_SET_FLUSH_ZERO_MODE(ftz);
_mm256_zeroupper();
}
// crossfade 4 inputs into 2 outputs with accumulation (interleaved)
void crossfade_4x2_AVX2(float* src, float* dst, const float* win, int numFrames) {
assert(numFrames % 8 == 0);
for (int i = 0; i < numFrames; i += 8) {
__m256 f0 = _mm256_loadu_ps(&win[i]);
__m256 x0 = _mm256_castps128_ps256(_mm_loadu_ps(&src[4*i+0]));
__m256 x1 = _mm256_castps128_ps256(_mm_loadu_ps(&src[4*i+4]));
__m256 x2 = _mm256_castps128_ps256(_mm_loadu_ps(&src[4*i+8]));
__m256 x3 = _mm256_castps128_ps256(_mm_loadu_ps(&src[4*i+12]));
x0 = _mm256_insertf128_ps(x0, _mm_loadu_ps(&src[4*i+16]), 1);
x1 = _mm256_insertf128_ps(x1, _mm_loadu_ps(&src[4*i+20]), 1);
x2 = _mm256_insertf128_ps(x2, _mm_loadu_ps(&src[4*i+24]), 1);
x3 = _mm256_insertf128_ps(x3, _mm_loadu_ps(&src[4*i+28]), 1);
__m256 y0 = _mm256_loadu_ps(&dst[2*i+0]);
__m256 y1 = _mm256_loadu_ps(&dst[2*i+8]);
// deinterleave (4x4 matrix transpose)
__m256 t0 = _mm256_unpacklo_ps(x0, x1);
__m256 t1 = _mm256_unpackhi_ps(x0, x1);
__m256 t2 = _mm256_unpacklo_ps(x2, x3);
__m256 t3 = _mm256_unpackhi_ps(x2, x3);
x0 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(1,0,1,0));
x1 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(3,2,3,2));
x2 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(1,0,1,0));
x3 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(3,2,3,2));
// crossfade
x0 = _mm256_sub_ps(x0, x2);
x1 = _mm256_sub_ps(x1, x3);
x2 = _mm256_fmadd_ps(f0, x0, x2);
x3 = _mm256_fmadd_ps(f0, x1, x3);
// interleave
t0 = _mm256_unpacklo_ps(x2, x3);
t1 = _mm256_unpackhi_ps(x2, x3);
x0 = _mm256_permute2f128_ps(t0, t1, 0x20);
x1 = _mm256_permute2f128_ps(t0, t1, 0x31);
// accumulate
y0 = _mm256_add_ps(y0, x0);
y1 = _mm256_add_ps(y1, x1);
_mm256_storeu_ps(&dst[2*i+0], y0);
_mm256_storeu_ps(&dst[2*i+8], y1);
}
_mm256_zeroupper();
}
// linear interpolation with gain
void interpolate_AVX2(const float* src0, const float* src1, float* dst, float frac, float gain) {
__m256 f0 = _mm256_set1_ps(gain * (1.0f - frac));
__m256 f1 = _mm256_set1_ps(gain * frac);
static_assert(HRTF_TAPS % 8 == 0, "HRTF_TAPS must be a multiple of 8");
for (int k = 0; k < HRTF_TAPS; k += 8) {
__m256 x0 = _mm256_loadu_ps(&src0[k]);
__m256 x1 = _mm256_loadu_ps(&src1[k]);
x0 = _mm256_mul_ps(f0, x0);
x0 = _mm256_fmadd_ps(f1, x1, x0);
_mm256_storeu_ps(&dst[k], x0);
}
_mm256_zeroupper();
}
#endif

View file

@ -44,7 +44,7 @@ void FIR_1x4_AVX512(float* src, float* dst0, float* dst1, float* dst2, float* ds
float* ps = &src[i - HRTF_TAPS + 1]; // process forwards
assert(HRTF_TAPS % 4 == 0);
static_assert(HRTF_TAPS % 4 == 0, "HRTF_TAPS must be a multiple of 4");
for (int k = 0; k < HRTF_TAPS; k += 4) {