I don't know if any of you has use for that, but I have written that code the 3rd time now for a different project..
It's not much code, but I thought I post it here.. so that I have a place to copy&paste from if I need it for the 4rd project and ppl on this forums that are not familiar with AVX can see how it's done (and adapt it to their needs )
(AVX or SIMD is general is about to process to multiple values with a single command - i.e. the AVX and FMA code below processed 8 floats at a time, instead of 8x 1 float).
Notes:
- get_madd_func will return a function pointer that either maps to FMA3 (vfmadd132ps), AVX (vmulps+vaddps) or C (compiler config specific) instruction sets, depending on what's available on the CPU.
- The code doesn't aim to be a fast as possible. It uses intrinsics rather than assembly code, so no control on register allocation and instruction scheduling. Also the increment of the 4 pointers could be done on YMM rather then 4 add instructions.
Goal was to use the FMA3 instructions to process the floats, or fall back to AVX, or C ultimately, if not supported.
- You can use un-aligned memory. Performance penalty on unaligned loads with AVX is rather small, so _mm256_loadu_ps is used. If you know that the memory-blocks will be 32-byte aligned, replace with _mm256_load_ps for minor performance gain.
- The float-arrays to process can have any size (remaining non-multiple-8 data is processed by C function).
If you wanna do that on 256bit registers too, _mm256_maskload_ps is your friend (just do same as on the loop, but mask the load to not load full register, but only a part of it)
Code: Select all
#include <immintrin.h>
//////////////////////////////////////////////////////////
// get_cpuid functions - implement for you compiler if missing //
//////////////////////////////////////////////////////////
#if defined(_MSC_VER)
#include <intrin.h>
static void get_cpuid(unsigned int level, unsigned int regs[4]) {
memset(regs, 0, sizeof(unsigned int) * 4);
__cpuid((int*)regs, level);
}
#elif defined(__GNUC__)
#include <cpuid.h>
static void get_cpuid(unsigned int level, unsigned int regs[4]) {
memset(regs, 0, sizeof(unsigned int) * 4);
__get_cpuid(level, ®s[0], ®s[1], ®s[2], ®s[3]);
}
#else
#error "Implement a function to get cpuid on your compiler"
#endif
///////////////////////////////////////////////
// Function: r[i] = a[i] * b[i] + c[i] //
///////////////////////////////////////////////
// C
inline void madd_c(
float* a,
float* b,
float* c,
float* r,
size_t n)
{
for (size_t i = 0; i < n; i++)
{
r[i] = a[i] * b[i] + c[i];
}
}
// AVX
inline void madd_avx(
float* a,
float* b,
float* c,
float* r,
size_t n)
{
float* end = a + (n / 8) * 8;
size_t nleft = (a + n) - end;
while(a < end)
{
__m256 ma = _mm256_loadu_ps(a);
__m256 mb = _mm256_loadu_ps(b);
__m256 mc = _mm256_loadu_ps(c);
ma = _mm256_mul_ps(ma, mb);
ma = _mm256_add_ps(ma, mc);
_mm256_storeu_ps(r, ma);
a += 8;
b += 8;
c += 8;
r += 8;
}
if (nleft)
{
madd_c(a, b, c, r, nleft);
}
}
// FMA
inline void madd_fma(
float* a,
float* b,
float* c,
float* r,
size_t n)
{
float* end = a + (n / 8) * 8;
size_t nleft = (a + n) - end;
while(a < end)
{
__m256 ma = _mm256_loadu_ps(a);
__m256 mb = _mm256_loadu_ps(b);
__m256 mc = _mm256_loadu_ps(c);
ma = _mm256_fmadd_ps(ma, mb, mc);
_mm256_storeu_ps(r, ma);
a += 8;
b += 8;
c += 8;
r += 8;
}
if (nleft)
{
madd_c(a, b, c, r, nleft);
}
}
///////////////////////////////////////////////
// get_madd_func //
///////////////////////////////////////////////
typedef void (*madd_func_ptr)(float* a, float* b, float* c, float* r, size_t n);
static void get_madd_func(madd_func_ptr& ptr)
{
enum {
EAX = 0,
EBX = 1,
ECX = 2,
EDX = 3,
RegisterCount
};
unsigned int regs[RegisterCount];
get_cpuid(1, regs);
int has_avx = (regs[ECX] & (1<<28));
int has_fma = (regs[ECX] & (1<<12));
if (has_fma) {
ptr = madd_fma;
} else if (has_avx) {
ptr = madd_avx;
} else {
ptr = madd_c;
}
}
///////////////////////////////////////////////
// just some dummy code that shows how to use it //
///////////////////////////////////////////////
static void how_to_use()
{
#define NCOUNT 12345
madd_func_ptr madd_func;
get_madd_func(madd_func);
float a[NCOUNT];
float b[NCOUNT];
float c[NCOUNT];
float r[NCOUNT];
// ..
// do:
// r[i] = a[i] * b[i] + c[i]
madd_func(a, b, c, r, NCOUNT);
}