diff --git a/src/kem/frodokem/external/frodo_macrify_optimized.c b/src/kem/frodokem/external/frodo_macrify_optimized.c index d2654c746..41de5b0cd 100644 --- a/src/kem/frodokem/external/frodo_macrify_optimized.c +++ b/src/kem/frodokem/external/frodo_macrify_optimized.c @@ -8,6 +8,12 @@ #include #include +#include + +#if defined(USE_AVX2_INSTRUCTIONS) + #include +#endif + #include #include "frodo_internal.h" @@ -119,6 +125,7 @@ int frodo_mul_add_sa_plus_e(uint16_t *out, const uint16_t *s, const uint16_t *e, } } +#ifndef USE_AVX2_INSTRUCTIONS for (i = 0; i < PARAMS_NBAR; i++) { for (k = 0; k < PARAMS_STRIPE_STEP; k += PARAMS_PARALLEL) { uint16_t sum[PARAMS_PARALLEL] = {0}; @@ -135,6 +142,41 @@ int frodo_mul_add_sa_plus_e(uint16_t *out, const uint16_t *s, const uint16_t *e, out[i*PARAMS_N + kk + k + 3] += sum[3]; } } +#else + for (i = 0; i < PARAMS_NBAR; i++) { + for (k = 0; k < PARAMS_STRIPE_STEP; k += PARAMS_PARALLEL) { + ALIGN_HEADER(32) uint32_t sum[8 * PARAMS_PARALLEL] ALIGN_FOOTER(32); + __m256i a[PARAMS_PARALLEL], b, acc[PARAMS_PARALLEL]; + acc[0] = _mm256_setzero_si256(); + acc[1] = _mm256_setzero_si256(); + acc[2] = _mm256_setzero_si256(); + acc[3] = _mm256_setzero_si256(); + for (j = 0; j < PARAMS_N; j += 16) { // Matrix-vector multiplication + b = _mm256_load_si256((__m256i*)&s[i*PARAMS_N + j]); + a[0] = _mm256_load_si256((__m256i*)&a_cols_t[(k+0)*PARAMS_N + j]); + a[0] = _mm256_madd_epi16(a[0], b); + acc[0] = _mm256_add_epi16(a[0], acc[0]); + a[1] = _mm256_load_si256((__m256i*)&a_cols_t[(k+1)*PARAMS_N + j]); + a[1] = _mm256_madd_epi16(a[1], b); + acc[1] = _mm256_add_epi16(a[1], acc[1]); + a[2] = _mm256_load_si256((__m256i*)&a_cols_t[(k+2)*PARAMS_N + j]); + a[2] = _mm256_madd_epi16(a[2], b); + acc[2] = _mm256_add_epi16(a[2], acc[2]); + a[3] = _mm256_load_si256((__m256i*)&a_cols_t[(k+3)*PARAMS_N + j]); + a[3] = _mm256_madd_epi16(a[3], b); + acc[3] = _mm256_add_epi16(a[3], acc[3]); + } + _mm256_store_si256((__m256i*)(sum + (8*0)), acc[0]); + out[i*PARAMS_N + kk + k + 0] += sum[8*0 + 0] + sum[8*0 + 1] + sum[8*0 + 2] + sum[8*0 + 3] + sum[8*0 + 4] + sum[8*0 + 5] + sum[8*0 + 6] + sum[8*0 + 7]; + _mm256_store_si256((__m256i*)(sum + (8*1)), acc[1]); + out[i*PARAMS_N + kk + k + 1] += sum[8*1 + 0] + sum[8*1 + 1] + sum[8*1 + 2] + sum[8*1 + 3] + sum[8*1 + 4] + sum[8*1 + 5] + sum[8*1 + 6] + sum[8*1 + 7]; + _mm256_store_si256((__m256i*)(sum + (8*2)), acc[2]); + out[i*PARAMS_N + kk + k + 2] += sum[8*2 + 0] + sum[8*2 + 1] + sum[8*2 + 2] + sum[8*2 + 3] + sum[8*2 + 4] + sum[8*2 + 5] + sum[8*2 + 6] + sum[8*2 + 7]; + _mm256_store_si256((__m256i*)(sum + (8*3)), acc[3]); + out[i*PARAMS_N + kk + k + 3] += sum[8*3 + 0] + sum[8*3 + 1] + sum[8*3 + 2] + sum[8*3 + 3] + sum[8*3 + 4] + sum[8*3 + 5] + sum[8*3 + 6] + sum[8*3 + 7]; + } + } +#endif } OQS_AES128_free_schedule(aes_key_schedule); @@ -159,6 +201,7 @@ int frodo_mul_add_sa_plus_e(uint16_t *out, const uint16_t *s, const uint16_t *e, a_cols[i] = LE_TO_UINT16(a_cols[i]); } +#ifndef USE_AVX2_INSTRUCTIONS for (i = 0; i < PARAMS_NBAR; i++) { uint16_t sum[PARAMS_N] = {0}; for (j = 0; j < 4; j++) { @@ -171,6 +214,34 @@ int frodo_mul_add_sa_plus_e(uint16_t *out, const uint16_t *s, const uint16_t *e, out[i*PARAMS_N + k] += sum[k]; } } +#else + for (i = 0; i < PARAMS_NBAR; i++) { + __m256i a, b0, b1, b2, b3, acc[PARAMS_N/16]; + b0 = _mm256_set1_epi16(s[i*PARAMS_N + kk + 0]); + b1 = _mm256_set1_epi16(s[i*PARAMS_N + kk + 1]); + b2 = _mm256_set1_epi16(s[i*PARAMS_N + kk + 2]); + b3 = _mm256_set1_epi16(s[i*PARAMS_N + kk + 3]); + for (j = 0; j < PARAMS_N; j+=16) { // Matrix-vector multiplication + acc[j/16] = _mm256_load_si256((__m256i*)&out[i*PARAMS_N + j]); + a = _mm256_load_si256((__m256i*)&a_cols[(t+0)*PARAMS_N + j]); + a = _mm256_mullo_epi16(a, b0); + acc[j/16] = _mm256_add_epi16(a, acc[j/16]); + a = _mm256_load_si256((__m256i*)&a_cols[(t+1)*PARAMS_N + j]); + a = _mm256_mullo_epi16(a, b1); + acc[j/16] = _mm256_add_epi16(a, acc[j/16]); + a = _mm256_load_si256((__m256i*)&a_cols[(t+2)*PARAMS_N + j]); + a = _mm256_mullo_epi16(a, b2); + acc[j/16] = _mm256_add_epi16(a, acc[j/16]); + a = _mm256_load_si256((__m256i*)&a_cols[(t+3)*PARAMS_N + j]); + a = _mm256_mullo_epi16(a, b3); + acc[j/16] = _mm256_add_epi16(a, acc[j/16]); + } + + for (k = 0; k < PARAMS_N/16; k++) { + _mm256_store_si256((__m256i*)&out[i*PARAMS_N + 16*k], acc[k]); + } + } +#endif } #endif