mirror of
https://github.com/open-quantum-safe/liboqs.git
synced 2025-10-09 00:04:26 -04:00
Add Frodo AVX2 matrix multiplication
This commit is contained in:
parent
dce10891f3
commit
cb15135ab6
@ -8,6 +8,12 @@
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <oqs/oqsconfig.h>
|
||||
|
||||
#if defined(USE_AVX2_INSTRUCTIONS)
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
#include <oqs/aes.h>
|
||||
|
||||
#include "frodo_internal.h"
|
||||
@ -119,6 +125,7 @@ int frodo_mul_add_sa_plus_e(uint16_t *out, const uint16_t *s, const uint16_t *e,
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef USE_AVX2_INSTRUCTIONS
|
||||
for (i = 0; i < PARAMS_NBAR; i++) {
|
||||
for (k = 0; k < PARAMS_STRIPE_STEP; k += PARAMS_PARALLEL) {
|
||||
uint16_t sum[PARAMS_PARALLEL] = {0};
|
||||
@ -135,6 +142,41 @@ int frodo_mul_add_sa_plus_e(uint16_t *out, const uint16_t *s, const uint16_t *e,
|
||||
out[i*PARAMS_N + kk + k + 3] += sum[3];
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (i = 0; i < PARAMS_NBAR; i++) {
|
||||
for (k = 0; k < PARAMS_STRIPE_STEP; k += PARAMS_PARALLEL) {
|
||||
ALIGN_HEADER(32) uint32_t sum[8 * PARAMS_PARALLEL] ALIGN_FOOTER(32);
|
||||
__m256i a[PARAMS_PARALLEL], b, acc[PARAMS_PARALLEL];
|
||||
acc[0] = _mm256_setzero_si256();
|
||||
acc[1] = _mm256_setzero_si256();
|
||||
acc[2] = _mm256_setzero_si256();
|
||||
acc[3] = _mm256_setzero_si256();
|
||||
for (j = 0; j < PARAMS_N; j += 16) { // Matrix-vector multiplication
|
||||
b = _mm256_load_si256((__m256i*)&s[i*PARAMS_N + j]);
|
||||
a[0] = _mm256_load_si256((__m256i*)&a_cols_t[(k+0)*PARAMS_N + j]);
|
||||
a[0] = _mm256_madd_epi16(a[0], b);
|
||||
acc[0] = _mm256_add_epi16(a[0], acc[0]);
|
||||
a[1] = _mm256_load_si256((__m256i*)&a_cols_t[(k+1)*PARAMS_N + j]);
|
||||
a[1] = _mm256_madd_epi16(a[1], b);
|
||||
acc[1] = _mm256_add_epi16(a[1], acc[1]);
|
||||
a[2] = _mm256_load_si256((__m256i*)&a_cols_t[(k+2)*PARAMS_N + j]);
|
||||
a[2] = _mm256_madd_epi16(a[2], b);
|
||||
acc[2] = _mm256_add_epi16(a[2], acc[2]);
|
||||
a[3] = _mm256_load_si256((__m256i*)&a_cols_t[(k+3)*PARAMS_N + j]);
|
||||
a[3] = _mm256_madd_epi16(a[3], b);
|
||||
acc[3] = _mm256_add_epi16(a[3], acc[3]);
|
||||
}
|
||||
_mm256_store_si256((__m256i*)(sum + (8*0)), acc[0]);
|
||||
out[i*PARAMS_N + kk + k + 0] += sum[8*0 + 0] + sum[8*0 + 1] + sum[8*0 + 2] + sum[8*0 + 3] + sum[8*0 + 4] + sum[8*0 + 5] + sum[8*0 + 6] + sum[8*0 + 7];
|
||||
_mm256_store_si256((__m256i*)(sum + (8*1)), acc[1]);
|
||||
out[i*PARAMS_N + kk + k + 1] += sum[8*1 + 0] + sum[8*1 + 1] + sum[8*1 + 2] + sum[8*1 + 3] + sum[8*1 + 4] + sum[8*1 + 5] + sum[8*1 + 6] + sum[8*1 + 7];
|
||||
_mm256_store_si256((__m256i*)(sum + (8*2)), acc[2]);
|
||||
out[i*PARAMS_N + kk + k + 2] += sum[8*2 + 0] + sum[8*2 + 1] + sum[8*2 + 2] + sum[8*2 + 3] + sum[8*2 + 4] + sum[8*2 + 5] + sum[8*2 + 6] + sum[8*2 + 7];
|
||||
_mm256_store_si256((__m256i*)(sum + (8*3)), acc[3]);
|
||||
out[i*PARAMS_N + kk + k + 3] += sum[8*3 + 0] + sum[8*3 + 1] + sum[8*3 + 2] + sum[8*3 + 3] + sum[8*3 + 4] + sum[8*3 + 5] + sum[8*3 + 6] + sum[8*3 + 7];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
OQS_AES128_free_schedule(aes_key_schedule);
|
||||
|
||||
@ -159,6 +201,7 @@ int frodo_mul_add_sa_plus_e(uint16_t *out, const uint16_t *s, const uint16_t *e,
|
||||
a_cols[i] = LE_TO_UINT16(a_cols[i]);
|
||||
}
|
||||
|
||||
#ifndef USE_AVX2_INSTRUCTIONS
|
||||
for (i = 0; i < PARAMS_NBAR; i++) {
|
||||
uint16_t sum[PARAMS_N] = {0};
|
||||
for (j = 0; j < 4; j++) {
|
||||
@ -171,6 +214,34 @@ int frodo_mul_add_sa_plus_e(uint16_t *out, const uint16_t *s, const uint16_t *e,
|
||||
out[i*PARAMS_N + k] += sum[k];
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (i = 0; i < PARAMS_NBAR; i++) {
|
||||
__m256i a, b0, b1, b2, b3, acc[PARAMS_N/16];
|
||||
b0 = _mm256_set1_epi16(s[i*PARAMS_N + kk + 0]);
|
||||
b1 = _mm256_set1_epi16(s[i*PARAMS_N + kk + 1]);
|
||||
b2 = _mm256_set1_epi16(s[i*PARAMS_N + kk + 2]);
|
||||
b3 = _mm256_set1_epi16(s[i*PARAMS_N + kk + 3]);
|
||||
for (j = 0; j < PARAMS_N; j+=16) { // Matrix-vector multiplication
|
||||
acc[j/16] = _mm256_load_si256((__m256i*)&out[i*PARAMS_N + j]);
|
||||
a = _mm256_load_si256((__m256i*)&a_cols[(t+0)*PARAMS_N + j]);
|
||||
a = _mm256_mullo_epi16(a, b0);
|
||||
acc[j/16] = _mm256_add_epi16(a, acc[j/16]);
|
||||
a = _mm256_load_si256((__m256i*)&a_cols[(t+1)*PARAMS_N + j]);
|
||||
a = _mm256_mullo_epi16(a, b1);
|
||||
acc[j/16] = _mm256_add_epi16(a, acc[j/16]);
|
||||
a = _mm256_load_si256((__m256i*)&a_cols[(t+2)*PARAMS_N + j]);
|
||||
a = _mm256_mullo_epi16(a, b2);
|
||||
acc[j/16] = _mm256_add_epi16(a, acc[j/16]);
|
||||
a = _mm256_load_si256((__m256i*)&a_cols[(t+3)*PARAMS_N + j]);
|
||||
a = _mm256_mullo_epi16(a, b3);
|
||||
acc[j/16] = _mm256_add_epi16(a, acc[j/16]);
|
||||
}
|
||||
|
||||
for (k = 0; k < PARAMS_N/16; k++) {
|
||||
_mm256_store_si256((__m256i*)&out[i*PARAMS_N + 16*k], acc[k]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user