Add AVX512VL-Optimized SHA3/SHAKE Implementations (#2167)

* Add SHA3-256/384/512 and SHAKE128/256 AVX512VL implementations

Co-authored-by: Tomasz Kantecki <tomasz.kantecki@intel.com>
Co-authored-by: Erdinc Ozturk <erdinc.ozturk@intel.com>
Signed-off-by: Marcel Cornu <marcel.d.cornu@intel.com>
Signed-off-by: Tomasz Kantecki <tomasz.kantecki@intel.com>

* AVX512VL SHA3 is added as an extension of XKCP implementation

Co-authored-by: Marcel Cornu <marcel.d.cornu@intel.com>
Signed-off-by: Tomasz Kantecki <tomasz.kantecki@intel.com>
Signed-off-by: Marcel Cornu <marcel.d.cornu@intel.com>

* Add SHA3-384 tests

Signed-off-by: Marcel Cornu <marcel.d.cornu@intel.com>

* Update namespace test to include SHA3

Signed-off-by: Marcel Cornu <marcel.d.cornu@intel.com>

* Release SHA3 context after triggering dispatcher

Signed-off-by: Marcel Cornu <marcel.d.cornu@intel.com>

* Add linux CI for OQS_USE_SHA3_AVX512VL=OFF config

Signed-off-by: Marcel Cornu <marcel.d.cornu@intel.com>

* Add AVX512 emulation to linux CI

Signed-off-by: Marcel Cornu <marcel.d.cornu@intel.com>

---------

Signed-off-by: Marcel Cornu <marcel.d.cornu@intel.com>
Signed-off-by: Tomasz Kantecki <tomasz.kantecki@intel.com>
Co-authored-by: Tomasz Kantecki <tomasz.kantecki@intel.com>
Co-authored-by: Erdinc Ozturk <erdinc.ozturk@intel.com>
This commit is contained in:
Marcel Cornu 2025-06-20 18:37:32 +01:00 committed by GitHub
parent 47b8fdd404
commit 8f926065eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 3989 additions and 8 deletions

View File

@ -78,6 +78,13 @@ if(OQS_DIST_X86_64_BUILD OR OQS_USE_AVX2_INSTRUCTIONS)
endif()
endif()
# SHA3 AVX512VL only supported on Linux x86_64
if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND (OQS_DIST_X86_64_BUILD OR OQS_USE_AVX512_INSTRUCTIONS))
cmake_dependent_option(OQS_USE_SHA3_AVX512VL "Enable SHA3 AVX512VL usage" ON "NOT OQS_USE_SHA3_OPENSSL" OFF)
else()
option(OQS_USE_SHA3_AVX512VL "Enable SHA3 AVX512VL usage" OFF)
endif()
# BIKE is not supported on Windows, 32-bit ARM, X86, S390X (big endian) and PPC64 (big endian)
cmake_dependent_option(OQS_ENABLE_KEM_BIKE "Enable BIKE algorithm family" ON "NOT WIN32; NOT ARCH_ARM32v7; NOT ARCH_X86; NOT ARCH_S390X; NOT ARCH_PPC64" OFF)
# BIKE doesn't work on any 32-bit platform

View File

@ -112,6 +112,11 @@ jobs:
container: openquantumsafe/ci-ubuntu-latest:latest
CMAKE_ARGS: -DCMAKE_C_COMPILER=clang -DCMAKE_BUILD_TYPE=Debug -DUSE_SANITIZER=Address -DOQS_LIBJADE_BUILD=ON -DOQS_MINIMAL_BUILD="${{ vars.LIBJADE_ALG_LIST }}"
PYTEST_ARGS: --ignore=tests/test_distbuild.py --ignore=tests/test_leaks.py --ignore=tests/test_kat_all.py --maxprocesses=10
- name: noble-no-sha3-avx512vl
runner: ubuntu-latest
container: openquantumsafe/ci-ubuntu-latest:latest
CMAKE_ARGS: -DOQS_USE_SHA3_AVX512VL=OFF
PYTEST_ARGS: --ignore=tests/test_leaks.py --ignore=tests/test_kat_all.py
runs-on: ${{ matrix.runner }}
container:
image: ${{ matrix.container }}
@ -271,3 +276,36 @@ jobs:
- name: Build
run: scan-build --status-bugs ninja
working-directory: build
linux_x86_emulated:
runs-on: ubuntu-latest
container:
image: openquantumsafe/ci-ubuntu-latest:latest
strategy:
fail-fast: false
matrix:
include:
- name: avx512-ml-kem_ml-dsa
SDE_ARCH: -skx
CMAKE_ARGS: -DOQS_MINIMAL_BUILD="KEM_ml_kem_512;KEM_ml_kem_768;KEM_ml_kem_1024;SIG_ml_dsa_44;SIG_ml_dsa_65;SIG_ml_dsa_87"
PYTEST_ARGS: tests/test_hash.py::test_sha3 tests/test_kat.py tests/test_acvp_vectors.py
env:
SDE_URL: https://downloadmirror.intel.com/850782/sde-external-9.53.0-2025-03-16-lin.tar.xz
steps:
- name: Checkout code
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # pin@v4
- name: Setup Intel SDE
run: |
wget -O sde.tar.xz "$SDE_URL" && \
mkdir sde && tar -xf sde.tar.xz -C sde --strip-components=1 && \
echo "$(pwd)/sde" >> $GITHUB_PATH
- name: Configure
run: mkdir build && cd build && cmake -GNinja ${{ matrix.CMAKE_ARGS }} .. && cmake -LA -N ..
- name: Build
run: ninja
working-directory: build
- name: Run tests
timeout-minutes: 60
run: |
mkdir -p tmp && sde64 ${{ matrix.SDE_ARCH }} -- \
python3 -m pytest --verbose --numprocesses=auto ${{ matrix.PYTEST_ARGS }}

View File

@ -61,14 +61,19 @@ else()
endif()
if(${OQS_USE_SHA3_OPENSSL})
if (${OQS_ENABLE_SHA3_xkcp_low})
add_subdirectory(sha3/xkcp_low)
endif()
if (${OQS_ENABLE_SHA3_xkcp_low})
add_subdirectory(sha3/xkcp_low)
endif()
set(SHA3_IMPL sha3/ossl_sha3.c sha3/ossl_sha3x4.c)
set(OSSL_HELPERS ossl_helpers.c)
else() # using XKCP
add_subdirectory(sha3/xkcp_low)
set(SHA3_IMPL sha3/xkcp_sha3.c sha3/xkcp_sha3x4.c)
if(OQS_USE_SHA3_AVX512VL)
# also build avx512vl modules
add_subdirectory(sha3/avx512vl_low)
list(APPEND SHA3_IMPL sha3/avx512vl_sha3.c sha3/avx512vl_sha3x4.c)
endif()
endif()
if ((OQS_LIBJADE_BUILD STREQUAL "ON"))
@ -157,6 +162,11 @@ if(${OQS_ENABLE_SHA3_xkcp_low}) # using XKCP
set(_INTERNAL_OBJS ${_INTERNAL_OBJS} ${XKCP_LOW_OBJS})
endif()
if(${OQS_USE_SHA3_AVX512VL})
set(_COMMON_OBJS ${_COMMON_OBJS} ${SHA3_AVX512VL_LOW_OBJS})
set(_INTERNAL_OBJS ${_INTERNAL_OBJS} ${SHA3_AVX512VL_LOW_OBJS})
endif()
set(_COMMON_OBJS ${_COMMON_OBJS} $<TARGET_OBJECTS:common>)
set(COMMON_OBJS ${_COMMON_OBJS} PARENT_SCOPE)
set(_INTERNAL_OBJS ${_INTERNAL_OBJS} $<TARGET_OBJECTS:internal>)

View File

@ -0,0 +1,14 @@
# Copyright (c) 2025 Intel Corporation
#
# SPDX-License-Identifier: MIT
set(_SHA3_AVX512VL_LOW_OBJS "")
if(OQS_USE_SHA3_AVX512VL)
add_library(sha3_avx512vl_low OBJECT
KeccakP-1600-AVX512VL.S SHA3-AVX512VL.S KeccakP-1600-times4-AVX512VL.S SHA3-times4-AVX512VL.S)
set(_SHA3_AVX512VL_LOW_OBJS ${_SHA3_AVX512VL_LOW_OBJS} $<TARGET_OBJECTS:sha3_avx512vl_low>)
endif()
set(SHA3_AVX512VL_LOW_OBJS ${_SHA3_AVX512VL_LOW_OBJS} PARENT_SCOPE)

View File

@ -0,0 +1,548 @@
# Copyright (c) 2025 Intel Corporation
#
# SPDX-License-Identifier: MIT
# Define arg registers
.equ arg1, %rdi
.equ arg2, %rsi
.text
# Initialized Keccak state in registers
#
# input:
# output: xmm0-xmm24
.globl keccak_1600_init_state
.type keccak_1600_init_state,@function
.hidden keccak_1600_init_state
.balign 32
keccak_1600_init_state:
vpxorq %xmm0, %xmm0, %xmm0
vpxorq %xmm1, %xmm1, %xmm1
vpxorq %xmm2, %xmm2, %xmm2
vmovdqa64 %ymm0, %ymm3
vmovdqa64 %ymm0, %ymm4
vmovdqa64 %ymm0, %ymm5
vmovdqa64 %ymm0, %ymm6
vmovdqa64 %ymm0, %ymm7
vmovdqa64 %ymm0, %ymm8
vmovdqa64 %ymm0, %ymm9
vmovdqa64 %ymm0, %ymm10
vmovdqa64 %ymm0, %ymm11
vmovdqa64 %ymm0, %ymm12
vmovdqa64 %ymm0, %ymm13
vmovdqa64 %ymm0, %ymm14
vmovdqa64 %ymm0, %ymm15
vmovdqa64 %ymm0, %ymm16
vmovdqa64 %ymm0, %ymm17
vmovdqa64 %ymm0, %ymm18
vmovdqa64 %ymm0, %ymm19
vmovdqa64 %ymm0, %ymm20
vmovdqa64 %ymm0, %ymm21
vmovdqa64 %ymm0, %ymm22
vmovdqa64 %ymm0, %ymm23
vmovdqa64 %ymm0, %ymm24
ret
.size keccak_1600_init_state,.-keccak_1600_init_state
# Loads Keccak state from memory into registers
#
# input: arg1 - state pointer
# output: xmm0-xmm24
.globl keccak_1600_load_state
.type keccak_1600_load_state,@function
.hidden keccak_1600_load_state
.balign 32
keccak_1600_load_state:
vmovq (8*0)(arg1), %xmm0
vmovq (8*1)(arg1), %xmm1
vmovq (8*2)(arg1), %xmm2
vmovq (8*3)(arg1), %xmm3
vmovq (8*4)(arg1), %xmm4
vmovq (8*5)(arg1), %xmm5
vmovq (8*6)(arg1), %xmm6
vmovq (8*7)(arg1), %xmm7
vmovq (8*8)(arg1), %xmm8
vmovq (8*9)(arg1), %xmm9
vmovq (8*10)(arg1), %xmm10
vmovq (8*11)(arg1), %xmm11
vmovq (8*12)(arg1), %xmm12
vmovq (8*13)(arg1), %xmm13
vmovq (8*14)(arg1), %xmm14
vmovq (8*15)(arg1), %xmm15
vmovq (8*16)(arg1), %xmm16
vmovq (8*17)(arg1), %xmm17
vmovq (8*18)(arg1), %xmm18
vmovq (8*19)(arg1), %xmm19
vmovq (8*20)(arg1), %xmm20
vmovq (8*21)(arg1), %xmm21
vmovq (8*22)(arg1), %xmm22
vmovq (8*23)(arg1), %xmm23
vmovq (8*24)(arg1), %xmm24
ret
.size keccak_1600_load_state,.-keccak_1600_load_state
# Saves Keccak state to memory
#
# input: arg1 - state pointer
# xmm0-xmm24 - Keccak state registers
# output: memory from [arg1] to [arg1 + 25*8]
.globl keccak_1600_save_state
.type keccak_1600_save_state,@function
.hidden keccak_1600_save_state
.balign 32
keccak_1600_save_state:
vmovq %xmm0, (8*0)(arg1)
vmovq %xmm1, (8*1)(arg1)
vmovq %xmm2, (8*2)(arg1)
vmovq %xmm3, (8*3)(arg1)
vmovq %xmm4, (8*4)(arg1)
vmovq %xmm5, (8*5)(arg1)
vmovq %xmm6, (8*6)(arg1)
vmovq %xmm7, (8*7)(arg1)
vmovq %xmm8, (8*8)(arg1)
vmovq %xmm9, (8*9)(arg1)
vmovq %xmm10, (8*10)(arg1)
vmovq %xmm11, (8*11)(arg1)
vmovq %xmm12, (8*12)(arg1)
vmovq %xmm13, (8*13)(arg1)
vmovq %xmm14, (8*14)(arg1)
vmovq %xmm15, (8*15)(arg1)
vmovq %xmm16, (8*16)(arg1)
vmovq %xmm17, (8*17)(arg1)
vmovq %xmm18, (8*18)(arg1)
vmovq %xmm19, (8*19)(arg1)
vmovq %xmm20, (8*20)(arg1)
vmovq %xmm21, (8*21)(arg1)
vmovq %xmm22, (8*22)(arg1)
vmovq %xmm23, (8*23)(arg1)
vmovq %xmm24, (8*24)(arg1)
ret
.size keccak_1600_save_state,.-keccak_1600_save_state
# Add input data to state when message length is less than rate
#
# input:
# r13 - state
# arg2 - message pointer (updated on output)
# r12 - length (clobbered on output)
# output:
# memory - state from [r13] to [r13 + r12 - 1]
# clobbered:
# rax, k1, ymm31
.globl keccak_1600_partial_add
.type keccak_1600_partial_add,@function
.hidden keccak_1600_partial_add
.balign 32
keccak_1600_partial_add:
.partial_add_loop:
cmpq $32, %r12
jb .lt_32_bytes
vmovdqu64 (arg2), %ymm31
vpxorq (%r13), %ymm31, %ymm31
vmovdqu64 %ymm31, (%r13)
addq $32, arg2
addq $32, %r13
subq $32, %r12
jz .partial_add_done
jmp .partial_add_loop
.lt_32_bytes:
xorq %rax, %rax
bts %r12, %rax
decq %rax
kmovq %rax, %k1 # k1 is the mask of message bytes to read
vmovdqu8 (arg2), %ymm31{%k1}{z} # Read 0 to 31 bytes
vpxorq (%r13), %ymm31, %ymm31
vmovdqu8 %ymm31, (%r13){%k1}
addq %r12, arg2 # Increment message pointer
.partial_add_done:
ret
.size keccak_1600_partial_add,.-keccak_1600_partial_add
# Extract bytes from state and write to output
#
# input:
# r13 - state
# r10 - output pointer (updated on output)
# r12 - length (clobbered on output)
# output:
# memory - output from [r10] to [r10 + r12 - 1]
# clobbered:
# rax, k1, ymm31
.globl keccak_1600_extract_bytes
.type keccak_1600_extract_bytes,@function
.hidden keccak_1600_extract_bytes
.balign 32
keccak_1600_extract_bytes:
.extract_32_byte_loop:
cmpq $32, %r12
jb .extract_lt_32_bytes
vmovdqu64 (%r13), %ymm31
vmovdqu64 %ymm31, (%r10)
addq $32, %r13
addq $32, %r10
subq $32, %r12
jz .extract_done
jmp .extract_32_byte_loop
.extract_lt_32_bytes:
xorq %rax, %rax
bts %r12, %rax
decq %rax
kmovq %rax, %k1 # k1 is the mask of the last message bytes
vmovdqu8 (%r13), %ymm31{%k1}{z} # Read 0 to 31 bytes
vmovdqu8 %ymm31, (%r10){%k1}
addq %r12, %r10 # Increment output pointer
.extract_done:
ret
.size keccak_1600_extract_bytes,.-keccak_1600_extract_bytes
# Copy partial block message into temporary buffer, add padding byte and EOM bit
#
# r13 [in/out] destination pointer
# r12 [in/out] source pointer
# r11 [in/out] length in bytes
# r9 [in] rate
# r8 [in] pointer to the padding byte
# output:
# memory - output from [r13] to [r13 + r11 - 1], [r13 + r11] padding, [r13 + r9 - 1] EOM
# clobbered:
# rax, r15, k1, k2, ymm31
.globl keccak_1600_copy_with_padding
.type keccak_1600_copy_with_padding, @function
.hidden keccak_1600_copy_with_padding
.balign 32
keccak_1600_copy_with_padding:
# Clear the temporary buffer
vpxorq %ymm31, %ymm31, %ymm31
vmovdqu64 %ymm31, (32*0)(%r13)
vmovdqu64 %ymm31, (32*1)(%r13)
vmovdqu64 %ymm31, (32*2)(%r13)
vmovdqu64 %ymm31, (32*3)(%r13)
vmovdqu64 %ymm31, (32*4)(%r13)
vmovdqu64 %ymm31, (32*5)(%r13)
vmovdqu64 %ymm31, (32*6)(%r13)
vmovdqu64 %ymm31, (32*7)(%r13)
xorq %r15, %r15
.balign 32
.copy32_with_padding_loop:
cmpq $32, %r11 # Check at least 32 remaining
jb .partial32_with_padding # If no, then do final copy with padding
vmovdqu64 (%r12,%r15), %ymm31
vmovdqu64 %ymm31, (%r13,%r15)
subq $32, %r11 # Decrement the remaining length
addq $32, %r15 # Increment offset
jmp .copy32_with_padding_loop
.partial32_with_padding:
xorq %rax, %rax
bts %r11, %rax
kmovq %rax, %k2 # k2 is mask of the 1st byte after the message
decq %rax
kmovq %rax, %k1 # k1 is the mask of the last message bytes
vmovdqu8 (%r12,%r15), %ymm31{%k1}{z} # Read 0 to 31 bytes
vpbroadcastb (%r8), %ymm31{%k2} # Add padding
vmovdqu64 %ymm31, (%r13,%r15) # Store whole 32 bytes
xorb $0x80, (-1)(%r13,%r9) # EOM bit - XOR the last byte of the block
ret
.size keccak_1600_copy_with_padding,.-keccak_1600_copy_with_padding
.globl keccak_1600_copy_digest
.type keccak_1600_copy_digest, @function
.hidden keccak_1600_copy_digest
.balign 32
keccak_1600_copy_digest:
.copy32_digest_loop:
cmp $32, arg2 # Check at least 32 remaining
jb .partial32 # If no, then copy final bytes
vmovdqu64 (%r12), %ymm31
vmovdqu64 %ymm31, (%r13)
addq $32, %r13 # Increment destination pointer
addq $32, %r12 # Increment source pointer
subq $32, arg2 # Decrement the remaining length
jz .done
jmp .copy32_digest_loop
.partial32:
xorq %rax, %rax
bts arg2, %rax
dec %rax
kmovq %rax, %k1 # k1 is the mask of the last message bytes
vmovdqu8 (%r12), %ymm31{%k1}{z} # Read 0 to 31 bytes
vmovdqu8 %ymm31, (%r13){%k1} # Store 0 to 31 bytes
.done:
ret
.size keccak_1600_copy_digest,.-keccak_1600_copy_digest
# Perform Keccak permutation
#
# YMM registers 0 to 24 are used as Keccak state registers.
# This function, as is, can work on 1 to 4 independent states at the same time.
#
# There is no clear boundary between Theta, Rho, Pi, Chi and Iota steps.
# Instructions corresponding to these steps overlap for better efficiency.
#
# ymm0-ymm24 [in/out] Keccak state registers (one SIMD per one state register)
# ymm25-ymm31 [clobbered] temporary SIMD registers
# r13 [clobbered] used for round tracking
# r14 [clobbered] used for access to SHA3 constant table
.globl keccak_1600_permute
.type keccak_1600_permute,@function
.hidden keccak_1600_permute
.balign 32
keccak_1600_permute:
movl $24, %r13d # 24 rounds
leaq sha3_rc(%rip), %r14 # Load the address of the SHA3 round constants
.balign 32
keccak_rnd_loop:
# Theta step
# Compute column parities
# C[5] = [0, 0, 0, 0, 0]
# for x in 0 to 4:
# C[x] = state[x][0] XOR state[x][1] XOR state[x][2] XOR state[x][3] XOR state[x][4]
vmovdqa64 %ymm0, %ymm25
vpternlogq $0x96, %ymm5, %ymm10, %ymm25
vmovdqa64 %ymm1, %ymm26
vpternlogq $0x96, %ymm11, %ymm6, %ymm26
vmovdqa64 %ymm2, %ymm27
vpternlogq $0x96, %ymm12, %ymm7, %ymm27
vmovdqa64 %ymm3, %ymm28
vpternlogq $0x96, %ymm13, %ymm8, %ymm28
vmovdqa64 %ymm4, %ymm29
vpternlogq $0x96, %ymm14, %ymm9, %ymm29
vpternlogq $0x96, %ymm20, %ymm15, %ymm25
vpternlogq $0x96, %ymm21, %ymm16, %ymm26
vpternlogq $0x96, %ymm22, %ymm17, %ymm27
vpternlogq $0x96, %ymm23, %ymm18, %ymm28
# Start computing D values and keep computing column parity
# D[5] = [0, 0, 0, 0, 0]
# for x in 0 to 4:
# D[x] = C[(x+4) mod 5] XOR ROTATE_LEFT(C[(x+1) mod 5], 1)
vprolq $1, %ymm26, %ymm30
vprolq $1, %ymm27, %ymm31
vpternlogq $0x96, %ymm24, %ymm19, %ymm29
# Continue computing D values and apply Theta
# for x in 0 to 4:
# for y in 0 to 4:
# state[x][y] = state[x][y] XOR D[x]
vpternlogq $0x96, %ymm30, %ymm29, %ymm0
vpternlogq $0x96, %ymm30, %ymm29, %ymm10
vpternlogq $0x96, %ymm30, %ymm29, %ymm20
vpternlogq $0x96, %ymm30, %ymm29, %ymm5
vpternlogq $0x96, %ymm30, %ymm29, %ymm15
vprolq $1, %ymm28, %ymm30
vpternlogq $0x96, %ymm31, %ymm25, %ymm6
vpternlogq $0x96, %ymm31, %ymm25, %ymm16
vpternlogq $0x96, %ymm31, %ymm25, %ymm1
vpternlogq $0x96, %ymm31, %ymm25, %ymm11
vpternlogq $0x96, %ymm31, %ymm25, %ymm21
vprolq $1, %ymm29, %ymm31
vpbroadcastq (%r14), %ymm29 # Load the round constant into ymm29 (Iota)
addq $8, %r14 # Increment the pointer to the next round constant
vpternlogq $0x96, %ymm30, %ymm26, %ymm12
vpternlogq $0x96, %ymm30, %ymm26, %ymm7
vpternlogq $0x96, %ymm30, %ymm26, %ymm22
vpternlogq $0x96, %ymm30, %ymm26, %ymm17
vpternlogq $0x96, %ymm30, %ymm26, %ymm2
vprolq $1, %ymm25, %ymm30
# Rho step
# Keep applying Theta and start Rho step
#
# ROTATION_OFFSETS[5][5] = [
# [0, 1, 62, 28, 27],
# [36, 44, 6, 55, 20],
# [3, 10, 43, 25, 39],
# [41, 45, 15, 21, 8],
# [18, 2, 61, 56, 14] ]
#
# for x in 0 to 4:
# for y in 0 to 4:
# state[x][y] = ROTATE_LEFT(state[x][y], ROTATION_OFFSETS[x][y])
vpternlogq $0x96, %ymm31, %ymm27, %ymm3
vpternlogq $0x96, %ymm31, %ymm27, %ymm13
vpternlogq $0x96, %ymm31, %ymm27, %ymm23
vprolq $44, %ymm6, %ymm6
vpternlogq $0x96, %ymm31, %ymm27, %ymm18
vpternlogq $0x96, %ymm31, %ymm27, %ymm8
vprolq $43, %ymm12, %ymm12
vprolq $21, %ymm18, %ymm18
vpternlogq $0x96, %ymm30, %ymm28, %ymm24
vprolq $14, %ymm24, %ymm24
vprolq $28, %ymm3, %ymm3
vpternlogq $0x96, %ymm30, %ymm28, %ymm9
vprolq $20, %ymm9, %ymm9
vprolq $3, %ymm10, %ymm10
vpternlogq $0x96, %ymm30, %ymm28, %ymm19
vprolq $45, %ymm16, %ymm16
vprolq $61, %ymm22, %ymm22
vpternlogq $0x96, %ymm30, %ymm28, %ymm4
vprolq $1, %ymm1, %ymm1
vprolq $6, %ymm7, %ymm7
vpternlogq $0x96, %ymm30, %ymm28, %ymm14
# Continue with Rho and start Pi and Chi steps at the same time
# Ternary logic 0xD2 is used for Chi step
#
# for x in 0 to 4:
# for y in 0 to 4:
# state[x][y] = state[x][y] XOR ((NOT state[(x+1) mod 5][y]) AND state[(x+2) mod 5][y])
vprolq $25, %ymm13, %ymm13
vprolq $8, %ymm19, %ymm19
vmovdqa64 %ymm0, %ymm30
vpternlogq $0xD2, %ymm12, %ymm6, %ymm30
vprolq $18, %ymm20, %ymm20
vprolq $27, %ymm4, %ymm4
vpxorq %ymm29, %ymm30, %ymm30 # Iota step
vprolq $36, %ymm5, %ymm5
vprolq $10, %ymm11, %ymm11
vmovdqa64 %ymm6, %ymm31
vpternlogq $0xD2, %ymm18, %ymm12, %ymm31
vprolq $15, %ymm17, %ymm17
vprolq $56, %ymm23, %ymm23
vpternlogq $0xD2, %ymm24, %ymm18, %ymm12
vprolq $62, %ymm2, %ymm2
vprolq $55, %ymm8, %ymm8
vpternlogq $0xD2, %ymm0, %ymm24, %ymm18
vprolq $39, %ymm14, %ymm14
vprolq $41, %ymm15, %ymm15
vpternlogq $0xD2, %ymm6, %ymm0, %ymm24
vmovdqa64 %ymm30, %ymm0
vmovdqa64 %ymm31, %ymm6
vprolq $2, %ymm21, %ymm21
vmovdqa64 %ymm3, %ymm30
vpternlogq $0xD2, %ymm10, %ymm9, %ymm30
vmovdqa64 %ymm9, %ymm31
vpternlogq $0xD2, %ymm16, %ymm10, %ymm31
vpternlogq $0xD2, %ymm22, %ymm16, %ymm10
vpternlogq $0xD2, %ymm3, %ymm22, %ymm16
vpternlogq $0xD2, %ymm9, %ymm3, %ymm22
vmovdqa64 %ymm30, %ymm3
vmovdqa64 %ymm31, %ymm9
vmovdqa64 %ymm1, %ymm30
vpternlogq $0xD2, %ymm13, %ymm7, %ymm30
vmovdqa64 %ymm7, %ymm31
vpternlogq $0xD2, %ymm19, %ymm13, %ymm31
vpternlogq $0xD2, %ymm20, %ymm19, %ymm13
vpternlogq $0xD2, %ymm1, %ymm20, %ymm19
vpternlogq $0xD2, %ymm7, %ymm1, %ymm20
vmovdqa64 %ymm30, %ymm1
vmovdqa64 %ymm31, %ymm7
vmovdqa64 %ymm4, %ymm30
vpternlogq $0xD2, %ymm11, %ymm5, %ymm30
vmovdqa64 %ymm5, %ymm31
vpternlogq $0xD2, %ymm17, %ymm11, %ymm31
vpternlogq $0xD2, %ymm23, %ymm17, %ymm11
vpternlogq $0xD2, %ymm4, %ymm23, %ymm17
vpternlogq $0xD2, %ymm5, %ymm4, %ymm23
vmovdqa64 %ymm30, %ymm4
vmovdqa64 %ymm31, %ymm5
vmovdqa64 %ymm2, %ymm30
vpternlogq $0xD2, %ymm14, %ymm8, %ymm30
vmovdqa64 %ymm8, %ymm31
vpternlogq $0xD2, %ymm15, %ymm14, %ymm31
vpternlogq $0xD2, %ymm21, %ymm15, %ymm14
vpternlogq $0xD2, %ymm2, %ymm21, %ymm15
vpternlogq $0xD2, %ymm8, %ymm2, %ymm21
vmovdqa64 %ymm30, %ymm2
vmovdqa64 %ymm31, %ymm8
# Complete the steps and get updated state registers in ymm0 to ymm24
vmovdqa64 %ymm3, %ymm30
vmovdqa64 %ymm18, %ymm3
vmovdqa64 %ymm17, %ymm18
vmovdqa64 %ymm11, %ymm17
vmovdqa64 %ymm7, %ymm11
vmovdqa64 %ymm10, %ymm7
vmovdqa64 %ymm1, %ymm10
vmovdqa64 %ymm6, %ymm1
vmovdqa64 %ymm9, %ymm6
vmovdqa64 %ymm22, %ymm9
vmovdqa64 %ymm14, %ymm22
vmovdqa64 %ymm20, %ymm14
vmovdqa64 %ymm2, %ymm20
vmovdqa64 %ymm12, %ymm2
vmovdqa64 %ymm13, %ymm12
vmovdqa64 %ymm19, %ymm13
vmovdqa64 %ymm23, %ymm19
vmovdqa64 %ymm15, %ymm23
vmovdqa64 %ymm4, %ymm15
vmovdqa64 %ymm24, %ymm4
vmovdqa64 %ymm21, %ymm24
vmovdqa64 %ymm8, %ymm21
vmovdqa64 %ymm16, %ymm8
vmovdqa64 %ymm5, %ymm16
vmovdqa64 %ymm30, %ymm5
decl %r13d # Decrement the round counter
jnz keccak_rnd_loop # Jump to the start of the loop if r13d is not zero
ret
.size keccak_1600_permute,.-keccak_1600_permute
.section .rodata
.balign 64
sha3_rc:
# SHA3 round constants
# These constants are used in each round of the Keccak permutation
.quad 0x0000000000000001, 0x0000000000008082
.quad 0x800000000000808a, 0x8000000080008000
.quad 0x000000000000808b, 0x0000000080000001
.quad 0x8000000080008081, 0x8000000000008009
.quad 0x000000000000008a, 0x0000000000000088
.quad 0x0000000080008009, 0x000000008000000a
.quad 0x000000008000808b, 0x800000000000008b
.quad 0x8000000000008089, 0x8000000000008003
.quad 0x8000000000008002, 0x8000000000000080
.quad 0x000000000000800a, 0x800000008000000a
.quad 0x8000000080008081, 0x8000000000008080
.quad 0x0000000080000001, 0x8000000080008008
.section .note.GNU-stack,"",%progbits

View File

@ -0,0 +1,375 @@
# Copyright (c) 2025 Intel Corporation
#
# SPDX-License-Identifier: MIT
# Define arg registers
.equ arg1, %rdi
.equ arg2, %rsi
.equ arg3, %rdx
.equ arg4, %rcx
.equ arg5, %r8
.text
# Loads Keccak state from memory into registers
#
# input: arg1 - state pointer
# output: ymm0-ymm24
.globl keccak_1600_load_state_x4
.type keccak_1600_load_state_x4,@function
.hidden keccak_1600_load_state_x4
.balign 32
keccak_1600_load_state_x4:
vmovdqu64 (32*0)(arg1), %ymm0
vmovdqu64 (32*1)(arg1), %ymm1
vmovdqu64 (32*2)(arg1), %ymm2
vmovdqu64 (32*3)(arg1), %ymm3
vmovdqu64 (32*4)(arg1), %ymm4
vmovdqu64 (32*5)(arg1), %ymm5
vmovdqu64 (32*6)(arg1), %ymm6
vmovdqu64 (32*7)(arg1), %ymm7
vmovdqu64 (32*8)(arg1), %ymm8
vmovdqu64 (32*9)(arg1), %ymm9
vmovdqu64 (32*10)(arg1), %ymm10
vmovdqu64 (32*11)(arg1), %ymm11
vmovdqu64 (32*12)(arg1), %ymm12
vmovdqu64 (32*13)(arg1), %ymm13
vmovdqu64 (32*14)(arg1), %ymm14
vmovdqu64 (32*15)(arg1), %ymm15
vmovdqu64 (32*16)(arg1), %ymm16
vmovdqu64 (32*17)(arg1), %ymm17
vmovdqu64 (32*18)(arg1), %ymm18
vmovdqu64 (32*19)(arg1), %ymm19
vmovdqu64 (32*20)(arg1), %ymm20
vmovdqu64 (32*21)(arg1), %ymm21
vmovdqu64 (32*22)(arg1), %ymm22
vmovdqu64 (32*23)(arg1), %ymm23
vmovdqu64 (32*24)(arg1), %ymm24
ret
.size keccak_1600_load_state_x4,.-keccak_1600_load_state_x4
# Saves Keccak state to memory
#
# input: arg1 - state pointer
# ymm0-ymm24 - Keccak state registers
# output: memory from [arg1] to [arg1 + 100*8]
.globl keccak_1600_save_state_x4
.type keccak_1600_save_state_x4,@function
.hidden keccak_1600_save_state_x4
.balign 32
keccak_1600_save_state_x4:
vmovdqu64 %ymm0, (32*0)(arg1)
vmovdqu64 %ymm1, (32*1)(arg1)
vmovdqu64 %ymm2, (32*2)(arg1)
vmovdqu64 %ymm3, (32*3)(arg1)
vmovdqu64 %ymm4, (32*4)(arg1)
vmovdqu64 %ymm5, (32*5)(arg1)
vmovdqu64 %ymm6, (32*6)(arg1)
vmovdqu64 %ymm7, (32*7)(arg1)
vmovdqu64 %ymm8, (32*8)(arg1)
vmovdqu64 %ymm9, (32*9)(arg1)
vmovdqu64 %ymm10, (32*10)(arg1)
vmovdqu64 %ymm11, (32*11)(arg1)
vmovdqu64 %ymm12, (32*12)(arg1)
vmovdqu64 %ymm13, (32*13)(arg1)
vmovdqu64 %ymm14, (32*14)(arg1)
vmovdqu64 %ymm15, (32*15)(arg1)
vmovdqu64 %ymm16, (32*16)(arg1)
vmovdqu64 %ymm17, (32*17)(arg1)
vmovdqu64 %ymm18, (32*18)(arg1)
vmovdqu64 %ymm19, (32*19)(arg1)
vmovdqu64 %ymm20, (32*20)(arg1)
vmovdqu64 %ymm21, (32*21)(arg1)
vmovdqu64 %ymm22, (32*22)(arg1)
vmovdqu64 %ymm23, (32*23)(arg1)
vmovdqu64 %ymm24, (32*24)(arg1)
ret
.size keccak_1600_save_state_x4,.-keccak_1600_save_state_x4
# Add input data to state when message length is less than rate
#
# input:
# r10 - state pointer to absorb into (clobbered)
# arg2 - message pointer lane 0 (updated on output)
# arg3 - message pointer lane 1 (updated on output)
# arg4 - message pointer lane 2 (updated on output)
# arg5 - message pointer lane 3 (updated on output)
# r12 - length in bytes (clobbered on output)
# output:
# memory - state from [r10] to [r10 + 4*r12 - 1]
# clobbered:
# rax, rbx, r15, k1, ymm31-ymm29
.globl keccak_1600_partial_add_x4
.type keccak_1600_partial_add_x4,@function
.hidden keccak_1600_partial_add_x4
.balign 32
keccak_1600_partial_add_x4:
movq (8*100)(%r10), %rax
testl $7, %eax
jz .start_aligned_to_4x8
# start offset is not aligned to register size
# - calculate remaining capacity of the register
# - get the min between length and the capacity of the register
# - perform partial add on the register
# - once aligned to the register go into ymm loop
movq %rax, %r15 # %r15 = s[100]
andl $7, %eax
negl %eax
addl $8, %eax # register capacity = 8 - (offset % 8)
cmpl %eax, %r12d
cmovb %r12d, %eax # %eax = min(register capacity, $length)
leaq byte_kmask_0_to_7(%rip), %rbx
kmovb (%rbx,%rax), %k1 # message load mask
movq %r15, %rbx
andl $~7, %ebx
leaq (%r10, %rbx,4), %r10 # get to state starting register
movq %r15, %rbx
andl $7, %ebx
vmovdqu8 (%r10), %ymm31 # load & store / allocate SB for the register
vmovdqu8 %ymm31, (%r10)
vmovdqu8 (arg2), %xmm31{%k1}{z} # Read 1 to 7 bytes from lane 0
vmovdqu8 (8*0)(%r10,%rbx), %xmm30{%k1}{z} # Read 1 to 7 bytes from state reg lane 0
vpxorq %xmm30, %xmm31, %xmm31
vmovdqu8 %xmm31, (8*0)(%r10,%rbx){%k1} # Write 1 to 7 bytes to state reg lane 0
vmovdqu8 (arg3), %xmm31{%k1}{z} # Read 1 to 7 bytes from lane 1
vmovdqu8 (8*1)(%r10,%rbx), %xmm30{%k1}{z} # Read 1 to 7 bytes from state reg lane 1
vpxorq %xmm30, %xmm31, %xmm31
vmovdqu8 %xmm31, (8*1)(%r10,%rbx){%k1} # Write 1 to 7 bytes to state reg lane 1
vmovdqu8 (arg4), %xmm31{%k1}{z} # Read 1 to 7 bytes from lane 2
vmovdqu8 (8*2)(%r10,%rbx), %xmm30{%k1}{z} # Read 1 to 7 bytes from state reg lane 2
vpxorq %xmm30, %xmm31, %xmm31
vmovdqu8 %xmm31, (8*2)(%r10,%rbx){%k1} # Write 1 to 7 bytes to state reg lane 2
vmovdqu8 (arg5), %xmm31{%k1}{z} # Read 1 to 7 bytes from lane 3
vmovdqu8 (8*3)(%r10,%rbx), %xmm30{%k1}{z} # Read 1 to 7 bytes from state reg lane 3
vpxorq %xmm30, %xmm31, %xmm31
vmovdqu8 %xmm31, (8*3)(%r10,%rbx){%k1} # Write 1 to 7 bytes to state reg lane 3
subq %rax, %r12
jz .zero_bytes
addq %rax, arg2
addq %rax, arg3
addq %rax, arg4
addq %rax, arg5
addq $32, %r10
xorq %rax, %rax
jmp .ymm_loop
.start_aligned_to_4x8:
leaq (%r10,%rax,4), %r10
xorq %rax, %rax
.balign 32
.ymm_loop:
cmpl $8, %r12d
jb .lt_8_bytes
vmovq (arg2, %rax), %xmm31 # Read 8 bytes from lane 0
vpinsrq $1, (arg3, %rax), %xmm31, %xmm31 # Read 8 bytes from lane 1
vmovq (arg4, %rax), %xmm30 # Read 8 bytes from lane 2
vpinsrq $1, (arg5, %rax),%xmm30, %xmm30 # Read 8 bytes from lane 3
vinserti32x4 $1, %xmm30, %ymm31, %ymm31
vpxorq (%r10,%rax,4), %ymm31, %ymm31 # Add data with the state
vmovdqu64 %ymm31, (%r10,%rax,4)
addq $8, %rax
subq $8, %r12
jz .zero_bytes
jmp .ymm_loop
.balign 32
.zero_bytes:
addq %rax, arg2
addq %rax, arg3
addq %rax, arg4
addq %rax, arg5
ret
.balign 32
.lt_8_bytes:
addq %rax, arg2
addq %rax, arg3
addq %rax, arg4
addq %rax, arg5
leaq (%r10,%rax,4), %r10
leaq byte_kmask_0_to_7(%rip), %rbx
kmovb (%rbx,%r12), %k1 # message load mask
vmovdqu8 (arg2), %xmm31{%k1}{z} # Read 1 to 7 bytes from lane 0
vmovdqu8 (arg3), %xmm30{%k1}{z} # Read 1 to 7 bytes from lane 1
vpunpcklqdq %xmm30, %xmm31, %xmm31 # Interleave data from lane 0 and lane 1
vmovdqu8 (arg4), %xmm30{%k1}{z} # Read 1 to 7 bytes from lane 2
vmovdqu8 (arg5), %xmm29{%k1}{z} # Read 1 to 7 bytes from lane 3
vpunpcklqdq %xmm29, %xmm30, %xmm30 # Interleave data from lane 2 and lane 3
vinserti32x4 $1, %xmm30, %ymm31, %ymm31
vpxorq (%r10), %ymm31, %ymm31 # Add data to the state
vmovdqu64 %ymm31, (%r10) # Update state in memory
addq %r12, arg2 # increment message pointer lane 0
addq %r12, arg3 # increment message pointer lane 1
addq %r12, arg4 # increment message pointer lane 2
addq %r12, arg5 # increment message pointer lane 3
ret
.size keccak_1600_partial_add_x4,.-keccak_1600_partial_add_x4
# Extract bytes from state and write to outputs
#
# input:
# r10 - state pointer to start extracting from (clobbered)
# arg1 - output pointer lane 0 (updated on output)
# arg2 - output pointer lane 1 (updated on output)
# arg3 - output pointer lane 2 (updated on output)
# arg4 - output pointer lane 3 (updated on output)
# r12 - length in bytes (clobbered on output)
# r11 - state offset to start extract from
# output:
# memory - output lane 0 from [arg1] to [arg1 + r12 - 1]
# memory - output lane 1 from [arg2] to [arg2 + r12 - 1]
# memory - output lane 2 from [arg3] to [arg3 + r12 - 1]
# memory - output lane 3 from [arg4] to [arg4 + r12 - 1]
# clobbered:
# rax, rbx, k1, ymm31-ymm30
.globl keccak_1600_extract_bytes_x4
.type keccak_1600_extract_bytes_x4,@function
.hidden keccak_1600_extract_bytes_x4
.balign 32
keccak_1600_extract_bytes_x4:
orq %r12, %r12
jz .extract_zero_bytes
testl $7, %r11d
jz .extract_start_aligned_to_4x8
# extract offset is not aligned to the register size (8 bytes)
# - calculate remaining capacity of the register
# - get the min between length to extract and register capacity
# - perform partial add on the register
movq %r11, %rax # %rax = %r11 = offset in the state
andl $7, %eax
negl %eax
addl $8, %eax # register capacity = 8 - (offset % 8)
cmpl %eax, %r12d
cmovb %r12d, %eax # %eax = min(register capacity, length)
leaq byte_kmask_0_to_7(%rip), %rbx
kmovb (%rbx,%rax), %k1 # message store mask
movq %r11, %rbx
andl $~7, %ebx
leaq (%r10,%rbx,4), %r10 # get to state starting register
movq %r11, %rbx
andl $7, %ebx
vmovdqu8 (8*0)(%r10,%rbx), %xmm31{%k1}{z} # Read 1 to 7 bytes from state reg lane 0
vmovdqu8 %xmm31, (arg1){%k1} # Write 1 to 7 bytes to lane 0 output
vmovdqu8 (8*1)(%r10,%rbx), %xmm31{%k1}{z} # Read 1 to 7 bytes from state reg lane 1
vmovdqu8 %xmm31, (arg2){%k1} # Write 1 to 7 bytes to lane 1 output
vmovdqu8 (8*2)(%r10,%rbx), %xmm31{%k1}{z} # Read 1 to 7 bytes from state reg lane 2
vmovdqu8 %xmm31, (arg3){%k1} # Write 1 to 7 bytes to lane 2 output
vmovdqu8 (8*3)(%r10,%rbx), %xmm31{%k1}{z} # Read 1 to 7 bytes from state reg
vmovdqu8 %xmm31, (arg4){%k1} # Write 1 to 7 bytes to lane 3 output
# increment output registers
addq %rax, arg1
addq %rax, arg2
addq %rax, arg3
addq %rax, arg4
# decrement length to extract
subq %rax, %r12
jz .extract_zero_bytes
# there is more data to extract, update state register pointer and go to the main loop
addq $32, %r10
xorq %rax, %rax
jmp .ymm_loop
.extract_start_aligned_to_4x8:
leaq (%r10,%r11,4), %r10
xorq %rax, %rax
.balign 32
.extract_ymm_loop:
cmpq $8, %r12
jb .extract_lt_8_bytes
vmovdqu64 (%r10), %xmm31
vmovdqu64 (16)(%r10), %xmm30
vmovq %xmm31, (arg1, %rax)
vpextrq $1, %xmm31, (arg2, %rax)
vmovq %xmm30, (arg3, %rax)
vpextrq $1, %xmm30, (arg4, %rax)
addq $8, %rax
subq $8, %r12
jz .zero_bytes_left
addq $4*8, %r10
jmp .extract_ymm_loop
.balign 32
.zero_bytes_left:
# increment output pointers
addq %rax, arg1
addq %rax, arg2
addq %rax, arg3
addq %rax, arg4
.extract_zero_bytes:
ret
.balign 32
.extract_lt_8_bytes:
addq %rax, arg1
addq %rax, arg2
addq %rax, arg3
addq %rax, arg4
leaq byte_kmask_0_to_7(%rip), %rax
kmovb (%rax,%r12), %k1 # k1 is the mask of message bytes to read
vmovq (0*8)(%r10), %xmm31 # Read 8 bytes from state lane 0
vmovdqu8 %xmm31, (arg1){%k1} # Extract 1 to 7 bytes of state into output 0
vmovq (1*8)(%r10), %xmm31 # Read 8 bytes from state lane 1
vmovdqu8 %xmm31, (arg2){%k1} # Extract 1 to 7 bytes of state into output 1
vmovq (2*8)(%r10), %xmm31 # Read 8 bytes from state lane 2
vmovdqu8 %xmm31, (arg3){%k1} # Extract 1 to 7 bytes of state into output 2
vmovq (3*8)(%r10), %xmm31 # Read 8 bytes from state lane 3
vmovdqu8 %xmm31, (arg4){%k1} # Extract 1 to 7 bytes of state into output 3
# increment output pointers
addq %r12, arg2
addq %r12, arg3
addq %r12, arg4
addq %r12, arg5
ret
.size keccak_1600_extract_bytes_x4,.-keccak_1600_extract_bytes_x4
.section .rodata
.balign 8
byte_kmask_0_to_7:
.byte 0x00, 0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f # 0xff should never happen
.section .note.GNU-stack,"",%progbits

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,242 @@
/**
* \file avx512vl_sha3.c
* \brief Implementation of the OQS SHA3 API using the AVX512VL low interface.
*
* Copyright (c) 2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*/
#include "sha3.h"
#include <oqs/common.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#define KECCAK_CTX_ALIGNMENT 32
#define _KECCAK_CTX_BYTES (200 + sizeof(uint64_t))
#define KECCAK_CTX_BYTES \
(KECCAK_CTX_ALIGNMENT * \
((_KECCAK_CTX_BYTES + KECCAK_CTX_ALIGNMENT - 1) / KECCAK_CTX_ALIGNMENT))
/*
* External compact functions
*/
extern void SHA3_sha3_256_avx512vl(uint8_t *output, const uint8_t *input,
size_t inlen);
extern void SHA3_sha3_384_avx512vl(uint8_t *output, const uint8_t *input,
size_t inlen);
extern void SHA3_sha3_512_avx512vl(uint8_t *output, const uint8_t *input,
size_t inlen);
extern void SHA3_shake128_avx512vl(uint8_t *output, size_t outlen,
const uint8_t *input, size_t inlen);
extern void SHA3_shake256_avx512vl(uint8_t *output, size_t outlen,
const uint8_t *input, size_t inlen);
/*
* External reset functions
*/
extern void
SHA3_sha3_256_inc_ctx_reset_avx512vl(OQS_SHA3_sha3_256_inc_ctx *state);
extern void
SHA3_sha3_384_inc_ctx_reset_avx512vl(OQS_SHA3_sha3_384_inc_ctx *state);
extern void
SHA3_sha3_512_inc_ctx_reset_avx512vl(OQS_SHA3_sha3_512_inc_ctx *state);
extern void
SHA3_shake128_inc_ctx_reset_avx512vl(OQS_SHA3_shake128_inc_ctx *state);
extern void
SHA3_shake256_inc_ctx_reset_avx512vl(OQS_SHA3_shake256_inc_ctx *state);
/*
* External absorb functions
*/
extern void SHA3_sha3_256_inc_absorb_avx512vl(OQS_SHA3_sha3_256_inc_ctx *state,
const uint8_t *input,
size_t inlen);
extern void SHA3_sha3_384_inc_absorb_avx512vl(OQS_SHA3_sha3_384_inc_ctx *state,
const uint8_t *input,
size_t inlen);
extern void SHA3_sha3_512_inc_absorb_avx512vl(OQS_SHA3_sha3_512_inc_ctx *state,
const uint8_t *input,
size_t inlen);
extern void SHA3_shake128_inc_absorb_avx512vl(OQS_SHA3_shake128_inc_ctx *state,
const uint8_t *input,
size_t inlen);
extern void SHA3_shake256_inc_absorb_avx512vl(OQS_SHA3_shake256_inc_ctx *state,
const uint8_t *input,
size_t inlen);
/*
* External finalize functions
*/
extern void
SHA3_sha3_256_inc_finalize_avx512vl(uint8_t *output,
OQS_SHA3_sha3_256_inc_ctx *state);
extern void
SHA3_sha3_384_inc_finalize_avx512vl(uint8_t *output,
OQS_SHA3_sha3_384_inc_ctx *state);
extern void
SHA3_sha3_512_inc_finalize_avx512vl(uint8_t *output,
OQS_SHA3_sha3_512_inc_ctx *state);
extern void
SHA3_shake128_inc_finalize_avx512vl(OQS_SHA3_shake128_inc_ctx *state);
extern void
SHA3_shake256_inc_finalize_avx512vl(OQS_SHA3_shake256_inc_ctx *state);
/*
* External squeeze functions
*/
extern void
SHA3_shake128_inc_squeeze_avx512vl(uint8_t *output, size_t outlen,
OQS_SHA3_shake128_inc_ctx *state);
extern void
SHA3_shake256_inc_squeeze_avx512vl(uint8_t *output, size_t outlen,
OQS_SHA3_shake256_inc_ctx *state);
/*
* SHA3-256
*/
static void SHA3_sha3_256_inc_init_avx512vl(OQS_SHA3_sha3_256_inc_ctx *state) {
state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES);
OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3");
SHA3_sha3_256_inc_ctx_reset_avx512vl(state);
}
static void
SHA3_sha3_256_inc_ctx_release_avx512vl(OQS_SHA3_sha3_256_inc_ctx *state) {
SHA3_sha3_256_inc_ctx_reset_avx512vl(state);
OQS_MEM_aligned_free(state->ctx);
}
static void
SHA3_sha3_256_inc_ctx_clone_avx512vl(OQS_SHA3_sha3_256_inc_ctx *dest,
const OQS_SHA3_sha3_256_inc_ctx *src) {
memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES);
}
/*
* SHA3-384
*/
static void SHA3_sha3_384_inc_init_avx512vl(OQS_SHA3_sha3_384_inc_ctx *state) {
state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES);
OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3");
SHA3_sha3_384_inc_ctx_reset_avx512vl(state);
}
static void
SHA3_sha3_384_inc_ctx_release_avx512vl(OQS_SHA3_sha3_384_inc_ctx *state) {
SHA3_sha3_384_inc_ctx_reset_avx512vl(state);
OQS_MEM_aligned_free(state->ctx);
}
static void
SHA3_sha3_384_inc_ctx_clone_avx512vl(OQS_SHA3_sha3_384_inc_ctx *dest,
const OQS_SHA3_sha3_384_inc_ctx *src) {
memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES);
}
/*
* SHA3-512
*/
static void SHA3_sha3_512_inc_init_avx512vl(OQS_SHA3_sha3_512_inc_ctx *state) {
state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES);
OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3");
SHA3_sha3_512_inc_ctx_reset_avx512vl(state);
}
static void
SHA3_sha3_512_inc_ctx_release_avx512vl(OQS_SHA3_sha3_512_inc_ctx *state) {
SHA3_sha3_512_inc_ctx_reset_avx512vl(state);
OQS_MEM_aligned_free(state->ctx);
}
static void
SHA3_sha3_512_inc_ctx_clone_avx512vl(OQS_SHA3_sha3_512_inc_ctx *dest,
const OQS_SHA3_sha3_512_inc_ctx *src) {
memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES);
}
/*
* SHAKE128
*/
static void SHA3_shake128_inc_init_avx512vl(OQS_SHA3_shake128_inc_ctx *state) {
state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES);
OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3");
SHA3_shake128_inc_ctx_reset_avx512vl(state);
}
static void
SHA3_shake128_inc_ctx_clone_avx512vl(OQS_SHA3_shake128_inc_ctx *dest,
const OQS_SHA3_shake128_inc_ctx *src) {
memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES);
}
static void
SHA3_shake128_inc_ctx_release_avx512vl(OQS_SHA3_shake128_inc_ctx *state) {
SHA3_shake128_inc_ctx_reset_avx512vl(state);
OQS_MEM_aligned_free(state->ctx);
}
/*
* SHAKE256
*/
static void SHA3_shake256_inc_init_avx512vl(OQS_SHA3_shake256_inc_ctx *state) {
state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES);
OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3");
SHA3_shake256_inc_ctx_reset_avx512vl(state);
}
static void
SHA3_shake256_inc_ctx_release_avx512vl(OQS_SHA3_shake256_inc_ctx *state) {
SHA3_shake256_inc_ctx_reset_avx512vl(state);
OQS_MEM_aligned_free(state->ctx);
}
static void
SHA3_shake256_inc_ctx_clone_avx512vl(OQS_SHA3_shake256_inc_ctx *dest,
const OQS_SHA3_shake256_inc_ctx *src) {
memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES);
}
const struct OQS_SHA3_callbacks sha3_avx512vl_callbacks = {
SHA3_sha3_256_avx512vl,
SHA3_sha3_256_inc_init_avx512vl,
SHA3_sha3_256_inc_absorb_avx512vl,
SHA3_sha3_256_inc_finalize_avx512vl,
SHA3_sha3_256_inc_ctx_release_avx512vl,
SHA3_sha3_256_inc_ctx_reset_avx512vl,
SHA3_sha3_256_inc_ctx_clone_avx512vl,
SHA3_sha3_384_avx512vl,
SHA3_sha3_384_inc_init_avx512vl,
SHA3_sha3_384_inc_absorb_avx512vl,
SHA3_sha3_384_inc_finalize_avx512vl,
SHA3_sha3_384_inc_ctx_release_avx512vl,
SHA3_sha3_384_inc_ctx_reset_avx512vl,
SHA3_sha3_384_inc_ctx_clone_avx512vl,
SHA3_sha3_512_avx512vl,
SHA3_sha3_512_inc_init_avx512vl,
SHA3_sha3_512_inc_absorb_avx512vl,
SHA3_sha3_512_inc_finalize_avx512vl,
SHA3_sha3_512_inc_ctx_release_avx512vl,
SHA3_sha3_512_inc_ctx_reset_avx512vl,
SHA3_sha3_512_inc_ctx_clone_avx512vl,
SHA3_shake128_avx512vl,
SHA3_shake128_inc_init_avx512vl,
SHA3_shake128_inc_absorb_avx512vl,
SHA3_shake128_inc_finalize_avx512vl,
SHA3_shake128_inc_squeeze_avx512vl,
SHA3_shake128_inc_ctx_release_avx512vl,
SHA3_shake128_inc_ctx_clone_avx512vl,
SHA3_shake128_inc_ctx_reset_avx512vl,
SHA3_shake256_avx512vl,
SHA3_shake256_inc_init_avx512vl,
SHA3_shake256_inc_absorb_avx512vl,
SHA3_shake256_inc_finalize_avx512vl,
SHA3_shake256_inc_squeeze_avx512vl,
SHA3_shake256_inc_ctx_release_avx512vl,
SHA3_shake256_inc_ctx_clone_avx512vl,
SHA3_shake256_inc_ctx_reset_avx512vl,
};

View File

@ -0,0 +1,133 @@
/**
* \file avx512vl_sha3x4.c
* \brief Implementation of the OQS SHA3 times 4 API using the AVX512VL low interface.
*
* Copyright (c) 2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*/
#include "sha3x4.h"
#include <oqs/common.h>
#include <oqs/oqsconfig.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#define KECCAK_X4_CTX_ALIGNMENT 32
#define _KECCAK_X4_CTX_BYTES (800 + sizeof(uint64_t))
#define KECCAK_X4_CTX_BYTES \
(KECCAK_X4_CTX_ALIGNMENT * \
((_KECCAK_X4_CTX_BYTES + KECCAK_X4_CTX_ALIGNMENT - 1) / KECCAK_X4_CTX_ALIGNMENT))
/********** SHAKE128 ***********/
/* SHAKE128 external */
extern void
SHA3_shake128_x4_avx512vl(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen,
const uint8_t *in0, const uint8_t *in1, const uint8_t *in2,
const uint8_t *in3, size_t inlen);
extern void
SHA3_shake128_x4_inc_ctx_reset_avx512vl(OQS_SHA3_shake128_x4_inc_ctx *state);
extern void
SHA3_shake128_x4_inc_absorb_avx512vl(OQS_SHA3_shake128_x4_inc_ctx *state, const uint8_t *in0,
const uint8_t *in1, const uint8_t *in2, const uint8_t *in3,
size_t inlen);
extern void
SHA3_shake128_x4_inc_finalize_avx512vl(OQS_SHA3_shake128_x4_inc_ctx *state);
extern void
SHA3_shake128_x4_inc_squeeze_avx512vl(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3,
size_t outlen, OQS_SHA3_shake128_x4_inc_ctx *state);
/* SHAKE128 incremental */
static void
SHA3_shake128_x4_inc_init_avx512vl(OQS_SHA3_shake128_x4_inc_ctx *state) {
state->ctx = OQS_MEM_aligned_alloc(KECCAK_X4_CTX_ALIGNMENT, KECCAK_X4_CTX_BYTES);
OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3x4");
SHA3_shake128_x4_inc_ctx_reset_avx512vl(state);
}
static void
SHA3_shake128_x4_inc_ctx_clone_avx512vl(OQS_SHA3_shake128_x4_inc_ctx *dest,
const OQS_SHA3_shake128_x4_inc_ctx *src) {
memcpy(dest->ctx, src->ctx, KECCAK_X4_CTX_BYTES);
}
static void
SHA3_shake128_x4_inc_ctx_release_avx512vl(OQS_SHA3_shake128_x4_inc_ctx *state) {
SHA3_shake128_x4_inc_ctx_reset_avx512vl(state);
OQS_MEM_aligned_free(state->ctx);
}
/********** SHAKE256 ***********/
/* SHAKE256 external */
extern void
SHA3_shake256_x4_avx512vl(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen,
const uint8_t *in0, const uint8_t *in1, const uint8_t *in2,
const uint8_t *in3, size_t inlen);
extern void
SHA3_shake256_x4_inc_ctx_reset_avx512vl(OQS_SHA3_shake256_x4_inc_ctx *state);
extern void
SHA3_shake256_x4_inc_absorb_avx512vl(OQS_SHA3_shake256_x4_inc_ctx *state, const uint8_t *in0,
const uint8_t *in1, const uint8_t *in2, const uint8_t *in3,
size_t inlen);
extern void
SHA3_shake256_x4_inc_finalize_avx512vl(OQS_SHA3_shake256_x4_inc_ctx *state);
extern void
SHA3_shake256_x4_inc_squeeze_avx512vl(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3,
size_t outlen, OQS_SHA3_shake256_x4_inc_ctx *state);
/* SHAKE256 incremental */
static void
SHA3_shake256_x4_inc_init_avx512vl(OQS_SHA3_shake256_x4_inc_ctx *state) {
state->ctx = OQS_MEM_aligned_alloc(KECCAK_X4_CTX_ALIGNMENT, KECCAK_X4_CTX_BYTES);
OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3x4");
SHA3_shake256_x4_inc_ctx_reset_avx512vl(state);
}
static void
SHA3_shake256_x4_inc_ctx_clone_avx512vl(OQS_SHA3_shake256_x4_inc_ctx *dest,
const OQS_SHA3_shake256_x4_inc_ctx *src) {
memcpy(dest->ctx, src->ctx, KECCAK_X4_CTX_BYTES);
}
static void
SHA3_shake256_x4_inc_ctx_release_avx512vl(OQS_SHA3_shake256_x4_inc_ctx *state) {
SHA3_shake256_x4_inc_ctx_reset_avx512vl(state);
OQS_MEM_aligned_free(state->ctx);
}
const struct OQS_SHA3_x4_callbacks sha3_x4_avx512vl_callbacks = {
SHA3_shake128_x4_avx512vl,
SHA3_shake128_x4_inc_init_avx512vl,
SHA3_shake128_x4_inc_absorb_avx512vl,
SHA3_shake128_x4_inc_finalize_avx512vl,
SHA3_shake128_x4_inc_squeeze_avx512vl,
SHA3_shake128_x4_inc_ctx_release_avx512vl,
SHA3_shake128_x4_inc_ctx_clone_avx512vl,
SHA3_shake128_x4_inc_ctx_reset_avx512vl,
SHA3_shake256_x4_avx512vl,
SHA3_shake256_x4_inc_init_avx512vl,
SHA3_shake256_x4_inc_absorb_avx512vl,
SHA3_shake256_x4_inc_finalize_avx512vl,
SHA3_shake256_x4_inc_squeeze_avx512vl,
SHA3_shake256_x4_inc_ctx_release_avx512vl,
SHA3_shake256_x4_inc_ctx_clone_avx512vl,
SHA3_shake256_x4_inc_ctx_reset_avx512vl,
};

View File

@ -37,10 +37,19 @@ static KeccakPermuteFn *Keccak_Permute_ptr = NULL;
static KeccakExtractBytesFn *Keccak_ExtractBytes_ptr = NULL;
static KeccakFastLoopAbsorbFn *Keccak_FastLoopAbsorb_ptr = NULL;
extern struct OQS_SHA3_callbacks sha3_default_callbacks;
static void Keccak_Dispatch(void) {
// TODO: Simplify this when we have a Windows-compatible AVX2 implementation of SHA3
#if defined(OQS_DIST_X86_64_BUILD)
#if defined(OQS_ENABLE_SHA3_xkcp_low_avx2)
#if defined(OQS_USE_SHA3_AVX512VL)
if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX512)) {
extern const struct OQS_SHA3_callbacks sha3_avx512vl_callbacks;
sha3_default_callbacks = sha3_avx512vl_callbacks;
}
#endif
if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2)) {
Keccak_Initialize_ptr = &KeccakP1600_Initialize_avx2;
Keccak_AddByte_ptr = &KeccakP1600_AddByte_avx2;
@ -385,8 +394,6 @@ static void SHA3_shake256_inc_ctx_reset(OQS_SHA3_shake256_inc_ctx *state) {
keccak_inc_reset((uint64_t *)state->ctx);
}
extern struct OQS_SHA3_callbacks sha3_default_callbacks;
struct OQS_SHA3_callbacks sha3_default_callbacks = {
SHA3_sha3_256,
SHA3_sha3_256_inc_init,

View File

@ -31,10 +31,19 @@ static KeccakX4AddBytesFn *Keccak_X4_AddBytes_ptr = NULL;
static KeccakX4PermuteFn *Keccak_X4_Permute_ptr = NULL;
static KeccakX4ExtractBytesFn *Keccak_X4_ExtractBytes_ptr = NULL;
extern struct OQS_SHA3_x4_callbacks sha3_x4_default_callbacks;
static void Keccak_X4_Dispatch(void) {
// TODO: Simplify this when we have a Windows-compatible AVX2 implementation of SHA3
#if defined(OQS_DIST_X86_64_BUILD)
#if defined(OQS_ENABLE_SHA3_xkcp_low_avx2)
#if defined(OQS_USE_SHA3_AVX512VL)
if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX512)) {
extern const struct OQS_SHA3_x4_callbacks sha3_x4_avx512vl_callbacks;
sha3_x4_default_callbacks = sha3_x4_avx512vl_callbacks;
}
#endif
if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2)) {
Keccak_X4_Initialize_ptr = &KeccakP1600times4_InitializeAll_avx2;
Keccak_X4_AddByte_ptr = &KeccakP1600times4_AddByte_avx2;
@ -238,8 +247,6 @@ static void SHA3_shake256_x4_inc_ctx_reset(OQS_SHA3_shake256_x4_inc_ctx *state)
keccak_x4_inc_reset((uint64_t *)state->ctx);
}
extern struct OQS_SHA3_x4_callbacks sha3_x4_default_callbacks;
struct OQS_SHA3_x4_callbacks sha3_x4_default_callbacks = {
SHA3_shake128_x4,
SHA3_shake128_x4_inc_init,

View File

@ -69,6 +69,7 @@
#cmakedefine OQS_ENABLE_TEST_CONSTANT_TIME 1
#cmakedefine OQS_ENABLE_SHA3_xkcp_low_avx2 1
#cmakedefine OQS_USE_SHA3_AVX512VL 1
#cmakedefine01 OQS_USE_CUPQC

View File

@ -252,6 +252,14 @@ static void print_oqs_configuration(void) {
#endif
#if defined(OQS_USE_SHA3_OPENSSL)
printf("SHA-3: OpenSSL\n");
#elif defined(OQS_USE_SHA3_AVX512VL)
if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX512)) {
printf("SHA-3: AVX512VL\n");
} else if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2)) {
printf("SHA-3: AVX2\n");
} else {
printf("SHA-3: C\n");
}
#else
printf("SHA-3: C\n");
#endif

View File

@ -33,7 +33,7 @@ def test_namespace():
symbols.append(line)
# ideally this would be just ['oqs', 'pqclean'], but contains exceptions (e.g., providing compat implementations of unavailable platform functions)
namespaces = ['oqs', 'pqclean', 'keccak', 'pqcrystals', 'pqmayo', 'init', 'fini', 'seedexpander', '__x86.get_pc_thunk', 'libjade', 'jade', '__jade', '__jasmin_syscall', 'pqcp', 'pqov', '_snova']
namespaces = ['oqs', 'pqclean', 'keccak', 'pqcrystals', 'pqmayo', 'init', 'fini', 'seedexpander', '__x86.get_pc_thunk', 'libjade', 'jade', '__jade', '__jasmin_syscall', 'pqcp', 'pqov', '_snova', 'sha3']
non_namespaced = []
for symbolstr in symbols:

View File

@ -197,6 +197,135 @@ int sha3_256_kat_test(void) {
return status;
}
int sha3_384_kat_test(void) {
int status = EXIT_SUCCESS;
uint8_t output[48];
const uint8_t msg0[1] = {0x0};
const uint8_t msg24[3] = {0x61, 0x62, 0x63};
const uint8_t msg448[56] = {
0x61, 0x62, 0x63, 0x64, 0x62, 0x63, 0x64, 0x65, 0x63, 0x64, 0x65, 0x66, 0x64, 0x65, 0x66, 0x67,
0x65, 0x66, 0x67, 0x68, 0x66, 0x67, 0x68, 0x69, 0x67, 0x68, 0x69, 0x6A, 0x68, 0x69, 0x6A, 0x6B,
0x69, 0x6A, 0x6B, 0x6C, 0x6A, 0x6B, 0x6C, 0x6D, 0x6B, 0x6C, 0x6D, 0x6E, 0x6C, 0x6D, 0x6E, 0x6F,
0x6D, 0x6E, 0x6F, 0x70, 0x6E, 0x6F, 0x70, 0x71
};
const uint8_t msg1600[200] = {
0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3,
0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3,
0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3,
0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3,
0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3,
0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3,
0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3,
0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3,
0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3,
0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3,
0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3,
0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3,
0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3
};
const uint8_t exp0[48] = {
0x0C, 0x63, 0xA7, 0x5B, 0x84, 0x5E, 0x4F, 0x7D, 0x01, 0x10, 0x7D, 0x85, 0x2E, 0x4C, 0x24,
0x85, 0xC5, 0x1A, 0x50, 0xAA, 0xAA, 0x94, 0xFC, 0x61, 0x99, 0x5E, 0x71, 0xBB, 0xEE, 0x98,
0x3A, 0x2A, 0xC3, 0x71, 0x38, 0x31, 0x26, 0x4A, 0xDB, 0x47, 0xFB, 0x6B, 0xD1, 0xE0, 0x58,
0xD5, 0xF0, 0x04
};
const uint8_t exp24[48] = {
0xEC, 0x01, 0x49, 0x82, 0x88, 0x51, 0x6F, 0xC9, 0x26, 0x45, 0x9F, 0x58, 0xE2, 0xC6, 0xAD,
0x8D, 0xF9, 0xB4, 0x73, 0xCB, 0x0F, 0xC0, 0x8C, 0x25, 0x96, 0xDA, 0x7C, 0xF0, 0xE4, 0x9B,
0xE4, 0xB2, 0x98, 0xD8, 0x8C, 0xEA, 0x92, 0x7A, 0xC7, 0xF5, 0x39, 0xF1, 0xED, 0xF2, 0x28,
0x37, 0x6D, 0x25
};
const uint8_t exp448[48] = {
0x99, 0x1C, 0x66, 0x57, 0x55, 0xEB, 0x3A, 0x4B, 0x6B, 0xBD, 0xFB, 0x75, 0xC7, 0x8A, 0x49,
0x2E, 0x8C, 0x56, 0xA2, 0x2C, 0x5C, 0x4D, 0x7E, 0x42, 0x9B, 0xFD, 0xBC, 0x32, 0xB9, 0xD4,
0xAD, 0x5A, 0xA0, 0x4A, 0x1F, 0x07, 0x6E, 0x62, 0xFE, 0xA1, 0x9E, 0xEF, 0x51, 0xAC, 0xD0,
0x65, 0x7C, 0x22
};
const uint8_t exp1600[48] = {
0x18, 0x81, 0xDE, 0x2C, 0xA7, 0xE4, 0x1E, 0xF9, 0x5D, 0xC4, 0x73, 0x2B, 0x8F, 0x5F, 0x00,
0x2B, 0x18, 0x9C, 0xC1, 0xE4, 0x2B, 0x74, 0x16, 0x8E, 0xD1, 0x73, 0x26, 0x49, 0xCE, 0x1D,
0xBC, 0xDD, 0x76, 0x19, 0x7A, 0x31, 0xFD, 0x55, 0xEE, 0x98, 0x9F, 0x2D, 0x70, 0x50, 0xDD,
0x47, 0x3E, 0x8F
};
/* test compact api */
clear8(output, 48);
OQS_SHA3_sha3_384(output, msg0, 0);
if (are_equal8(output, exp0, 48) == EXIT_FAILURE) {
status = EXIT_FAILURE;
}
clear8(output, 48);
OQS_SHA3_sha3_384(output, msg24, 3);
if (are_equal8(output, exp24, 48) == EXIT_FAILURE) {
status = EXIT_FAILURE;
}
clear8(output, 48);
OQS_SHA3_sha3_384(output, msg448, 56);
if (are_equal8(output, exp448, 48) == EXIT_FAILURE) {
status = EXIT_FAILURE;
}
clear8(output, 48);
OQS_SHA3_sha3_384(output, msg1600, 200);
if (are_equal8(output, exp1600, 48) == EXIT_FAILURE) {
status = EXIT_FAILURE;
}
/* test long-form api */
OQS_SHA3_sha3_384_inc_ctx state;
uint8_t hash[200];
clear8(hash, 200);
OQS_SHA3_sha3_384_inc_init(&state);
OQS_SHA3_sha3_384_inc_absorb(&state, msg0, 0);
OQS_SHA3_sha3_384_inc_finalize(hash, &state);
if (are_equal8(hash, exp0, 48) == EXIT_FAILURE) {
status = EXIT_FAILURE;
}
clear8(hash, 200);
OQS_SHA3_sha3_384_inc_ctx_reset(&state);
OQS_SHA3_sha3_384_inc_absorb(&state, msg24, 3);
OQS_SHA3_sha3_384_inc_finalize(hash, &state);
if (are_equal8(hash, exp24, 48) == EXIT_FAILURE) {
status = EXIT_FAILURE;
}
clear8(hash, 200);
OQS_SHA3_sha3_384_inc_ctx_reset(&state);
OQS_SHA3_sha3_384_inc_absorb(&state, msg448, 56);
OQS_SHA3_sha3_384_inc_finalize(hash, &state);
if (are_equal8(hash, exp448, 48) == EXIT_FAILURE) {
status = EXIT_FAILURE;
}
clear8(hash, 200);
OQS_SHA3_sha3_384_inc_ctx_reset(&state);
OQS_SHA3_sha3_384_inc_absorb(&state, msg1600, 200);
OQS_SHA3_sha3_384_inc_finalize(hash, &state);
OQS_SHA3_sha3_384_inc_ctx_release(&state);
if (are_equal8(hash, exp1600, 48) == EXIT_FAILURE) {
status = EXIT_FAILURE;
}
return status;
}
/**
* \brief Tests the 512 bit version of the keccak message digest for correct operation,
* using selected vectors from NIST Fips202 and alternative references.
@ -1054,12 +1183,31 @@ static void override_SHA3_shake128_x4_inc_init(OQS_SHA3_shake128_x4_inc_ctx *sta
sha3_x4_default_callbacks.SHA3_shake128_x4_inc_init(state);
}
#ifdef OQS_USE_SHA3_AVX512VL
/**
* \brief Trigger SHA3 internal callback dispatcher.
*
* This is required to trigger runtime AVX512/AVX2 detection and set
* SHA3 default callbacks before test application callbacks are configured.
*/
static void sha3_trigger_dispatcher(void) {
OQS_SHA3_sha3_256_inc_ctx state;
OQS_SHA3_sha3_256_inc_init(&state);
OQS_SHA3_sha3_256_inc_ctx_release(&state);
}
#endif
/**
* \brief Run the SHA3 and SHAKE KAT tests
*/
int main(UNUSED int argc, UNUSED char **argv) {
int ret = EXIT_SUCCESS;
#ifdef OQS_USE_SHA3_AVX512VL
/* set SHA3 default callbacks */
sha3_trigger_dispatcher();
#endif
struct OQS_SHA3_callbacks sha3_callbacks = sha3_default_callbacks;
sha3_callbacks.SHA3_sha3_256_inc_init = override_SHA3_sha3_256_inc_init;
@ -1080,6 +1228,13 @@ int main(UNUSED int argc, UNUSED char **argv) {
ret = EXIT_FAILURE;
}
if (sha3_384_kat_test() == EXIT_SUCCESS) {
printf("Success! passed sha3-384 known answer tests \n");
} else {
printf("Failure! failed sha3-384 known answer tests \n");
ret = EXIT_FAILURE;
}
if (sha3_512_kat_test() == EXIT_SUCCESS) {
printf("Success! passed sha3-512 known answer tests \n");
} else {