Multiply-add on AVX and FMA

DSP, Plugin and Host development discussion.
Post Reply New Topic
RELATED
PRODUCTS

Post

Hi all.
I don't know if any of you has use for that, but I have written that code the 3rd time now for a different project..
It's not much code, but I thought I post it here.. so that I have a place to copy&paste from if I need it for the 4rd project and ppl on this forums that are not familiar with AVX can see how it's done :D (and adapt it to their needs ;) )
(AVX or SIMD is general is about to process to multiple values with a single command - i.e. the AVX and FMA code below processed 8 floats at a time, instead of 8x 1 float).

Notes:
- get_madd_func will return a function pointer that either maps to FMA3 (vfmadd132ps), AVX (vmulps+vaddps) or C (compiler config specific) instruction sets, depending on what's available on the CPU.
- The code doesn't aim to be a fast as possible. It uses intrinsics rather than assembly code, so no control on register allocation and instruction scheduling. Also the increment of the 4 pointers could be done on YMM rather then 4 add instructions.
Goal was to use the FMA3 instructions to process the floats, or fall back to AVX, or C ultimately, if not supported.
- You can use un-aligned memory. Performance penalty on unaligned loads with AVX is rather small, so _mm256_loadu_ps is used. If you know that the memory-blocks will be 32-byte aligned, replace with _mm256_load_ps for minor performance gain.
- The float-arrays to process can have any size (remaining non-multiple-8 data is processed by C function).
If you wanna do that on 256bit registers too, _mm256_maskload_ps is your friend (just do same as on the loop, but mask the load to not load full register, but only a part of it) ;)

Code: Select all

#include <immintrin.h>

//////////////////////////////////////////////////////////
// get_cpuid functions - implement for you compiler if missing   //
//////////////////////////////////////////////////////////

#if defined(_MSC_VER)
	#include <intrin.h>
	static void get_cpuid(unsigned int level, unsigned int regs[4]) {
		memset(regs, 0, sizeof(unsigned int) * 4);
		__cpuid((int*)regs, level);
	}
#elif defined(__GNUC__)
	#include <cpuid.h>
	static void get_cpuid(unsigned int level, unsigned int regs[4]) {
		memset(regs, 0, sizeof(unsigned int) * 4);
		__get_cpuid(level, &regs[0], &regs[1], &regs[2], &regs[3]);
	}
#else
	#error "Implement a function to get cpuid on your compiler"
#endif

///////////////////////////////////////////////
// Function: r[i] = a[i] * b[i] + c[i]                            //
///////////////////////////////////////////////

// C
inline void madd_c(
	float* a, 
	float* b, 
	float* c, 
	float* r,
	size_t n)
{
	for (size_t i = 0; i < n; i++)
	{
		r[i] = a[i] * b[i] + c[i];
	}
}

// AVX
inline void madd_avx(
	float* a, 
	float* b, 
	float* c, 
	float* r,
	size_t n)
{
	float* end = a + (n / 8) * 8;
	size_t nleft = (a + n) - end;

	while(a < end)
	{

		__m256 ma =  _mm256_loadu_ps(a);
		__m256 mb =  _mm256_loadu_ps(b);
		__m256 mc =  _mm256_loadu_ps(c);

		ma = _mm256_mul_ps(ma, mb);
		ma = _mm256_add_ps(ma, mc);

		_mm256_storeu_ps(r, ma);

		a += 8;
		b += 8;
		c += 8;
		r += 8;
	}

	if (nleft) 
	{
		madd_c(a, b, c, r, nleft);
	}
}

// FMA
inline void madd_fma(
	float* a, 
	float* b, 
	float* c, 
	float* r,
	size_t n)
{
	float* end = a + (n / 8) * 8;
	size_t nleft = (a + n) - end;

	while(a < end)
	{
		__m256 ma =  _mm256_loadu_ps(a);
		__m256 mb =  _mm256_loadu_ps(b);
		__m256 mc =  _mm256_loadu_ps(c);

		ma = _mm256_fmadd_ps(ma, mb, mc);

		_mm256_storeu_ps(r, ma);

		a += 8;
		b += 8;
		c += 8;
		r += 8;
	}

	if (nleft) 
	{
		madd_c(a, b, c, r, nleft);
	}
}

///////////////////////////////////////////////
// get_madd_func                                                   //
///////////////////////////////////////////////

typedef void (*madd_func_ptr)(float* a, float* b, float* c, float* r, size_t n);

static void get_madd_func(madd_func_ptr& ptr)
{
	enum {
		EAX = 0,
		EBX = 1,
		ECX = 2,
		EDX = 3,
		RegisterCount
	};

	unsigned int regs[RegisterCount];

	get_cpuid(1, regs);

	int has_avx = (regs[ECX] & (1<<28));
	int has_fma = (regs[ECX] & (1<<12));

	if (has_fma) {
		ptr = madd_fma;
	} else if (has_avx) {
		ptr = madd_avx;
	} else {
		ptr = madd_c;
	}
}

///////////////////////////////////////////////
// just some dummy code that shows how to use it  //
///////////////////////////////////////////////

static void how_to_use()
{
	#define NCOUNT	12345

	madd_func_ptr madd_func;
	get_madd_func(madd_func);

	float a[NCOUNT];
	float b[NCOUNT];
	float c[NCOUNT];
	float r[NCOUNT];

	// ..

	// do: 
	//   r[i] = a[i] * b[i] + c[i]

	madd_func(a, b, c, r, NCOUNT);
}

Post Reply

Return to “DSP and Plugin Development”