From 8f926065ebd90591106e121a847f586488e6071f Mon Sep 17 00:00:00 2001 From: Marcel Cornu Date: Fri, 20 Jun 2025 18:37:32 +0100 Subject: [PATCH] Add AVX512VL-Optimized SHA3/SHAKE Implementations (#2167) * Add SHA3-256/384/512 and SHAKE128/256 AVX512VL implementations Co-authored-by: Tomasz Kantecki Co-authored-by: Erdinc Ozturk Signed-off-by: Marcel Cornu Signed-off-by: Tomasz Kantecki * AVX512VL SHA3 is added as an extension of XKCP implementation Co-authored-by: Marcel Cornu Signed-off-by: Tomasz Kantecki Signed-off-by: Marcel Cornu * Add SHA3-384 tests Signed-off-by: Marcel Cornu * Update namespace test to include SHA3 Signed-off-by: Marcel Cornu * Release SHA3 context after triggering dispatcher Signed-off-by: Marcel Cornu * Add linux CI for OQS_USE_SHA3_AVX512VL=OFF config Signed-off-by: Marcel Cornu * Add AVX512 emulation to linux CI Signed-off-by: Marcel Cornu --------- Signed-off-by: Marcel Cornu Signed-off-by: Tomasz Kantecki Co-authored-by: Tomasz Kantecki Co-authored-by: Erdinc Ozturk --- .CMake/alg_support.cmake | 7 + .github/workflows/linux.yml | 38 + src/common/CMakeLists.txt | 16 +- src/common/sha3/avx512vl_low/CMakeLists.txt | 14 + .../sha3/avx512vl_low/KeccakP-1600-AVX512VL.S | 548 +++++++ .../KeccakP-1600-times4-AVX512VL.S | 375 +++++ src/common/sha3/avx512vl_low/SHA3-AVX512VL.S | 1159 +++++++++++++++ .../sha3/avx512vl_low/SHA3-times4-AVX512VL.S | 1277 +++++++++++++++++ src/common/sha3/avx512vl_sha3.c | 242 ++++ src/common/sha3/avx512vl_sha3x4.c | 133 ++ src/common/sha3/xkcp_sha3.c | 11 +- src/common/sha3/xkcp_sha3x4.c | 11 +- src/oqsconfig.h.cmake | 1 + tests/system_info.c | 8 + tests/test_binary.py | 2 +- tests/test_sha3.c | 155 ++ 16 files changed, 3989 insertions(+), 8 deletions(-) create mode 100644 src/common/sha3/avx512vl_low/CMakeLists.txt create mode 100644 src/common/sha3/avx512vl_low/KeccakP-1600-AVX512VL.S create mode 100644 src/common/sha3/avx512vl_low/KeccakP-1600-times4-AVX512VL.S create mode 100644 src/common/sha3/avx512vl_low/SHA3-AVX512VL.S create mode 100644 src/common/sha3/avx512vl_low/SHA3-times4-AVX512VL.S create mode 100644 src/common/sha3/avx512vl_sha3.c create mode 100644 src/common/sha3/avx512vl_sha3x4.c diff --git a/.CMake/alg_support.cmake b/.CMake/alg_support.cmake index d5b5f0a77..06fcb095e 100644 --- a/.CMake/alg_support.cmake +++ b/.CMake/alg_support.cmake @@ -78,6 +78,13 @@ if(OQS_DIST_X86_64_BUILD OR OQS_USE_AVX2_INSTRUCTIONS) endif() endif() +# SHA3 AVX512VL only supported on Linux x86_64 +if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND (OQS_DIST_X86_64_BUILD OR OQS_USE_AVX512_INSTRUCTIONS)) + cmake_dependent_option(OQS_USE_SHA3_AVX512VL "Enable SHA3 AVX512VL usage" ON "NOT OQS_USE_SHA3_OPENSSL" OFF) +else() + option(OQS_USE_SHA3_AVX512VL "Enable SHA3 AVX512VL usage" OFF) +endif() + # BIKE is not supported on Windows, 32-bit ARM, X86, S390X (big endian) and PPC64 (big endian) cmake_dependent_option(OQS_ENABLE_KEM_BIKE "Enable BIKE algorithm family" ON "NOT WIN32; NOT ARCH_ARM32v7; NOT ARCH_X86; NOT ARCH_S390X; NOT ARCH_PPC64" OFF) # BIKE doesn't work on any 32-bit platform diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 8621fb0dc..7f83c7864 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -112,6 +112,11 @@ jobs: container: openquantumsafe/ci-ubuntu-latest:latest CMAKE_ARGS: -DCMAKE_C_COMPILER=clang -DCMAKE_BUILD_TYPE=Debug -DUSE_SANITIZER=Address -DOQS_LIBJADE_BUILD=ON -DOQS_MINIMAL_BUILD="${{ vars.LIBJADE_ALG_LIST }}" PYTEST_ARGS: --ignore=tests/test_distbuild.py --ignore=tests/test_leaks.py --ignore=tests/test_kat_all.py --maxprocesses=10 + - name: noble-no-sha3-avx512vl + runner: ubuntu-latest + container: openquantumsafe/ci-ubuntu-latest:latest + CMAKE_ARGS: -DOQS_USE_SHA3_AVX512VL=OFF + PYTEST_ARGS: --ignore=tests/test_leaks.py --ignore=tests/test_kat_all.py runs-on: ${{ matrix.runner }} container: image: ${{ matrix.container }} @@ -271,3 +276,36 @@ jobs: - name: Build run: scan-build --status-bugs ninja working-directory: build + + linux_x86_emulated: + runs-on: ubuntu-latest + container: + image: openquantumsafe/ci-ubuntu-latest:latest + strategy: + fail-fast: false + matrix: + include: + - name: avx512-ml-kem_ml-dsa + SDE_ARCH: -skx + CMAKE_ARGS: -DOQS_MINIMAL_BUILD="KEM_ml_kem_512;KEM_ml_kem_768;KEM_ml_kem_1024;SIG_ml_dsa_44;SIG_ml_dsa_65;SIG_ml_dsa_87" + PYTEST_ARGS: tests/test_hash.py::test_sha3 tests/test_kat.py tests/test_acvp_vectors.py + env: + SDE_URL: https://downloadmirror.intel.com/850782/sde-external-9.53.0-2025-03-16-lin.tar.xz + steps: + - name: Checkout code + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # pin@v4 + - name: Setup Intel SDE + run: | + wget -O sde.tar.xz "$SDE_URL" && \ + mkdir sde && tar -xf sde.tar.xz -C sde --strip-components=1 && \ + echo "$(pwd)/sde" >> $GITHUB_PATH + - name: Configure + run: mkdir build && cd build && cmake -GNinja ${{ matrix.CMAKE_ARGS }} .. && cmake -LA -N .. + - name: Build + run: ninja + working-directory: build + - name: Run tests + timeout-minutes: 60 + run: | + mkdir -p tmp && sde64 ${{ matrix.SDE_ARCH }} -- \ + python3 -m pytest --verbose --numprocesses=auto ${{ matrix.PYTEST_ARGS }} diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index 98badd237..a4965af6f 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -61,14 +61,19 @@ else() endif() if(${OQS_USE_SHA3_OPENSSL}) - if (${OQS_ENABLE_SHA3_xkcp_low}) - add_subdirectory(sha3/xkcp_low) - endif() + if (${OQS_ENABLE_SHA3_xkcp_low}) + add_subdirectory(sha3/xkcp_low) + endif() set(SHA3_IMPL sha3/ossl_sha3.c sha3/ossl_sha3x4.c) set(OSSL_HELPERS ossl_helpers.c) else() # using XKCP add_subdirectory(sha3/xkcp_low) set(SHA3_IMPL sha3/xkcp_sha3.c sha3/xkcp_sha3x4.c) + if(OQS_USE_SHA3_AVX512VL) + # also build avx512vl modules + add_subdirectory(sha3/avx512vl_low) + list(APPEND SHA3_IMPL sha3/avx512vl_sha3.c sha3/avx512vl_sha3x4.c) + endif() endif() if ((OQS_LIBJADE_BUILD STREQUAL "ON")) @@ -157,6 +162,11 @@ if(${OQS_ENABLE_SHA3_xkcp_low}) # using XKCP set(_INTERNAL_OBJS ${_INTERNAL_OBJS} ${XKCP_LOW_OBJS}) endif() +if(${OQS_USE_SHA3_AVX512VL}) + set(_COMMON_OBJS ${_COMMON_OBJS} ${SHA3_AVX512VL_LOW_OBJS}) + set(_INTERNAL_OBJS ${_INTERNAL_OBJS} ${SHA3_AVX512VL_LOW_OBJS}) +endif() + set(_COMMON_OBJS ${_COMMON_OBJS} $) set(COMMON_OBJS ${_COMMON_OBJS} PARENT_SCOPE) set(_INTERNAL_OBJS ${_INTERNAL_OBJS} $) diff --git a/src/common/sha3/avx512vl_low/CMakeLists.txt b/src/common/sha3/avx512vl_low/CMakeLists.txt new file mode 100644 index 000000000..74339c0cc --- /dev/null +++ b/src/common/sha3/avx512vl_low/CMakeLists.txt @@ -0,0 +1,14 @@ +# Copyright (c) 2025 Intel Corporation +# +# SPDX-License-Identifier: MIT + +set(_SHA3_AVX512VL_LOW_OBJS "") + +if(OQS_USE_SHA3_AVX512VL) + add_library(sha3_avx512vl_low OBJECT + KeccakP-1600-AVX512VL.S SHA3-AVX512VL.S KeccakP-1600-times4-AVX512VL.S SHA3-times4-AVX512VL.S) + set(_SHA3_AVX512VL_LOW_OBJS ${_SHA3_AVX512VL_LOW_OBJS} $) +endif() + +set(SHA3_AVX512VL_LOW_OBJS ${_SHA3_AVX512VL_LOW_OBJS} PARENT_SCOPE) + diff --git a/src/common/sha3/avx512vl_low/KeccakP-1600-AVX512VL.S b/src/common/sha3/avx512vl_low/KeccakP-1600-AVX512VL.S new file mode 100644 index 000000000..7de6783c1 --- /dev/null +++ b/src/common/sha3/avx512vl_low/KeccakP-1600-AVX512VL.S @@ -0,0 +1,548 @@ +# Copyright (c) 2025 Intel Corporation +# +# SPDX-License-Identifier: MIT + +# Define arg registers +.equ arg1, %rdi +.equ arg2, %rsi + +.text + +# Initialized Keccak state in registers +# +# input: +# output: xmm0-xmm24 +.globl keccak_1600_init_state +.type keccak_1600_init_state,@function +.hidden keccak_1600_init_state +.balign 32 +keccak_1600_init_state: + vpxorq %xmm0, %xmm0, %xmm0 + vpxorq %xmm1, %xmm1, %xmm1 + vpxorq %xmm2, %xmm2, %xmm2 + vmovdqa64 %ymm0, %ymm3 + vmovdqa64 %ymm0, %ymm4 + vmovdqa64 %ymm0, %ymm5 + vmovdqa64 %ymm0, %ymm6 + vmovdqa64 %ymm0, %ymm7 + vmovdqa64 %ymm0, %ymm8 + vmovdqa64 %ymm0, %ymm9 + vmovdqa64 %ymm0, %ymm10 + vmovdqa64 %ymm0, %ymm11 + vmovdqa64 %ymm0, %ymm12 + vmovdqa64 %ymm0, %ymm13 + vmovdqa64 %ymm0, %ymm14 + vmovdqa64 %ymm0, %ymm15 + vmovdqa64 %ymm0, %ymm16 + vmovdqa64 %ymm0, %ymm17 + vmovdqa64 %ymm0, %ymm18 + vmovdqa64 %ymm0, %ymm19 + vmovdqa64 %ymm0, %ymm20 + vmovdqa64 %ymm0, %ymm21 + vmovdqa64 %ymm0, %ymm22 + vmovdqa64 %ymm0, %ymm23 + vmovdqa64 %ymm0, %ymm24 + ret +.size keccak_1600_init_state,.-keccak_1600_init_state + + +# Loads Keccak state from memory into registers +# +# input: arg1 - state pointer +# output: xmm0-xmm24 +.globl keccak_1600_load_state +.type keccak_1600_load_state,@function +.hidden keccak_1600_load_state +.balign 32 +keccak_1600_load_state: + vmovq (8*0)(arg1), %xmm0 + vmovq (8*1)(arg1), %xmm1 + vmovq (8*2)(arg1), %xmm2 + vmovq (8*3)(arg1), %xmm3 + vmovq (8*4)(arg1), %xmm4 + vmovq (8*5)(arg1), %xmm5 + vmovq (8*6)(arg1), %xmm6 + vmovq (8*7)(arg1), %xmm7 + vmovq (8*8)(arg1), %xmm8 + vmovq (8*9)(arg1), %xmm9 + vmovq (8*10)(arg1), %xmm10 + vmovq (8*11)(arg1), %xmm11 + vmovq (8*12)(arg1), %xmm12 + vmovq (8*13)(arg1), %xmm13 + vmovq (8*14)(arg1), %xmm14 + vmovq (8*15)(arg1), %xmm15 + vmovq (8*16)(arg1), %xmm16 + vmovq (8*17)(arg1), %xmm17 + vmovq (8*18)(arg1), %xmm18 + vmovq (8*19)(arg1), %xmm19 + vmovq (8*20)(arg1), %xmm20 + vmovq (8*21)(arg1), %xmm21 + vmovq (8*22)(arg1), %xmm22 + vmovq (8*23)(arg1), %xmm23 + vmovq (8*24)(arg1), %xmm24 + ret +.size keccak_1600_load_state,.-keccak_1600_load_state + + +# Saves Keccak state to memory +# +# input: arg1 - state pointer +# xmm0-xmm24 - Keccak state registers +# output: memory from [arg1] to [arg1 + 25*8] +.globl keccak_1600_save_state +.type keccak_1600_save_state,@function +.hidden keccak_1600_save_state +.balign 32 +keccak_1600_save_state: + vmovq %xmm0, (8*0)(arg1) + vmovq %xmm1, (8*1)(arg1) + vmovq %xmm2, (8*2)(arg1) + vmovq %xmm3, (8*3)(arg1) + vmovq %xmm4, (8*4)(arg1) + vmovq %xmm5, (8*5)(arg1) + vmovq %xmm6, (8*6)(arg1) + vmovq %xmm7, (8*7)(arg1) + vmovq %xmm8, (8*8)(arg1) + vmovq %xmm9, (8*9)(arg1) + vmovq %xmm10, (8*10)(arg1) + vmovq %xmm11, (8*11)(arg1) + vmovq %xmm12, (8*12)(arg1) + vmovq %xmm13, (8*13)(arg1) + vmovq %xmm14, (8*14)(arg1) + vmovq %xmm15, (8*15)(arg1) + vmovq %xmm16, (8*16)(arg1) + vmovq %xmm17, (8*17)(arg1) + vmovq %xmm18, (8*18)(arg1) + vmovq %xmm19, (8*19)(arg1) + vmovq %xmm20, (8*20)(arg1) + vmovq %xmm21, (8*21)(arg1) + vmovq %xmm22, (8*22)(arg1) + vmovq %xmm23, (8*23)(arg1) + vmovq %xmm24, (8*24)(arg1) + ret +.size keccak_1600_save_state,.-keccak_1600_save_state + + +# Add input data to state when message length is less than rate +# +# input: +# r13 - state +# arg2 - message pointer (updated on output) +# r12 - length (clobbered on output) +# output: +# memory - state from [r13] to [r13 + r12 - 1] +# clobbered: +# rax, k1, ymm31 +.globl keccak_1600_partial_add +.type keccak_1600_partial_add,@function +.hidden keccak_1600_partial_add +.balign 32 +keccak_1600_partial_add: +.partial_add_loop: + cmpq $32, %r12 + jb .lt_32_bytes + + vmovdqu64 (arg2), %ymm31 + vpxorq (%r13), %ymm31, %ymm31 + vmovdqu64 %ymm31, (%r13) + addq $32, arg2 + addq $32, %r13 + subq $32, %r12 + jz .partial_add_done + jmp .partial_add_loop + +.lt_32_bytes: + xorq %rax, %rax + bts %r12, %rax + decq %rax + kmovq %rax, %k1 # k1 is the mask of message bytes to read + vmovdqu8 (arg2), %ymm31{%k1}{z} # Read 0 to 31 bytes + vpxorq (%r13), %ymm31, %ymm31 + vmovdqu8 %ymm31, (%r13){%k1} + addq %r12, arg2 # Increment message pointer + +.partial_add_done: + ret +.size keccak_1600_partial_add,.-keccak_1600_partial_add + + +# Extract bytes from state and write to output +# +# input: +# r13 - state +# r10 - output pointer (updated on output) +# r12 - length (clobbered on output) +# output: +# memory - output from [r10] to [r10 + r12 - 1] +# clobbered: +# rax, k1, ymm31 +.globl keccak_1600_extract_bytes +.type keccak_1600_extract_bytes,@function +.hidden keccak_1600_extract_bytes +.balign 32 +keccak_1600_extract_bytes: +.extract_32_byte_loop: + cmpq $32, %r12 + jb .extract_lt_32_bytes + + vmovdqu64 (%r13), %ymm31 + vmovdqu64 %ymm31, (%r10) + addq $32, %r13 + addq $32, %r10 + subq $32, %r12 + jz .extract_done + jmp .extract_32_byte_loop + +.extract_lt_32_bytes: + xorq %rax, %rax + bts %r12, %rax + decq %rax + kmovq %rax, %k1 # k1 is the mask of the last message bytes + vmovdqu8 (%r13), %ymm31{%k1}{z} # Read 0 to 31 bytes + vmovdqu8 %ymm31, (%r10){%k1} + addq %r12, %r10 # Increment output pointer +.extract_done: + ret +.size keccak_1600_extract_bytes,.-keccak_1600_extract_bytes + +# Copy partial block message into temporary buffer, add padding byte and EOM bit +# +# r13 [in/out] destination pointer +# r12 [in/out] source pointer +# r11 [in/out] length in bytes +# r9 [in] rate +# r8 [in] pointer to the padding byte +# output: +# memory - output from [r13] to [r13 + r11 - 1], [r13 + r11] padding, [r13 + r9 - 1] EOM +# clobbered: +# rax, r15, k1, k2, ymm31 +.globl keccak_1600_copy_with_padding +.type keccak_1600_copy_with_padding, @function +.hidden keccak_1600_copy_with_padding +.balign 32 +keccak_1600_copy_with_padding: + # Clear the temporary buffer + vpxorq %ymm31, %ymm31, %ymm31 + vmovdqu64 %ymm31, (32*0)(%r13) + vmovdqu64 %ymm31, (32*1)(%r13) + vmovdqu64 %ymm31, (32*2)(%r13) + vmovdqu64 %ymm31, (32*3)(%r13) + vmovdqu64 %ymm31, (32*4)(%r13) + vmovdqu64 %ymm31, (32*5)(%r13) + vmovdqu64 %ymm31, (32*6)(%r13) + vmovdqu64 %ymm31, (32*7)(%r13) + + xorq %r15, %r15 +.balign 32 +.copy32_with_padding_loop: + cmpq $32, %r11 # Check at least 32 remaining + jb .partial32_with_padding # If no, then do final copy with padding + + vmovdqu64 (%r12,%r15), %ymm31 + vmovdqu64 %ymm31, (%r13,%r15) + subq $32, %r11 # Decrement the remaining length + addq $32, %r15 # Increment offset + jmp .copy32_with_padding_loop + +.partial32_with_padding: + xorq %rax, %rax + bts %r11, %rax + kmovq %rax, %k2 # k2 is mask of the 1st byte after the message + decq %rax + kmovq %rax, %k1 # k1 is the mask of the last message bytes + vmovdqu8 (%r12,%r15), %ymm31{%k1}{z} # Read 0 to 31 bytes + vpbroadcastb (%r8), %ymm31{%k2} # Add padding + vmovdqu64 %ymm31, (%r13,%r15) # Store whole 32 bytes + xorb $0x80, (-1)(%r13,%r9) # EOM bit - XOR the last byte of the block + ret +.size keccak_1600_copy_with_padding,.-keccak_1600_copy_with_padding + + +.globl keccak_1600_copy_digest +.type keccak_1600_copy_digest, @function +.hidden keccak_1600_copy_digest +.balign 32 +keccak_1600_copy_digest: + .copy32_digest_loop: + cmp $32, arg2 # Check at least 32 remaining + jb .partial32 # If no, then copy final bytes + + vmovdqu64 (%r12), %ymm31 + vmovdqu64 %ymm31, (%r13) + addq $32, %r13 # Increment destination pointer + addq $32, %r12 # Increment source pointer + subq $32, arg2 # Decrement the remaining length + jz .done + jmp .copy32_digest_loop + +.partial32: + xorq %rax, %rax + bts arg2, %rax + dec %rax + kmovq %rax, %k1 # k1 is the mask of the last message bytes + vmovdqu8 (%r12), %ymm31{%k1}{z} # Read 0 to 31 bytes + vmovdqu8 %ymm31, (%r13){%k1} # Store 0 to 31 bytes +.done: + ret +.size keccak_1600_copy_digest,.-keccak_1600_copy_digest + +# Perform Keccak permutation +# +# YMM registers 0 to 24 are used as Keccak state registers. +# This function, as is, can work on 1 to 4 independent states at the same time. +# +# There is no clear boundary between Theta, Rho, Pi, Chi and Iota steps. +# Instructions corresponding to these steps overlap for better efficiency. +# +# ymm0-ymm24 [in/out] Keccak state registers (one SIMD per one state register) +# ymm25-ymm31 [clobbered] temporary SIMD registers +# r13 [clobbered] used for round tracking +# r14 [clobbered] used for access to SHA3 constant table +.globl keccak_1600_permute +.type keccak_1600_permute,@function +.hidden keccak_1600_permute +.balign 32 +keccak_1600_permute: + movl $24, %r13d # 24 rounds + leaq sha3_rc(%rip), %r14 # Load the address of the SHA3 round constants + +.balign 32 +keccak_rnd_loop: + # Theta step + + # Compute column parities + # C[5] = [0, 0, 0, 0, 0] + # for x in 0 to 4: + # C[x] = state[x][0] XOR state[x][1] XOR state[x][2] XOR state[x][3] XOR state[x][4] + + vmovdqa64 %ymm0, %ymm25 + vpternlogq $0x96, %ymm5, %ymm10, %ymm25 + vmovdqa64 %ymm1, %ymm26 + vpternlogq $0x96, %ymm11, %ymm6, %ymm26 + vmovdqa64 %ymm2, %ymm27 + vpternlogq $0x96, %ymm12, %ymm7, %ymm27 + + vmovdqa64 %ymm3, %ymm28 + vpternlogq $0x96, %ymm13, %ymm8, %ymm28 + vmovdqa64 %ymm4, %ymm29 + vpternlogq $0x96, %ymm14, %ymm9, %ymm29 + vpternlogq $0x96, %ymm20, %ymm15, %ymm25 + + vpternlogq $0x96, %ymm21, %ymm16, %ymm26 + vpternlogq $0x96, %ymm22, %ymm17, %ymm27 + vpternlogq $0x96, %ymm23, %ymm18, %ymm28 + + # Start computing D values and keep computing column parity + # D[5] = [0, 0, 0, 0, 0] + # for x in 0 to 4: + # D[x] = C[(x+4) mod 5] XOR ROTATE_LEFT(C[(x+1) mod 5], 1) + + vprolq $1, %ymm26, %ymm30 + vprolq $1, %ymm27, %ymm31 + vpternlogq $0x96, %ymm24, %ymm19, %ymm29 + + # Continue computing D values and apply Theta + # for x in 0 to 4: + # for y in 0 to 4: + # state[x][y] = state[x][y] XOR D[x] + + vpternlogq $0x96, %ymm30, %ymm29, %ymm0 + vpternlogq $0x96, %ymm30, %ymm29, %ymm10 + vpternlogq $0x96, %ymm30, %ymm29, %ymm20 + + vpternlogq $0x96, %ymm30, %ymm29, %ymm5 + vpternlogq $0x96, %ymm30, %ymm29, %ymm15 + vprolq $1, %ymm28, %ymm30 + + vpternlogq $0x96, %ymm31, %ymm25, %ymm6 + vpternlogq $0x96, %ymm31, %ymm25, %ymm16 + vpternlogq $0x96, %ymm31, %ymm25, %ymm1 + + vpternlogq $0x96, %ymm31, %ymm25, %ymm11 + vpternlogq $0x96, %ymm31, %ymm25, %ymm21 + vprolq $1, %ymm29, %ymm31 + + vpbroadcastq (%r14), %ymm29 # Load the round constant into ymm29 (Iota) + addq $8, %r14 # Increment the pointer to the next round constant + + vpternlogq $0x96, %ymm30, %ymm26, %ymm12 + vpternlogq $0x96, %ymm30, %ymm26, %ymm7 + vpternlogq $0x96, %ymm30, %ymm26, %ymm22 + + vpternlogq $0x96, %ymm30, %ymm26, %ymm17 + vpternlogq $0x96, %ymm30, %ymm26, %ymm2 + vprolq $1, %ymm25, %ymm30 + + # Rho step + # Keep applying Theta and start Rho step + # + # ROTATION_OFFSETS[5][5] = [ + # [0, 1, 62, 28, 27], + # [36, 44, 6, 55, 20], + # [3, 10, 43, 25, 39], + # [41, 45, 15, 21, 8], + # [18, 2, 61, 56, 14] ] + # + # for x in 0 to 4: + # for y in 0 to 4: + # state[x][y] = ROTATE_LEFT(state[x][y], ROTATION_OFFSETS[x][y]) + + vpternlogq $0x96, %ymm31, %ymm27, %ymm3 + vpternlogq $0x96, %ymm31, %ymm27, %ymm13 + vpternlogq $0x96, %ymm31, %ymm27, %ymm23 + + vprolq $44, %ymm6, %ymm6 + vpternlogq $0x96, %ymm31, %ymm27, %ymm18 + vpternlogq $0x96, %ymm31, %ymm27, %ymm8 + + vprolq $43, %ymm12, %ymm12 + vprolq $21, %ymm18, %ymm18 + vpternlogq $0x96, %ymm30, %ymm28, %ymm24 + + vprolq $14, %ymm24, %ymm24 + vprolq $28, %ymm3, %ymm3 + vpternlogq $0x96, %ymm30, %ymm28, %ymm9 + + vprolq $20, %ymm9, %ymm9 + vprolq $3, %ymm10, %ymm10 + vpternlogq $0x96, %ymm30, %ymm28, %ymm19 + + vprolq $45, %ymm16, %ymm16 + vprolq $61, %ymm22, %ymm22 + vpternlogq $0x96, %ymm30, %ymm28, %ymm4 + + vprolq $1, %ymm1, %ymm1 + vprolq $6, %ymm7, %ymm7 + vpternlogq $0x96, %ymm30, %ymm28, %ymm14 + + # Continue with Rho and start Pi and Chi steps at the same time + # Ternary logic 0xD2 is used for Chi step + # + # for x in 0 to 4: + # for y in 0 to 4: + # state[x][y] = state[x][y] XOR ((NOT state[(x+1) mod 5][y]) AND state[(x+2) mod 5][y]) + + vprolq $25, %ymm13, %ymm13 + vprolq $8, %ymm19, %ymm19 + vmovdqa64 %ymm0, %ymm30 + vpternlogq $0xD2, %ymm12, %ymm6, %ymm30 + + vprolq $18, %ymm20, %ymm20 + vprolq $27, %ymm4, %ymm4 + vpxorq %ymm29, %ymm30, %ymm30 # Iota step + + vprolq $36, %ymm5, %ymm5 + vprolq $10, %ymm11, %ymm11 + vmovdqa64 %ymm6, %ymm31 + vpternlogq $0xD2, %ymm18, %ymm12, %ymm31 + + vprolq $15, %ymm17, %ymm17 + vprolq $56, %ymm23, %ymm23 + vpternlogq $0xD2, %ymm24, %ymm18, %ymm12 + + vprolq $62, %ymm2, %ymm2 + vprolq $55, %ymm8, %ymm8 + vpternlogq $0xD2, %ymm0, %ymm24, %ymm18 + + vprolq $39, %ymm14, %ymm14 + vprolq $41, %ymm15, %ymm15 + vpternlogq $0xD2, %ymm6, %ymm0, %ymm24 + vmovdqa64 %ymm30, %ymm0 + vmovdqa64 %ymm31, %ymm6 + + vprolq $2, %ymm21, %ymm21 + vmovdqa64 %ymm3, %ymm30 + vpternlogq $0xD2, %ymm10, %ymm9, %ymm30 + vmovdqa64 %ymm9, %ymm31 + vpternlogq $0xD2, %ymm16, %ymm10, %ymm31 + + vpternlogq $0xD2, %ymm22, %ymm16, %ymm10 + vpternlogq $0xD2, %ymm3, %ymm22, %ymm16 + vpternlogq $0xD2, %ymm9, %ymm3, %ymm22 + vmovdqa64 %ymm30, %ymm3 + vmovdqa64 %ymm31, %ymm9 + + vmovdqa64 %ymm1, %ymm30 + vpternlogq $0xD2, %ymm13, %ymm7, %ymm30 + vmovdqa64 %ymm7, %ymm31 + vpternlogq $0xD2, %ymm19, %ymm13, %ymm31 + vpternlogq $0xD2, %ymm20, %ymm19, %ymm13 + + vpternlogq $0xD2, %ymm1, %ymm20, %ymm19 + vpternlogq $0xD2, %ymm7, %ymm1, %ymm20 + vmovdqa64 %ymm30, %ymm1 + vmovdqa64 %ymm31, %ymm7 + vmovdqa64 %ymm4, %ymm30 + vpternlogq $0xD2, %ymm11, %ymm5, %ymm30 + + vmovdqa64 %ymm5, %ymm31 + vpternlogq $0xD2, %ymm17, %ymm11, %ymm31 + vpternlogq $0xD2, %ymm23, %ymm17, %ymm11 + vpternlogq $0xD2, %ymm4, %ymm23, %ymm17 + + vpternlogq $0xD2, %ymm5, %ymm4, %ymm23 + vmovdqa64 %ymm30, %ymm4 + vmovdqa64 %ymm31, %ymm5 + vmovdqa64 %ymm2, %ymm30 + vpternlogq $0xD2, %ymm14, %ymm8, %ymm30 + vmovdqa64 %ymm8, %ymm31 + vpternlogq $0xD2, %ymm15, %ymm14, %ymm31 + + vpternlogq $0xD2, %ymm21, %ymm15, %ymm14 + vpternlogq $0xD2, %ymm2, %ymm21, %ymm15 + vpternlogq $0xD2, %ymm8, %ymm2, %ymm21 + vmovdqa64 %ymm30, %ymm2 + vmovdqa64 %ymm31, %ymm8 + + # Complete the steps and get updated state registers in ymm0 to ymm24 + vmovdqa64 %ymm3, %ymm30 + vmovdqa64 %ymm18, %ymm3 + vmovdqa64 %ymm17, %ymm18 + vmovdqa64 %ymm11, %ymm17 + vmovdqa64 %ymm7, %ymm11 + vmovdqa64 %ymm10, %ymm7 + vmovdqa64 %ymm1, %ymm10 + vmovdqa64 %ymm6, %ymm1 + vmovdqa64 %ymm9, %ymm6 + vmovdqa64 %ymm22, %ymm9 + vmovdqa64 %ymm14, %ymm22 + vmovdqa64 %ymm20, %ymm14 + vmovdqa64 %ymm2, %ymm20 + vmovdqa64 %ymm12, %ymm2 + vmovdqa64 %ymm13, %ymm12 + vmovdqa64 %ymm19, %ymm13 + vmovdqa64 %ymm23, %ymm19 + vmovdqa64 %ymm15, %ymm23 + vmovdqa64 %ymm4, %ymm15 + vmovdqa64 %ymm24, %ymm4 + vmovdqa64 %ymm21, %ymm24 + vmovdqa64 %ymm8, %ymm21 + vmovdqa64 %ymm16, %ymm8 + vmovdqa64 %ymm5, %ymm16 + vmovdqa64 %ymm30, %ymm5 + + decl %r13d # Decrement the round counter + jnz keccak_rnd_loop # Jump to the start of the loop if r13d is not zero + ret +.size keccak_1600_permute,.-keccak_1600_permute + +.section .rodata + +.balign 64 +sha3_rc: +# SHA3 round constants +# These constants are used in each round of the Keccak permutation +.quad 0x0000000000000001, 0x0000000000008082 +.quad 0x800000000000808a, 0x8000000080008000 +.quad 0x000000000000808b, 0x0000000080000001 +.quad 0x8000000080008081, 0x8000000000008009 +.quad 0x000000000000008a, 0x0000000000000088 +.quad 0x0000000080008009, 0x000000008000000a +.quad 0x000000008000808b, 0x800000000000008b +.quad 0x8000000000008089, 0x8000000000008003 +.quad 0x8000000000008002, 0x8000000000000080 +.quad 0x000000000000800a, 0x800000008000000a +.quad 0x8000000080008081, 0x8000000000008080 +.quad 0x0000000080000001, 0x8000000080008008 + +.section .note.GNU-stack,"",%progbits diff --git a/src/common/sha3/avx512vl_low/KeccakP-1600-times4-AVX512VL.S b/src/common/sha3/avx512vl_low/KeccakP-1600-times4-AVX512VL.S new file mode 100644 index 000000000..a7df09142 --- /dev/null +++ b/src/common/sha3/avx512vl_low/KeccakP-1600-times4-AVX512VL.S @@ -0,0 +1,375 @@ +# Copyright (c) 2025 Intel Corporation +# +# SPDX-License-Identifier: MIT + +# Define arg registers +.equ arg1, %rdi +.equ arg2, %rsi +.equ arg3, %rdx +.equ arg4, %rcx +.equ arg5, %r8 + +.text + +# Loads Keccak state from memory into registers +# +# input: arg1 - state pointer +# output: ymm0-ymm24 +.globl keccak_1600_load_state_x4 +.type keccak_1600_load_state_x4,@function +.hidden keccak_1600_load_state_x4 +.balign 32 +keccak_1600_load_state_x4: + vmovdqu64 (32*0)(arg1), %ymm0 + vmovdqu64 (32*1)(arg1), %ymm1 + vmovdqu64 (32*2)(arg1), %ymm2 + vmovdqu64 (32*3)(arg1), %ymm3 + vmovdqu64 (32*4)(arg1), %ymm4 + vmovdqu64 (32*5)(arg1), %ymm5 + vmovdqu64 (32*6)(arg1), %ymm6 + vmovdqu64 (32*7)(arg1), %ymm7 + vmovdqu64 (32*8)(arg1), %ymm8 + vmovdqu64 (32*9)(arg1), %ymm9 + vmovdqu64 (32*10)(arg1), %ymm10 + vmovdqu64 (32*11)(arg1), %ymm11 + vmovdqu64 (32*12)(arg1), %ymm12 + vmovdqu64 (32*13)(arg1), %ymm13 + vmovdqu64 (32*14)(arg1), %ymm14 + vmovdqu64 (32*15)(arg1), %ymm15 + vmovdqu64 (32*16)(arg1), %ymm16 + vmovdqu64 (32*17)(arg1), %ymm17 + vmovdqu64 (32*18)(arg1), %ymm18 + vmovdqu64 (32*19)(arg1), %ymm19 + vmovdqu64 (32*20)(arg1), %ymm20 + vmovdqu64 (32*21)(arg1), %ymm21 + vmovdqu64 (32*22)(arg1), %ymm22 + vmovdqu64 (32*23)(arg1), %ymm23 + vmovdqu64 (32*24)(arg1), %ymm24 + ret +.size keccak_1600_load_state_x4,.-keccak_1600_load_state_x4 + + +# Saves Keccak state to memory +# +# input: arg1 - state pointer +# ymm0-ymm24 - Keccak state registers +# output: memory from [arg1] to [arg1 + 100*8] +.globl keccak_1600_save_state_x4 +.type keccak_1600_save_state_x4,@function +.hidden keccak_1600_save_state_x4 +.balign 32 +keccak_1600_save_state_x4: + vmovdqu64 %ymm0, (32*0)(arg1) + vmovdqu64 %ymm1, (32*1)(arg1) + vmovdqu64 %ymm2, (32*2)(arg1) + vmovdqu64 %ymm3, (32*3)(arg1) + vmovdqu64 %ymm4, (32*4)(arg1) + vmovdqu64 %ymm5, (32*5)(arg1) + vmovdqu64 %ymm6, (32*6)(arg1) + vmovdqu64 %ymm7, (32*7)(arg1) + vmovdqu64 %ymm8, (32*8)(arg1) + vmovdqu64 %ymm9, (32*9)(arg1) + vmovdqu64 %ymm10, (32*10)(arg1) + vmovdqu64 %ymm11, (32*11)(arg1) + vmovdqu64 %ymm12, (32*12)(arg1) + vmovdqu64 %ymm13, (32*13)(arg1) + vmovdqu64 %ymm14, (32*14)(arg1) + vmovdqu64 %ymm15, (32*15)(arg1) + vmovdqu64 %ymm16, (32*16)(arg1) + vmovdqu64 %ymm17, (32*17)(arg1) + vmovdqu64 %ymm18, (32*18)(arg1) + vmovdqu64 %ymm19, (32*19)(arg1) + vmovdqu64 %ymm20, (32*20)(arg1) + vmovdqu64 %ymm21, (32*21)(arg1) + vmovdqu64 %ymm22, (32*22)(arg1) + vmovdqu64 %ymm23, (32*23)(arg1) + vmovdqu64 %ymm24, (32*24)(arg1) + ret +.size keccak_1600_save_state_x4,.-keccak_1600_save_state_x4 + + +# Add input data to state when message length is less than rate +# +# input: +# r10 - state pointer to absorb into (clobbered) +# arg2 - message pointer lane 0 (updated on output) +# arg3 - message pointer lane 1 (updated on output) +# arg4 - message pointer lane 2 (updated on output) +# arg5 - message pointer lane 3 (updated on output) +# r12 - length in bytes (clobbered on output) +# output: +# memory - state from [r10] to [r10 + 4*r12 - 1] +# clobbered: +# rax, rbx, r15, k1, ymm31-ymm29 +.globl keccak_1600_partial_add_x4 +.type keccak_1600_partial_add_x4,@function +.hidden keccak_1600_partial_add_x4 +.balign 32 +keccak_1600_partial_add_x4: + movq (8*100)(%r10), %rax + testl $7, %eax + jz .start_aligned_to_4x8 + + # start offset is not aligned to register size + # - calculate remaining capacity of the register + # - get the min between length and the capacity of the register + # - perform partial add on the register + # - once aligned to the register go into ymm loop + + movq %rax, %r15 # %r15 = s[100] + + andl $7, %eax + negl %eax + addl $8, %eax # register capacity = 8 - (offset % 8) + cmpl %eax, %r12d + cmovb %r12d, %eax # %eax = min(register capacity, $length) + + leaq byte_kmask_0_to_7(%rip), %rbx + kmovb (%rbx,%rax), %k1 # message load mask + + movq %r15, %rbx + andl $~7, %ebx + leaq (%r10, %rbx,4), %r10 # get to state starting register + + movq %r15, %rbx + andl $7, %ebx + + vmovdqu8 (%r10), %ymm31 # load & store / allocate SB for the register + vmovdqu8 %ymm31, (%r10) + + vmovdqu8 (arg2), %xmm31{%k1}{z} # Read 1 to 7 bytes from lane 0 + vmovdqu8 (8*0)(%r10,%rbx), %xmm30{%k1}{z} # Read 1 to 7 bytes from state reg lane 0 + vpxorq %xmm30, %xmm31, %xmm31 + vmovdqu8 %xmm31, (8*0)(%r10,%rbx){%k1} # Write 1 to 7 bytes to state reg lane 0 + + vmovdqu8 (arg3), %xmm31{%k1}{z} # Read 1 to 7 bytes from lane 1 + vmovdqu8 (8*1)(%r10,%rbx), %xmm30{%k1}{z} # Read 1 to 7 bytes from state reg lane 1 + vpxorq %xmm30, %xmm31, %xmm31 + vmovdqu8 %xmm31, (8*1)(%r10,%rbx){%k1} # Write 1 to 7 bytes to state reg lane 1 + + vmovdqu8 (arg4), %xmm31{%k1}{z} # Read 1 to 7 bytes from lane 2 + vmovdqu8 (8*2)(%r10,%rbx), %xmm30{%k1}{z} # Read 1 to 7 bytes from state reg lane 2 + vpxorq %xmm30, %xmm31, %xmm31 + vmovdqu8 %xmm31, (8*2)(%r10,%rbx){%k1} # Write 1 to 7 bytes to state reg lane 2 + + vmovdqu8 (arg5), %xmm31{%k1}{z} # Read 1 to 7 bytes from lane 3 + vmovdqu8 (8*3)(%r10,%rbx), %xmm30{%k1}{z} # Read 1 to 7 bytes from state reg lane 3 + vpxorq %xmm30, %xmm31, %xmm31 + vmovdqu8 %xmm31, (8*3)(%r10,%rbx){%k1} # Write 1 to 7 bytes to state reg lane 3 + + subq %rax, %r12 + jz .zero_bytes + + addq %rax, arg2 + addq %rax, arg3 + addq %rax, arg4 + addq %rax, arg5 + addq $32, %r10 + xorq %rax, %rax + jmp .ymm_loop + +.start_aligned_to_4x8: + leaq (%r10,%rax,4), %r10 + xorq %rax, %rax + +.balign 32 +.ymm_loop: + cmpl $8, %r12d + jb .lt_8_bytes + + vmovq (arg2, %rax), %xmm31 # Read 8 bytes from lane 0 + vpinsrq $1, (arg3, %rax), %xmm31, %xmm31 # Read 8 bytes from lane 1 + vmovq (arg4, %rax), %xmm30 # Read 8 bytes from lane 2 + vpinsrq $1, (arg5, %rax),%xmm30, %xmm30 # Read 8 bytes from lane 3 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq (%r10,%rax,4), %ymm31, %ymm31 # Add data with the state + vmovdqu64 %ymm31, (%r10,%rax,4) + addq $8, %rax + subq $8, %r12 + jz .zero_bytes + + jmp .ymm_loop + +.balign 32 +.zero_bytes: + addq %rax, arg2 + addq %rax, arg3 + addq %rax, arg4 + addq %rax, arg5 + ret + +.balign 32 +.lt_8_bytes: + addq %rax, arg2 + addq %rax, arg3 + addq %rax, arg4 + addq %rax, arg5 + leaq (%r10,%rax,4), %r10 + + leaq byte_kmask_0_to_7(%rip), %rbx + kmovb (%rbx,%r12), %k1 # message load mask + + vmovdqu8 (arg2), %xmm31{%k1}{z} # Read 1 to 7 bytes from lane 0 + vmovdqu8 (arg3), %xmm30{%k1}{z} # Read 1 to 7 bytes from lane 1 + vpunpcklqdq %xmm30, %xmm31, %xmm31 # Interleave data from lane 0 and lane 1 + vmovdqu8 (arg4), %xmm30{%k1}{z} # Read 1 to 7 bytes from lane 2 + vmovdqu8 (arg5), %xmm29{%k1}{z} # Read 1 to 7 bytes from lane 3 + vpunpcklqdq %xmm29, %xmm30, %xmm30 # Interleave data from lane 2 and lane 3 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + + vpxorq (%r10), %ymm31, %ymm31 # Add data to the state + vmovdqu64 %ymm31, (%r10) # Update state in memory + + addq %r12, arg2 # increment message pointer lane 0 + addq %r12, arg3 # increment message pointer lane 1 + addq %r12, arg4 # increment message pointer lane 2 + addq %r12, arg5 # increment message pointer lane 3 + ret +.size keccak_1600_partial_add_x4,.-keccak_1600_partial_add_x4 + + +# Extract bytes from state and write to outputs +# +# input: +# r10 - state pointer to start extracting from (clobbered) +# arg1 - output pointer lane 0 (updated on output) +# arg2 - output pointer lane 1 (updated on output) +# arg3 - output pointer lane 2 (updated on output) +# arg4 - output pointer lane 3 (updated on output) +# r12 - length in bytes (clobbered on output) +# r11 - state offset to start extract from +# output: +# memory - output lane 0 from [arg1] to [arg1 + r12 - 1] +# memory - output lane 1 from [arg2] to [arg2 + r12 - 1] +# memory - output lane 2 from [arg3] to [arg3 + r12 - 1] +# memory - output lane 3 from [arg4] to [arg4 + r12 - 1] +# clobbered: +# rax, rbx, k1, ymm31-ymm30 +.globl keccak_1600_extract_bytes_x4 +.type keccak_1600_extract_bytes_x4,@function +.hidden keccak_1600_extract_bytes_x4 +.balign 32 +keccak_1600_extract_bytes_x4: + orq %r12, %r12 + jz .extract_zero_bytes + + testl $7, %r11d + jz .extract_start_aligned_to_4x8 + + # extract offset is not aligned to the register size (8 bytes) + # - calculate remaining capacity of the register + # - get the min between length to extract and register capacity + # - perform partial add on the register + + movq %r11, %rax # %rax = %r11 = offset in the state + + andl $7, %eax + negl %eax + addl $8, %eax # register capacity = 8 - (offset % 8) + cmpl %eax, %r12d + cmovb %r12d, %eax # %eax = min(register capacity, length) + + leaq byte_kmask_0_to_7(%rip), %rbx + kmovb (%rbx,%rax), %k1 # message store mask + + movq %r11, %rbx + andl $~7, %ebx + leaq (%r10,%rbx,4), %r10 # get to state starting register + + movq %r11, %rbx + andl $7, %ebx + + vmovdqu8 (8*0)(%r10,%rbx), %xmm31{%k1}{z} # Read 1 to 7 bytes from state reg lane 0 + vmovdqu8 %xmm31, (arg1){%k1} # Write 1 to 7 bytes to lane 0 output + + vmovdqu8 (8*1)(%r10,%rbx), %xmm31{%k1}{z} # Read 1 to 7 bytes from state reg lane 1 + vmovdqu8 %xmm31, (arg2){%k1} # Write 1 to 7 bytes to lane 1 output + + vmovdqu8 (8*2)(%r10,%rbx), %xmm31{%k1}{z} # Read 1 to 7 bytes from state reg lane 2 + vmovdqu8 %xmm31, (arg3){%k1} # Write 1 to 7 bytes to lane 2 output + + vmovdqu8 (8*3)(%r10,%rbx), %xmm31{%k1}{z} # Read 1 to 7 bytes from state reg + vmovdqu8 %xmm31, (arg4){%k1} # Write 1 to 7 bytes to lane 3 output + + # increment output registers + addq %rax, arg1 + addq %rax, arg2 + addq %rax, arg3 + addq %rax, arg4 + + # decrement length to extract + subq %rax, %r12 + jz .extract_zero_bytes + + # there is more data to extract, update state register pointer and go to the main loop + addq $32, %r10 + xorq %rax, %rax + jmp .ymm_loop + +.extract_start_aligned_to_4x8: + leaq (%r10,%r11,4), %r10 + xorq %rax, %rax + +.balign 32 +.extract_ymm_loop: + cmpq $8, %r12 + jb .extract_lt_8_bytes + + vmovdqu64 (%r10), %xmm31 + vmovdqu64 (16)(%r10), %xmm30 + vmovq %xmm31, (arg1, %rax) + vpextrq $1, %xmm31, (arg2, %rax) + vmovq %xmm30, (arg3, %rax) + vpextrq $1, %xmm30, (arg4, %rax) + addq $8, %rax + subq $8, %r12 + jz .zero_bytes_left + + addq $4*8, %r10 + jmp .extract_ymm_loop + + +.balign 32 +.zero_bytes_left: + # increment output pointers + addq %rax, arg1 + addq %rax, arg2 + addq %rax, arg3 + addq %rax, arg4 +.extract_zero_bytes: + ret + +.balign 32 +.extract_lt_8_bytes: + addq %rax, arg1 + addq %rax, arg2 + addq %rax, arg3 + addq %rax, arg4 + + leaq byte_kmask_0_to_7(%rip), %rax + kmovb (%rax,%r12), %k1 # k1 is the mask of message bytes to read + + vmovq (0*8)(%r10), %xmm31 # Read 8 bytes from state lane 0 + vmovdqu8 %xmm31, (arg1){%k1} # Extract 1 to 7 bytes of state into output 0 + vmovq (1*8)(%r10), %xmm31 # Read 8 bytes from state lane 1 + vmovdqu8 %xmm31, (arg2){%k1} # Extract 1 to 7 bytes of state into output 1 + vmovq (2*8)(%r10), %xmm31 # Read 8 bytes from state lane 2 + vmovdqu8 %xmm31, (arg3){%k1} # Extract 1 to 7 bytes of state into output 2 + vmovq (3*8)(%r10), %xmm31 # Read 8 bytes from state lane 3 + vmovdqu8 %xmm31, (arg4){%k1} # Extract 1 to 7 bytes of state into output 3 + + # increment output pointers + addq %r12, arg2 + addq %r12, arg3 + addq %r12, arg4 + addq %r12, arg5 + ret +.size keccak_1600_extract_bytes_x4,.-keccak_1600_extract_bytes_x4 + +.section .rodata + +.balign 8 +byte_kmask_0_to_7: + .byte 0x00, 0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f # 0xff should never happen + +.section .note.GNU-stack,"",%progbits diff --git a/src/common/sha3/avx512vl_low/SHA3-AVX512VL.S b/src/common/sha3/avx512vl_low/SHA3-AVX512VL.S new file mode 100644 index 000000000..5a7e61094 --- /dev/null +++ b/src/common/sha3/avx512vl_low/SHA3-AVX512VL.S @@ -0,0 +1,1159 @@ +# Copyright (c) 2025 Intel Corporation +# +# SPDX-License-Identifier: MIT + +# Define arg registers +.equ arg1, %rdi +.equ arg2, %rsi +.equ arg3, %rdx +.equ arg4, %rcx + +# Define SHA3 rates +.equ SHA3_256_RATE, 136 +.equ SHA3_384_RATE, 104 +.equ SHA3_512_RATE, 72 +.equ SHAKE128_RATE, 168 +.equ SHAKE256_RATE, 136 + +# Define SHA3 digest sizes +.equ SHA3_256_DIGEST_SZ, 32 +.equ SHA3_384_DIGEST_SZ, 48 +.equ SHA3_512_DIGEST_SZ, 64 + +# Define SHA3 EOM bytes +.equ SHA3_256_EOM, 0x06 +.equ SHA3_384_EOM, 0x06 +.equ SHA3_512_EOM, 0x06 +.equ SHAKE128_EOM, 0x1F +.equ SHAKE256_EOM, 0x1F + + +# External utility functions +.extern keccak_1600_permute +.extern keccak_1600_init_state +.extern keccak_1600_load_state +.extern keccak_1600_save_state +.extern keccak_1600_partial_add +.extern keccak_1600_copy_with_padding +.extern keccak_1600_copy_digest +.extern keccak_1600_extract_bytes + + +# Define macros + +# Absorb input bytes into state registers +# +# input [in] message pointer +# offset [in] message offset pointer +# rate [in] SHA3 variant absorb rate +.macro absorb_bytes input, offset, rate + vmovq (\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm0, %ymm0 + vmovq 8(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm1, %ymm1 + vmovq 16(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm2, %ymm2 + vmovq 24(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm3, %ymm3 + vmovq 32(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm4, %ymm4 + vmovq 40(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm5, %ymm5 + vmovq 48(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm6, %ymm6 + vmovq 56(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm7, %ymm7 + vmovq 64(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm8, %ymm8 + # SHA3_512 rate, 72 bytes +.if \rate > SHA3_512_RATE + # SHA3_384 rate + vmovq 72(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm9, %ymm9 + vmovq 80(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm10, %ymm10 + vmovq 88(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm11, %ymm11 + vmovq 96(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm12, %ymm12 +.endif +.if \rate > SHA3_384_RATE + # SHA3_256 and shake256 rate + vmovq 104(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm13, %ymm13 + vmovq 112(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm14, %ymm14 + vmovq 120(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm15, %ymm15 + vmovq 128(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm16, %ymm16 +.endif +.if \rate > SHA3_256_RATE + # SHAKE128 rate + vmovq 136(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm17, %ymm17 + vmovq 144(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm18, %ymm18 + vmovq 152(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm19, %ymm19 + vmovq 160(\input, \offset, 1), %xmm31 + vpxorq %ymm31, %ymm20, %ymm20 +.endif +.endm + + +# Store state from SIMD registers to memory +# State registers are kept in xmm0-xmm24 +# +# output [in] destination pointer +# offset [in] destination offset +# n [in] numerical values, number of 8-byte state registers to extract +.macro extract_state output, offset, n + # SHA3 256 + vmovq %xmm0, 0(\output, \offset) + vmovq %xmm1, 8(\output, \offset) + vmovq %xmm2, 16(\output, \offset) + vmovq %xmm3, 24(\output, \offset) +.if \n > 4 + # SHA3 384 + vmovq %xmm4, 32(\output, \offset) + vmovq %xmm5, 40(\output, \offset) +.endif +.if \n > 6 + # SHA3 512 + vmovq %xmm6, 48(\output, \offset) + vmovq %xmm7, 56(\output, \offset) +.endif +.if \n > 8 + # SHAKE 256 + vmovq %xmm8, 64(\output, \offset) + vmovq %xmm9, 72(\output, \offset) + vmovq %xmm10, 80(\output, \offset) + vmovq %xmm11, 88(\output, \offset) + vmovq %xmm12, 96(\output, \offset) + vmovq %xmm13, 104(\output, \offset) + vmovq %xmm14, 112(\output, \offset) + vmovq %xmm15, 120(\output, \offset) + vmovq %xmm16, 128(\output, \offset) +.endif +.if \n > 17 + # SHAKE 128 + vmovq %xmm17, 136(\output, \offset) + vmovq %xmm18, 144(\output, \offset) + vmovq %xmm19, 152(\output, \offset) + vmovq %xmm20, 160(\output, \offset) +.endif +.endm + + +# Process a message with SHA3-256/384/512 and store digest +# +# input: +# rate [in] SHA3 rate +.macro sha3_complete rate + pushq %rbp + pushq %rbx + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 + subq $8*32, %rsp + + mov $\rate, %r9d # Initialize the rate + movq arg3, %r11 # copy message length to r11 + xorq %r12, %r12 # zero message offset + + # Initialize the state array to zero + call keccak_1600_init_state + + # Process the input message in blocks +.balign 32 +1: # main loop + + cmp %r9, %r11 + jb 2f + + absorb_bytes arg2, %r12, \rate + + subq %r9, %r11 # Subtract the rate from the remaining length + addq %r9, %r12 # Adjust the pointer to the next block of the input message + call keccak_1600_permute # Perform the Keccak permutation + jmp 1b + +.balign 32 +2: # main loop done + + movq %rsp, %r13 # dst pointer + addq arg2, %r12 # src pointer + # r11 is length in bytes already + # r9 is rate in bytes already + lea sha3_eom(%rip), %r8 + call keccak_1600_copy_with_padding + + # Add padded block to the state + xorq %r8, %r8 + absorb_bytes %rsp, %r8, \rate + + # Finalize the state and extract the output + call keccak_1600_permute + + # Clear the temporary buffer + vpxorq %ymm31, %ymm31, %ymm31 + vmovdqu64 %ymm31, (32*0)(%rsp) + vmovdqu64 %ymm31, (32*1)(%rsp) + vmovdqu64 %ymm31, (32*2)(%rsp) + vmovdqu64 %ymm31, (32*3)(%rsp) + vmovdqu64 %ymm31, (32*4)(%rsp) + vmovdqu64 %ymm31, (32*5)(%rsp) + vmovdqu64 %ymm31, (32*6)(%rsp) + vmovdqu64 %ymm31, (32*7)(%rsp) + + # Store the state into the digest buffer + xorq %r8, %r8 +.if \rate == SHA3_256_RATE + extract_state arg1, %r8, 4 +.endif +.if \rate == SHA3_384_RATE + extract_state arg1, %r8, 6 +.endif +.if \rate == SHA3_512_RATE + extract_state arg1, %r8, 8 +.endif + + vpxorq %xmm16, %xmm16, %xmm16 + vmovdqa64 %ymm16, %ymm17 + vmovdqa64 %ymm16, %ymm18 + vmovdqa64 %ymm16, %ymm19 + vmovdqa64 %ymm16, %ymm20 + vmovdqa64 %ymm16, %ymm21 + vmovdqa64 %ymm16, %ymm22 + vmovdqa64 %ymm16, %ymm23 + vmovdqa64 %ymm16, %ymm24 + vmovdqa64 %ymm16, %ymm25 + vmovdqa64 %ymm16, %ymm26 + vmovdqa64 %ymm16, %ymm27 + vmovdqa64 %ymm16, %ymm28 + vmovdqa64 %ymm16, %ymm29 + vmovdqa64 %ymm16, %ymm30 + vmovdqa64 %ymm16, %ymm31 + vzeroall + + addq $8*32, %rsp + popq %r15 + popq %r14 + popq %r13 + popq %r12 + popq %rbx + popq %rbp +.endm + + +# Process a message with SHAKE-128/256 and store digest +# +# input: +# rate [in] SHAKE rate +.macro shake_complete rate + pushq %rbp + pushq %rbx + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 + subq $8*32, %rsp + + mov $\rate, %r9d # Initialize the rate for SHAKE + movq arg4, %r11 # copy message length to %r11 + xorq %r12, %r12 # zero message offset + xorq %r10, %r10 + + # Initialize the state array to zero + call keccak_1600_init_state + + # Process the input message in blocks +.balign 32 +1: # main loop + + cmp %r9, %r11 + jb 2f + + absorb_bytes arg3, %r12, \rate + + subq %r9, %r11 # Subtract the rate from the remaining length + addq %r9, %r12 # Adjust the pointer to the next block of the input message + call keccak_1600_permute # Perform the Keccak permutation + + jmp 1b + +.balign 32 +2: # main loop done + + movq %rsp, %r13 # dst pointer + addq arg3, %r12 # src pointer + + # r11 is length in bytes already + # r9 is rate in bytes already + lea shake_eom(%rip), %r8 + call keccak_1600_copy_with_padding + + # Add padded block to the state + xorq %r8, %r8 + absorb_bytes %rsp, %r8, \rate + call keccak_1600_permute # Perform the Keccak permutation + + # Clear the temporary buffer + vpxorq %ymm31, %ymm31, %ymm31 + vmovdqu64 %ymm31, (32*0)(%rsp) + vmovdqu64 %ymm31, (32*1)(%rsp) + vmovdqu64 %ymm31, (32*2)(%rsp) + vmovdqu64 %ymm31, (32*3)(%rsp) + vmovdqu64 %ymm31, (32*4)(%rsp) + vmovdqu64 %ymm31, (32*5)(%rsp) + vmovdqu64 %ymm31, (32*6)(%rsp) + vmovdqu64 %ymm31, (32*7)(%rsp) + +.balign 32 +3: # xof loop + + cmp %r9, arg2 + jb 4f + + extract_state arg1, %r10, (\rate / 8) # Store the state into the digest buffer + call keccak_1600_permute # Perform the Keccak permutation + + subq %r9, arg2 # Subtract the rate from the remaining length + jz 5f # If equal, jump to the done label + addq %r9, %r10 # Adjust the output digest pointer for the next block + jmp 3b + +.balign 32 +4: # store final block + + # Store the state for the last block of SHAKE in the temporary buffer + xorq %r8, %r8 + extract_state %rsp, %r8, (\rate / 8) + + # Copy digest from the buffer to the output buffer byte by byte + lea (arg1, %r10), %r13 + movq %rsp, %r12 + + # arg2 is length in bytes + call keccak_1600_copy_digest + + # Clear the temporary buffer + vpxorq %ymm31, %ymm31, %ymm31 + vmovdqu64 %ymm31, (32*0)(%rsp) + vmovdqu64 %ymm31, (32*1)(%rsp) + vmovdqu64 %ymm31, (32*2)(%rsp) + vmovdqu64 %ymm31, (32*3)(%rsp) + vmovdqu64 %ymm31, (32*4)(%rsp) + vmovdqu64 %ymm31, (32*5)(%rsp) + vmovdqu64 %ymm31, (32*6)(%rsp) + vmovdqu64 %ymm31, (32*7)(%rsp) + +5: # done + + vpxorq %xmm16, %xmm16, %xmm16 + vmovdqa64 %ymm16, %ymm17 + vmovdqa64 %ymm16, %ymm18 + vmovdqa64 %ymm16, %ymm19 + vmovdqa64 %ymm16, %ymm20 + vmovdqa64 %ymm16, %ymm21 + vmovdqa64 %ymm16, %ymm22 + vmovdqa64 %ymm16, %ymm23 + vmovdqa64 %ymm16, %ymm24 + vmovdqa64 %ymm16, %ymm25 + vmovdqa64 %ymm16, %ymm26 + vmovdqa64 %ymm16, %ymm27 + vmovdqa64 %ymm16, %ymm28 + vmovdqa64 %ymm16, %ymm29 + vmovdqa64 %ymm16, %ymm30 + vmovdqa64 %ymm16, %ymm31 + vzeroall + + addq $8*32, %rsp + popq %r15 + popq %r14 + popq %r13 + popq %r12 + popq %rbx + popq %rbp +.endm + + +# Absorb input into state +# +# input: +# rate [in] SHA3 rate +.macro sha3_absorb rate + pushq %rbp + pushq %rbx + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 + + movq (arg1), arg1 # state.ctx into arg1 + + # check for partially processed block + movq (8*25)(arg1), %r14 + orq %r14, %r14 # s[25] == 0? + je 2f + + + # process remaining bytes if message long enough + movq $\rate, %r12 # c = rate - s[25] + subq %r14, %r12 # %r12 = capacity + + cmp %r12, arg3 # if mlen < capacity then cannot permute yet + jb 1f # skip permute + + movq arg3, %r10 + leaq (arg1, %r14), %r13 # %r13 = state + s[25] + movq arg2, arg3 + call keccak_1600_partial_add + + movq %r10, arg3 + + call keccak_1600_load_state + call keccak_1600_permute + + movq $0, (8*25)(arg1) # clear s[25] + jmp 3f # partial block done + +1: # skip permute + + movq arg3, %r11 # copy message length to %r11 + xorq %r12, %r12 # zero message offset + addq %r11, (8*25)(arg1) # store partially processed length in s[25] + addq %r14, arg1 # state += s[25] + jmp 6f + +2: # absorb start + + call keccak_1600_load_state + +3: # partial block done + + movq arg3, %r11 # copy message length to %r11 + xorq %r12, %r12 # zero message offset + + # Process the input message in blocks +.balign 32 +4: # main absorb loop + + cmpq $\rate, %r11 # compare mlen to rate + jb 5f + + absorb_bytes arg2, %r12, \rate # input + + subq $\rate, %r11 # Subtract the rate from the remaining length + addq $\rate, %r12 # Adjust the pointer to the next block of the input message + call keccak_1600_permute # Perform the Keccak permutation + + jmp 4b + +.balign 32 +5: # absorb loop done + + call keccak_1600_save_state + + addq %r11, (8*25)(arg1) # store partially processed length in s[25] + +6: # final partial add + + addq %r12, arg2 + movq arg1, %r13 + movq %r11, %r12 + call keccak_1600_partial_add + + vpxorq %xmm16, %xmm16, %xmm16 + vmovdqa64 %ymm16, %ymm17 + vmovdqa64 %ymm16, %ymm18 + vmovdqa64 %ymm16, %ymm19 + vmovdqa64 %ymm16, %ymm20 + vmovdqa64 %ymm16, %ymm21 + vmovdqa64 %ymm16, %ymm22 + vmovdqa64 %ymm16, %ymm23 + vmovdqa64 %ymm16, %ymm24 + vmovdqa64 %ymm16, %ymm25 + vmovdqa64 %ymm16, %ymm26 + vmovdqa64 %ymm16, %ymm27 + vmovdqa64 %ymm16, %ymm28 + vmovdqa64 %ymm16, %ymm29 + vmovdqa64 %ymm16, %ymm30 + vmovdqa64 %ymm16, %ymm31 + vzeroall + + popq %r15 + popq %r14 + popq %r13 + popq %r12 + popq %rbx + popq %rbp +.endm + + +# Finalize processing state and store SHA3 digest in output buffer +# +# Input: +# - output/arg1: pointer to the output buffer +# - state/arg2: pointer to the state +.macro sha3_finalize rate, eom_byte + pushq %rbp + pushq %rbx + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 + + movq (arg2), arg2 # state.ctx into arg2 + movq (8*25)(arg2), %r11 # load state offset from s[25] + xorb $\eom_byte, (arg2, %r11) # add EOM byte + xorb $0x80, (\rate-1)(arg2) # write EOM + + movq arg1, %r15 # save arg1 + movq arg2, arg1 + call keccak_1600_load_state + + movq %r15, arg1 + + # clobbers r13, r14 + call keccak_1600_permute + + # extract digest + xorq %r8, %r8 +.if \rate == SHA3_256_RATE + extract_state arg1, %r8, 4 +.endif +.if \rate == SHA3_384_RATE + extract_state arg1, %r8, 6 +.endif +.if \rate == SHA3_512_RATE + extract_state arg1, %r8, 8 +.endif + + vpxorq %xmm16, %xmm16, %xmm16 + vmovdqa64 %ymm16, %ymm17 + vmovdqa64 %ymm16, %ymm18 + vmovdqa64 %ymm16, %ymm19 + vmovdqa64 %ymm16, %ymm20 + vmovdqa64 %ymm16, %ymm21 + vmovdqa64 %ymm16, %ymm22 + vmovdqa64 %ymm16, %ymm23 + vmovdqa64 %ymm16, %ymm24 + vmovdqa64 %ymm16, %ymm25 + vmovdqa64 %ymm16, %ymm26 + vmovdqa64 %ymm16, %ymm27 + vmovdqa64 %ymm16, %ymm28 + vmovdqa64 %ymm16, %ymm29 + vmovdqa64 %ymm16, %ymm30 + vmovdqa64 %ymm16, %ymm31 + vzeroall + + popq %r15 + popq %r14 + popq %r13 + popq %r12 + popq %rbx + popq %rbp +.endm + + +# Absorb input into state +# +# input: +# rate [in] SHA3 rate +.macro shake_absorb rate + pushq %rbp + pushq %rbx + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 + + movq (arg1), arg1 # state.ctx into arg1 + + # check for partially processed block + movq (8*25)(arg1), %r14 + orq %r14, %r14 # check if s[25] is 0 + je 2f + + + # process remaining bytes if message long enough + movq $\rate, %r12 + subq %r14, %r12 # %r12 = capacity = rate - s[25] + + cmp %r12, arg3 # if mlen <= capacity then no permute + jbe 1f # skip permute + + subq %r12, arg3 + + # r13/state, arg2/input, r12/length + leaq (arg1, %r14), %r13 # %r13 = state + s[25] + call keccak_1600_partial_add # arg2 is updated + + call keccak_1600_load_state + + call keccak_1600_permute + + movq $0, (8*25)(arg1) # clear s[25] + jmp 3f + +1: # skip permute + + leaq (arg3, %r14), %r10 + movq %r10, (8*25)(arg1) # s[25] += inlen + # r13/state, arg2/input, r12/length + leaq (arg1, %r14), %r13 # state + s[24] + movq arg3, %r12 + call keccak_1600_partial_add + + cmpq $\rate, %r10 # s[25] >= rate + jb 6f + + call keccak_1600_load_state + call keccak_1600_permute + call keccak_1600_save_state + + movq $0, (8*25)(arg1) # clear s[25] + jmp 6f + +2: # main loop start + + call keccak_1600_load_state + +3: # partial block done + + movq arg3, %r11 # copy message length to %r11 + xorq %r12, %r12 # zero message offset + + # Process the input message in blocks +.balign 32 +4: # main loop + + cmpq $\rate, %r11 # compare mlen to rate + jb 5f + + absorb_bytes arg2, %r12, \rate + + subq $\rate, %r11 # Subtract the rate from the remaining length + addq $\rate, %r12 # Adjust the pointer to the next block of the input message + call keccak_1600_permute # Perform the Keccak permutation + + jmp 4b # next loop + +.balign 32 +5: # main loop done + + call keccak_1600_save_state + + movq %r11, (8*25)(arg1) # update s[25] + orq %r11, %r11 + jz 6f + + # r13/state, arg2/input, r12/length + addq %r12, arg2 + movq arg1, %r13 + movq %r11, %r12 + call keccak_1600_partial_add + +6: # done + + vpxorq %xmm16, %xmm16, %xmm16 + vmovdqa64 %ymm16, %ymm17 + vmovdqa64 %ymm16, %ymm18 + vmovdqa64 %ymm16, %ymm19 + vmovdqa64 %ymm16, %ymm20 + vmovdqa64 %ymm16, %ymm21 + vmovdqa64 %ymm16, %ymm22 + vmovdqa64 %ymm16, %ymm23 + vmovdqa64 %ymm16, %ymm24 + vmovdqa64 %ymm16, %ymm25 + vmovdqa64 %ymm16, %ymm26 + vmovdqa64 %ymm16, %ymm27 + vmovdqa64 %ymm16, %ymm28 + vmovdqa64 %ymm16, %ymm29 + vmovdqa64 %ymm16, %ymm30 + vmovdqa64 %ymm16, %ymm31 + vzeroall + + popq %r15 + popq %r14 + popq %r13 + popq %r12 + popq %rbx + popq %rbp +.endm + + +# Squeeze bytes from state to output buffer +# +# Input: +# - output/arg1: pointer to the output buffer for the message digest +# - outlen/arg2: length of the output in bytes +# - state/arg3: pointer to the state +.macro shake_squeeze rate + pushq %rbp + pushq %rbx + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 + + or arg2, arg2 + jz 5f + + movq arg1, %rax # rotate arg1 with arg3 + movq (arg3), arg1 # arg1 - state.ctx + movq %rax, arg3 # arg3 - output buffer + + # check for partially processed block + movq (8*25)(arg1), %r14 # s[25] - capacity + orq %r14, %r14 + jnz 1f + + call keccak_1600_load_state + + jmp 2f + +1: # no init permute + + # extract bytes: r13 - state/src, r10 - output/dst, r12 - length = min(capacity, outlen) + movl $\rate, %r11d + subq %r14, %r11 # state offset + leaq (arg1, %r11), %r13 + + movq arg3, %r10 + + movq %r14, %r12 + cmpq %r14, arg2 + cmovb arg2, %r12 # %r12 = min(capacity, outlen) + + subq %r12, arg2 # outlen -= length + subq %r12, %r14 # capacity -= length + movq %r14, (8*25)(arg1) # update s[25] + + call keccak_1600_extract_bytes + + orq %r14, %r14 + jnz 5f # check if s[25] not 0 + + movq %r10, arg3 # updated output buffer + + call keccak_1600_load_state + +.balign 32 +2: # main loop + + cmp $\rate, arg2 # outlen > r + jb 3f + + call keccak_1600_permute + + # Extract SHAKE rate bytes into the destination buffer + xorq %r8, %r8 + extract_state arg3, %r8, (\rate / 8) + + addq $\rate, arg3 # dst += r + subq $\rate, arg2 # outlen -= r + jmp 2b + +.balign 32 +3: # final extract + + or arg2, arg2 + jz 4f # no end permute + + movl $\rate, %r14d + subq arg2, %r14 + movq %r14, (8*25)(arg1) # s[25] = c + + call keccak_1600_permute + + call keccak_1600_save_state + + # extract bytes: r13 - state/src, r10 - output/dst, r12 - length + movq arg1, %r13 + movq arg3, %r10 + movq arg2, %r12 + call keccak_1600_extract_bytes + + jmp 5f # jump to done + +4: # no end permute + + movq $0, (8*25)(arg1) # s[25] = c + call keccak_1600_save_state + +5: # done + + vpxorq %xmm16, %xmm16, %xmm16 + vmovdqa64 %ymm16, %ymm17 + vmovdqa64 %ymm16, %ymm18 + vmovdqa64 %ymm16, %ymm19 + vmovdqa64 %ymm16, %ymm20 + vmovdqa64 %ymm16, %ymm21 + vmovdqa64 %ymm16, %ymm22 + vmovdqa64 %ymm16, %ymm23 + vmovdqa64 %ymm16, %ymm24 + vmovdqa64 %ymm16, %ymm25 + vmovdqa64 %ymm16, %ymm26 + vmovdqa64 %ymm16, %ymm27 + vmovdqa64 %ymm16, %ymm28 + vmovdqa64 %ymm16, %ymm29 + vmovdqa64 %ymm16, %ymm30 + vmovdqa64 %ymm16, %ymm31 + vzeroall + + popq %r15 + popq %r14 + popq %r13 + popq %r12 + popq %rbx + popq %rbp +.endm + + +.text + +# +# Compact API +# + +# ----------------------------------------------------------------------------- +# +# void SHA3_sha3_256_avx512vl(uint8_t *output, const uint8_t *input, size_t inplen); +# +.globl SHA3_sha3_256_avx512vl +.type SHA3_sha3_256_avx512vl,@function +.hidden SHA3_sha3_256_avx512vl +.balign 32 +SHA3_sha3_256_avx512vl: + sha3_complete SHA3_256_RATE + ret +.size SHA3_sha3_256_avx512vl,.-SHA3_sha3_256_avx512vl + +# ----------------------------------------------------------------------------- +# +# void SHA3_sha3_384_avx512vl(uint8_t *output, const uint8_t *input, size_t inplen); +# +.globl SHA3_sha3_384_avx512vl +.type SHA3_sha3_384_avx512vl,@function +.hidden SHA3_sha3_384_avx512vl +.balign 32 +SHA3_sha3_384_avx512vl: + sha3_complete SHA3_384_RATE + ret +.size SHA3_sha3_384_avx512vl,.-SHA3_sha3_384_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_sha3_512_avx512vl(uint8_t *output, const uint8_t *input, size_t inplen); +# +.globl SHA3_sha3_512_avx512vl +.type SHA3_sha3_512_avx512vl,@function +.hidden SHA3_sha3_512_avx512vl +.balign 32 +SHA3_sha3_512_avx512vl: + sha3_complete SHA3_512_RATE + ret +.size SHA3_sha3_512_avx512vl,.-SHA3_sha3_512_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake128_avx512vl(uint8_t *output, size_t outlen, const uint8_t *input, size_t inplen); +# +.globl SHA3_shake128_avx512vl +.type SHA3_shake128_avx512vl,@function +.hidden SHA3_shake128_avx512vl +.balign 32 +SHA3_shake128_avx512vl: + shake_complete SHAKE128_RATE + ret +.size SHA3_shake128_avx512vl,.-SHA3_shake128_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake256_avx512vl(uint8_t *output, size_t outlen, const uint8_t *input, size_t inplen); +# +.globl SHA3_shake256_avx512vl +.type SHA3_shake256_avx512vl,@function +.hidden SHA3_shake256_avx512vl +.balign 32 +SHA3_shake256_avx512vl: + shake_complete SHAKE256_RATE + ret +.size SHA3_shake256_avx512vl,.-SHA3_shake256_avx512vl + + +# +# Init/Reset API +# + +# ----------------------------------------------------------------------------- +# +# void SHA3_sha3_256_inc_ctx_reset_avx512vl(OQS_SHA3_sha3_256_inc_ctx *state); +# +.globl SHA3_sha3_256_inc_ctx_reset_avx512vl +.type SHA3_sha3_256_inc_ctx_reset_avx512vl,@function +.hidden SHA3_sha3_256_inc_ctx_reset_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_sha3_384_inc_ctx_reset_avx512vl(OQS_SHA3_sha3_384_inc_ctx *state); +# +.globl SHA3_sha3_384_inc_ctx_reset_avx512vl +.type SHA3_sha3_384_inc_ctx_reset_avx512vl,@function +.hidden SHA3_sha3_384_inc_ctx_reset_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_sha3_512_inc_ctx_reset_avx512vl(OQS_SHA3_sha3_512_inc_ctx *state); +# +.globl SHA3_sha3_512_inc_ctx_reset_avx512vl +.type SHA3_sha3_512_inc_ctx_reset_avx512vl,@function +.hidden SHA3_sha3_512_inc_ctx_reset_avx512vl + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake128_inc_ctx_reset_avx512vl(OQS_SHA3_shake128_inc_ctx *state); +# +.globl SHA3_shake128_inc_ctx_reset_avx512vl +.type SHA3_shake128_inc_ctx_reset_avx512vl,@function +.hidden SHA3_shake128_inc_ctx_reset_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake256_inc_ctx_reset_avx512vl(OQS_SHA3_shake256_inc_ctx *state); +# +.globl SHA3_shake256_inc_ctx_reset_avx512vl +.type SHA3_shake256_inc_ctx_reset_avx512vl,@function +.hidden SHA3_shake256_inc_ctx_reset_avx512vl + + +.balign 32 +SHA3_sha3_256_inc_ctx_reset_avx512vl: +SHA3_sha3_384_inc_ctx_reset_avx512vl: +SHA3_sha3_512_inc_ctx_reset_avx512vl: +SHA3_shake128_inc_ctx_reset_avx512vl: +SHA3_shake256_inc_ctx_reset_avx512vl: + movq (arg1), arg1 # load arg1.ctx + vpxorq %xmm0, %xmm0, %xmm0 + vmovdqu64 %ymm0, 0(arg1) # clear 200 bytes of arg1s + vmovdqu64 %ymm0, 32(arg1) + vmovdqu64 %ymm0, 64(arg1) + vmovdqu64 %ymm0, 96(arg1) + vmovdqu64 %ymm0, 128(arg1) + vmovdqu64 %ymm0, 160(arg1) + vmovdqu64 %xmm0, 192(arg1) # also clear additional 8 bytes s[25] + ret +.size SHA3_sha3_256_inc_ctx_reset_avx512vl,.-SHA3_sha3_256_inc_ctx_reset_avx512vl +.size SHA3_sha3_384_inc_ctx_reset_avx512vl,.-SHA3_sha3_384_inc_ctx_reset_avx512vl +.size SHA3_sha3_512_inc_ctx_reset_avx512vl,.-SHA3_sha3_512_inc_ctx_reset_avx512vl +.size SHA3_shake128_inc_ctx_reset_avx512vl,.-SHA3_shake128_inc_ctx_reset_avx512vl +.size SHA3_shake256_inc_ctx_reset_avx512vl,.-SHA3_shake256_inc_ctx_reset_avx512vl + + +# +# Absorb API +# + +# ----------------------------------------------------------------------------- +# +# void SHA3_sha3_256_inc_absorb_avx512vl(OQS_SHA3_sha3_256_inc_ctx *state, const uint8_t *input, size_t inlen); +# +.globl SHA3_sha3_256_inc_absorb_avx512vl +.type SHA3_sha3_256_inc_absorb_avx512vl,@function +.hidden SHA3_sha3_256_inc_absorb_avx512vl +.balign 32 +SHA3_sha3_256_inc_absorb_avx512vl: + sha3_absorb SHA3_256_RATE + ret +.size SHA3_sha3_256_inc_absorb_avx512vl,.-SHA3_sha3_256_inc_absorb_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_sha3_384_inc_absorb_avx512vl(OQS_SHA3_sha3_384_inc_ctx *state, const uint8_t *input, size_t inlen); +# +.globl SHA3_sha3_384_inc_absorb_avx512vl +.type SHA3_sha3_384_inc_absorb_avx512vl,@function +.hidden SHA3_sha3_384_inc_absorb_avx512vl +.balign 32 +SHA3_sha3_384_inc_absorb_avx512vl: + sha3_absorb SHA3_384_RATE + ret +.size SHA3_sha3_384_inc_absorb_avx512vl,.-SHA3_sha3_384_inc_absorb_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_sha3_512_inc_absorb_avx512vl(OQS_SHA3_sha3_512_inc_ctx *state, const uint8_t *input, size_t inlen); +# +.globl SHA3_sha3_512_inc_absorb_avx512vl +.type SHA3_sha3_512_inc_absorb_avx512vl,@function +.hidden SHA3_sha3_512_inc_absorb_avx512vl +.balign 32 +SHA3_sha3_512_inc_absorb_avx512vl: + sha3_absorb SHA3_512_RATE + ret +.size SHA3_sha3_512_inc_absorb_avx512vl,.-SHA3_sha3_512_inc_absorb_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake128_inc_absorb_avx512vl(OQS_SHA3_shake128_inc_ctx *state, const uint8_t *input, size_t inlen); +# +.globl SHA3_shake128_inc_absorb_avx512vl +.type SHA3_shake128_inc_absorb_avx512vl,@function +.hidden SHA3_shake128_inc_absorb_avx512vl +.balign 32 +SHA3_shake128_inc_absorb_avx512vl: + shake_absorb SHAKE128_RATE + ret +.size SHA3_shake128_inc_absorb_avx512vl,.-SHA3_shake128_inc_absorb_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake256_inc_absorb_avx512vl(OQS_SHA3_shake256_inc_ctx *state, const uint8_t *input, size_t inlen); +# +.globl SHA3_shake256_inc_absorb_avx512vl +.type SHA3_shake256_inc_absorb_avx512vl,@function +.hidden SHA3_shake256_inc_absorb_avx512vl +.balign 32 +SHA3_shake256_inc_absorb_avx512vl: + shake_absorb SHAKE256_RATE + ret +.size SHA3_shake256_inc_absorb_avx512vl,.-SHA3_shake256_inc_absorb_avx512vl + + +# +# Finalize API +# + +# ----------------------------------------------------------------------------- +# +# void SHA3_sha3_256_inc_finalize_avx512vl(uint8_t *output, OQS_SHA3_sha3_256_inc_ctx *state); +# +.globl SHA3_sha3_256_inc_finalize_avx512vl +.type SHA3_sha3_256_inc_finalize_avx512vl,@function +.hidden SHA3_sha3_256_inc_finalize_avx512vl +.balign 32 +SHA3_sha3_256_inc_finalize_avx512vl: + sha3_finalize SHA3_256_RATE, SHA3_256_EOM + ret +.size SHA3_sha3_256_inc_finalize_avx512vl,.-SHA3_sha3_256_inc_finalize_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_sha3_384_inc_finalize_avx512vl(uint8_t *output, OQS_SHA3_sha3_384_inc_ctx *state); +# +.globl SHA3_sha3_384_inc_finalize_avx512vl +.type SHA3_sha3_384_inc_finalize_avx512vl,@function +.hidden SHA3_sha3_384_inc_finalize_avx512vl +.balign 32 +SHA3_sha3_384_inc_finalize_avx512vl: + sha3_finalize SHA3_384_RATE, SHA3_384_EOM + ret +.size SHA3_sha3_384_inc_finalize_avx512vl,.-SHA3_sha3_384_inc_finalize_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_sha3_512_inc_finalize_avx512vl(uint8_t *output, OQS_SHA3_sha3_512_inc_ctx *state); +# +.globl SHA3_sha3_512_inc_finalize_avx512vl +.type SHA3_sha3_512_inc_finalize_avx512vl,@function +.hidden SHA3_sha3_512_inc_finalize_avx512vl +.balign 32 +SHA3_sha3_512_inc_finalize_avx512vl: + sha3_finalize SHA3_512_RATE, SHA3_512_EOM + ret +.size SHA3_sha3_512_inc_finalize_avx512vl,.-SHA3_sha3_512_inc_finalize_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake128_inc_finalize_avx512vl(OQS_SHA3_shake128_inc_ctx *state); +# +.globl SHA3_shake128_inc_finalize_avx512vl +.type SHA3_shake128_inc_finalize_avx512vl,@function +.hidden SHA3_shake128_inc_finalize_avx512vl +.balign 32 +SHA3_shake128_inc_finalize_avx512vl: + movq (arg1), arg1 # state.ctx into %rdi + movq (8*25)(arg1), %r11 # load state offset from s[25] + xorb $SHAKE128_EOM, (arg1, %r11) # add EOM + xorb $0x80, (SHAKE128_RATE-1)(arg1) # write EOM + movq $0, (8*25)(arg1) # clear s[25] + ret +.size SHA3_shake128_inc_finalize_avx512vl,.-SHA3_shake128_inc_finalize_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake256_inc_finalize_avx512vl(OQS_SHA3_shake256_inc_ctx *state); +# +.globl SHA3_shake256_inc_finalize_avx512vl +.type SHA3_shake256_inc_finalize_avx512vl,@function +.hidden SHA3_shake256_inc_finalize_avx512vl +.balign 32 +SHA3_shake256_inc_finalize_avx512vl: + movq (arg1), arg1 # state.ctx into %rdi + movq (8*25)(arg1), %r11 # load state offset from s[25] + xorb $SHAKE256_EOM, (arg1, %r11) # add EOM + xorb $0x80, (SHAKE256_RATE-1)(arg1) # write EOM + movq $0, (8*25)(arg1) # clear s[25] + ret +.size SHA3_shake256_inc_finalize_avx512vl,.-SHA3_shake256_inc_finalize_avx512vl + + +# +# Squeeze API +# + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake128_inc_squeeze_avx512vl(uint8_t *output, size_t outlen, OQS_SHA3_shake128_inc_ctx *state); +# +.globl SHA3_shake128_inc_squeeze_avx512vl +.type SHA3_shake128_inc_squeeze_avx512vl,@function +.hidden SHA3_shake128_inc_squeeze_avx512vl +.balign 32 +SHA3_shake128_inc_squeeze_avx512vl: + shake_squeeze SHAKE128_RATE + ret +.size SHA3_shake128_inc_squeeze_avx512vl,.-SHA3_shake128_inc_squeeze_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake256_inc_squeeze_avx512vl(uint8_t *output, size_t outlen, OQS_SHA3_shake256_inc_ctx *state); +# +.globl SHA3_shake256_inc_squeeze_avx512vl +.type SHA3_shake256_inc_squeeze_avx512vl,@function +.hidden SHA3_shake256_inc_squeeze_avx512vl +.balign 32 +SHA3_shake256_inc_squeeze_avx512vl: + shake_squeeze SHAKE256_RATE + ret +.size SHA3_shake256_inc_squeeze_avx512vl,.-SHA3_shake256_inc_squeeze_avx512vl + + +.section .rodata + +sha3_eom: +.byte 0x06 + +shake_eom: +.byte 0x1F + +.section .note.GNU-stack,"",%progbits diff --git a/src/common/sha3/avx512vl_low/SHA3-times4-AVX512VL.S b/src/common/sha3/avx512vl_low/SHA3-times4-AVX512VL.S new file mode 100644 index 000000000..393c3aaac --- /dev/null +++ b/src/common/sha3/avx512vl_low/SHA3-times4-AVX512VL.S @@ -0,0 +1,1277 @@ +# Copyright (c) 2025 Intel Corporation +# +# SPDX-License-Identifier: MIT + +# Define arg registers +.equ arg1, %rdi +.equ arg2, %rsi +.equ arg3, %rdx +.equ arg4, %rcx +.equ arg5, %r8 +.equ arg5d, %r8d +.equ arg6, %r9 + +# arg7-10 on stack +#define arg7 (2*8)(%rbp) +#define arg8 (3*8)(%rbp) +#define arg9 (4*8)(%rbp) +#define arg10 (5*8)(%rbp) + + +# Define SHA3 rates +.equ SHA3_256_RATE, 136 +.equ SHA3_384_RATE, 104 +.equ SHA3_512_RATE, 72 +.equ SHAKE128_RATE, 168 +.equ SHAKE256_RATE, 136 + + +# SHA3 multi-rate padding byte (added after the message) +.equ SHAKE_MRATE_PADDING, 0x1F + + +# Stack frame layout for shake128_x4 and shake256_x4 operations +.equ STATE_SIZE, ((25 * 8 * 4) + 8) +.equ sf_arg1, 0 +.equ sf_arg2, sf_arg1 + 8 # save arg2, output pointer +.equ sf_arg3, sf_arg2 + 8 # save arg3, output pointer +.equ sf_arg4, sf_arg3 + 8 # save arg4, output pointer +.equ sf_arg5, sf_arg4 + 8 # save arg5, output length +.equ sf_state_ptr, sf_arg5 + 8 # state context structure (pointer to a pointer) +.equ sf_state_x4, sf_state_ptr + 8 # start of x4 state structure +.equ sf_size, sf_state_x4 + STATE_SIZE + + +# External utility functions +.extern keccak_1600_init_state +.extern keccak_1600_permute +.extern keccak_1600_load_state_x4 +.extern keccak_1600_save_state_x4 +.extern keccak_1600_partial_add_x4 +.extern keccak_1600_extract_bytes_x4 + + +# Define macros for x4 operations + +# Absorb input bytes into x4 state registers +# ymm0-ymm24 [in] x4 state registers +# ymm30-ymm31 [clobbered] used as a temporary registers +.macro absorb_bytes_x4 input0, input1, input2, input3, offset, rate + vmovq (\input0, \offset), %xmm31 + vpinsrq $1, (\input1, \offset), %xmm31, %xmm31 + vmovq (\input2, \offset), %xmm30 + vpinsrq $1, (\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm0, %ymm0 + + vmovq 8(\input0, \offset), %xmm31 + vpinsrq $1, 8(\input1, \offset), %xmm31, %xmm31 + vmovq 8(\input2, \offset), %xmm30 + vpinsrq $1, 8(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm1, %ymm1 + + vmovq 16(\input0, \offset), %xmm31 + vpinsrq $1, 16(\input1, \offset), %xmm31, %xmm31 + vmovq 16(\input2, \offset), %xmm30 + vpinsrq $1, 16(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm2, %ymm2 + + vmovq 24(\input0, \offset), %xmm31 + vpinsrq $1, 24(\input1, \offset), %xmm31, %xmm31 + vmovq 24(\input2, \offset), %xmm30 + vpinsrq $1, 24(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm3, %ymm3 + + vmovq 32(\input0, \offset), %xmm31 + vpinsrq $1, 32(\input1, \offset), %xmm31, %xmm31 + vmovq 32(\input2, \offset), %xmm30 + vpinsrq $1, 32(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm4, %ymm4 + + vmovq 40(\input0, \offset), %xmm31 + vpinsrq $1, 40(\input1, \offset), %xmm31, %xmm31 + vmovq 40(\input2, \offset), %xmm30 + vpinsrq $1, 40(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm5, %ymm5 + + vmovq 48(\input0, \offset), %xmm31 + vpinsrq $1, 48(\input1, \offset), %xmm31, %xmm31 + vmovq 48(\input2, \offset), %xmm30 + vpinsrq $1, 48(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm6, %ymm6 + + vmovq 56(\input0, \offset), %xmm31 + vpinsrq $1, 56(\input1, \offset), %xmm31, %xmm31 + vmovq 56(\input2, \offset), %xmm30 + vpinsrq $1, 56(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm7, %ymm7 + + vmovq 64(\input0, \offset), %xmm31 + vpinsrq $1, 64(\input1, \offset), %xmm31, %xmm31 + vmovq 64(\input2, \offset), %xmm30 + vpinsrq $1, 64(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm8, %ymm8 + vmovq 72(\input0, \offset), %xmm31 + vpinsrq $1, 72(\input1, \offset), %xmm31, %xmm31 + vmovq 72(\input2, \offset), %xmm30 + vpinsrq $1, 72(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm9, %ymm9 + + vmovq 80(\input0, \offset), %xmm31 + vpinsrq $1, 80(\input1, \offset), %xmm31, %xmm31 + vmovq 80(\input2, \offset), %xmm30 + vpinsrq $1, 80(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm10, %ymm10 + + vmovq 88(\input0, \offset), %xmm31 + vpinsrq $1, 88(\input1, \offset), %xmm31, %xmm31 + vmovq 88(\input2, \offset), %xmm30 + vpinsrq $1, 88(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm11, %ymm11 + + vmovq 96(\input0, \offset), %xmm31 + vpinsrq $1, 96(\input1, \offset), %xmm31, %xmm31 + vmovq 96(\input2, \offset), %xmm30 + vpinsrq $1, 96(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm12, %ymm12 + vmovq 104(\input0, \offset), %xmm31 + vpinsrq $1, 104(\input1, \offset), %xmm31, %xmm31 + vmovq 104(\input2, \offset), %xmm30 + vpinsrq $1, 104(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm13, %ymm13 + + vmovq 112(\input0, \offset), %xmm31 + vpinsrq $1, 112(\input1, \offset), %xmm31, %xmm31 + vmovq 112(\input2, \offset), %xmm30 + vpinsrq $1, 112(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm14, %ymm14 + + vmovq 120(\input0, \offset), %xmm31 + vpinsrq $1, 120(\input1, \offset), %xmm31, %xmm31 + vmovq 120(\input2, \offset), %xmm30 + vpinsrq $1, 120(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm15, %ymm15 + + vmovq 128(\input0, \offset), %xmm31 + vpinsrq $1, 128(\input1, \offset), %xmm31, %xmm31 + vmovq 128(\input2, \offset), %xmm30 + vpinsrq $1, 128(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm16, %ymm16 +.if \rate > SHA3_256_RATE + # SHAKE128 rate + vmovq 136(\input0, \offset), %xmm31 + vpinsrq $1, 136(\input1, \offset), %xmm31, %xmm31 + vmovq 136(\input2, \offset), %xmm30 + vpinsrq $1, 136(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm17, %ymm17 + + vmovq 144(\input0, \offset), %xmm31 + vpinsrq $1, 144(\input1, \offset), %xmm31, %xmm31 + vmovq 144(\input2, \offset), %xmm30 + vpinsrq $1, 144(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm18, %ymm18 + + vmovq 152(\input0, \offset), %xmm31 + vpinsrq $1, 152(\input1, \offset), %xmm31, %xmm31 + vmovq 152(\input2, \offset), %xmm30 + vpinsrq $1, 152(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm19, %ymm19 + + vmovq 160(\input0, \offset), %xmm31 + vpinsrq $1, 160(\input1, \offset), %xmm31, %xmm31 + vmovq 160(\input2, \offset), %xmm30 + vpinsrq $1, 160(\input3, \offset), %xmm30, %xmm30 + vinserti32x4 $1, %xmm30, %ymm31, %ymm31 + vpxorq %ymm31, %ymm20, %ymm20 +.endif +.endm + +# Store x4 state from SIMD registers to memory +# +# ymm0-ymm24 [in] x4 state registers +# ymm31 [clobbered] used as a temporary register +.macro extract_state_x4 output0, output1, output2, output3, offset, n + vextracti64x2 $1, %ymm0, %xmm31 + vmovq %xmm0, (8*0)(\output0, \offset) + vpextrq $1, %xmm0, (8*0)(\output1, \offset) + vmovq %xmm31, (8*0)(\output2, \offset) + vpextrq $1, %xmm31, (8*0)(\output3, \offset) + + vextracti64x2 $1, %ymm1, %xmm31 + vmovq %xmm1, (8*1)(\output0, \offset) + vpextrq $1, %xmm1, (8*1)(\output1, \offset) + vmovq %xmm31, (8*1)(\output2, \offset) + vpextrq $1, %xmm31, (8*1)(\output3, \offset) + + vextracti64x2 $1, %ymm2, %xmm31 + vmovq %xmm2, (8*2)(\output0, \offset) + vpextrq $1, %xmm2, (8*2)(\output1, \offset) + vmovq %xmm31, (8*2)(\output2, \offset) + vpextrq $1, %xmm31, (8*2)(\output3, \offset) + + vextracti64x2 $1, %ymm3, %xmm31 + vmovq %xmm3, (8*3)(\output0, \offset) + vpextrq $1, %xmm3, (8*3)(\output1, \offset) + vmovq %xmm31, (8*3)(\output2, \offset) + vpextrq $1, %xmm31, (8*3)(\output3, \offset) + + vextracti64x2 $1, %ymm4, %xmm31 + vmovq %xmm4, (8*4)(\output0, \offset) + vpextrq $1, %xmm4, (8*4)(\output1, \offset) + vmovq %xmm31, (8*4)(\output2, \offset) + vpextrq $1, %xmm31, (8*4)(\output3, \offset) + + vextracti64x2 $1, %ymm5, %xmm31 + vmovq %xmm5, (8*5)(\output0, \offset) + vpextrq $1, %xmm5, (8*5)(\output1, \offset) + vmovq %xmm31, (8*5)(\output2, \offset) + vpextrq $1, %xmm31, (8*5)(\output3, \offset) + + vextracti64x2 $1, %ymm6, %xmm31 + vmovq %xmm6, (8*6)(\output0, \offset) + vpextrq $1, %xmm6, (8*6)(\output1, \offset) + vmovq %xmm31, (8*6)(\output2, \offset) + vpextrq $1, %xmm31, (8*6)(\output3, \offset) + + vextracti64x2 $1, %ymm7, %xmm31 + vmovq %xmm7, (8*7)(\output0, \offset) + vpextrq $1, %xmm7, (8*7)(\output1, \offset) + vmovq %xmm31, (8*7)(\output2, \offset) + vpextrq $1, %xmm31, (8*7)(\output3, \offset) + + vextracti64x2 $1, %ymm8, %xmm31 + vmovq %xmm8, (8*8)(\output0, \offset) + vpextrq $1, %xmm8, (8*8)(\output1, \offset) + vmovq %xmm31, (8*8)(\output2, \offset) + vpextrq $1, %xmm31, (8*8)(\output3, \offset) + + vextracti64x2 $1, %ymm9, %xmm31 + vmovq %xmm9, (8*9)(\output0, \offset) + vpextrq $1, %xmm9, (8*9)(\output1, \offset) + vmovq %xmm31, (8*9)(\output2, \offset) + vpextrq $1, %xmm31, (8*9)(\output3, \offset) + + vextracti64x2 $1, %ymm10, %xmm31 + vmovq %xmm10, (8*10)(\output0, \offset) + vpextrq $1, %xmm10, (8*10)(\output1, \offset) + vmovq %xmm31, (8*10)(\output2, \offset) + vpextrq $1, %xmm31, (8*10)(\output3, \offset) + + vextracti64x2 $1, %ymm11, %xmm31 + vmovq %xmm11, (8*11)(\output0, \offset) + vpextrq $1, %xmm11, (8*11)(\output1, \offset) + vmovq %xmm31, (8*11)(\output2, \offset) + vpextrq $1, %xmm31, (8*11)(\output3, \offset) + + vextracti64x2 $1, %ymm12, %xmm31 + vmovq %xmm12, (8*12)(\output0, \offset) + vpextrq $1, %xmm12, (8*12)(\output1, \offset) + vmovq %xmm31, (8*12)(\output2, \offset) + vpextrq $1, %xmm31, (8*12)(\output3, \offset) + + vextracti64x2 $1, %ymm13, %xmm31 + vmovq %xmm13, (8*13)(\output0, \offset) + vpextrq $1, %xmm13, (8*13)(\output1, \offset) + vmovq %xmm31, (8*13)(\output2, \offset) + vpextrq $1, %xmm31, (8*13)(\output3, \offset) + + vextracti64x2 $1, %ymm14, %xmm31 + vmovq %xmm14, (8*14)(\output0, \offset) + vpextrq $1, %xmm14, (8*14)(\output1, \offset) + vmovq %xmm31, (8*14)(\output2, \offset) + vpextrq $1, %xmm31, (8*14)(\output3, \offset) + + vextracti64x2 $1, %ymm15, %xmm31 + vmovq %xmm15, (8*15)(\output0, \offset) + vpextrq $1, %xmm15, (8*15)(\output1, \offset) + vmovq %xmm31, (8*15)(\output2, \offset) + vpextrq $1, %xmm31, (8*15)(\output3, \offset) + + vextracti64x2 $1, %ymm16, %xmm31 + vmovq %xmm16, (8*16)(\output0, \offset) + vpextrq $1, %xmm16, (8*16)(\output1, \offset) + vmovq %xmm31, (8*16)(\output2, \offset) + vpextrq $1, %xmm31, (8*16)(\output3, \offset) + +.if \n > 17 + vextracti64x2 $1, %ymm17, %xmm31 + vmovq %xmm17, (8*17)(\output0, \offset) + vpextrq $1, %xmm17, (8*17)(\output1, \offset) + vmovq %xmm31, (8*17)(\output2, \offset) + vpextrq $1, %xmm31, (8*17)(\output3, \offset) + + vextracti64x2 $1, %ymm18, %xmm31 + vmovq %xmm18, (8*18)(\output0, \offset) + vpextrq $1, %xmm18, (8*18)(\output1, \offset) + vmovq %xmm31, (8*18)(\output2, \offset) + vpextrq $1, %xmm31, (8*18)(\output3, \offset) + + vextracti64x2 $1, %ymm19, %xmm31 + vmovq %xmm19, (8*19)(\output0, \offset) + vpextrq $1, %xmm19, (8*19)(\output1, \offset) + vmovq %xmm31, (8*19)(\output2, \offset) + vpextrq $1, %xmm31, (8*19)(\output3, \offset) + + vextracti64x2 $1, %ymm20, %xmm31 + vmovq %xmm20, (8*20)(\output0, \offset) + vpextrq $1, %xmm20, (8*20)(\output1, \offset) + vmovq %xmm31, (8*20)(\output2, \offset) + vpextrq $1, %xmm31, (8*20)(\output3, \offset) +.endif +.endm + +.text + +# +# Init/Reset API +# + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake128_x4_inc_ctx_reset_avx512vl(OQS_SHA3_shake128_inc_ctx *state); +# +.globl SHA3_shake128_x4_inc_ctx_reset_avx512vl +.type SHA3_shake128_x4_inc_ctx_reset_avx512vl,@function +.hidden SHA3_shake128_x4_inc_ctx_reset_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake256_inc_ctx_reset_x4_avx512vl(OQS_SHA3_shake256_inc_ctx *state); +# +.globl SHA3_shake256_x4_inc_ctx_reset_avx512vl +.type SHA3_shake256_x4_inc_ctx_reset_avx512vl,@function +.hidden SHA3_shake256_x4_inc_ctx_reset_avx512vl + + +.balign 32 +SHA3_shake128_x4_inc_ctx_reset_avx512vl: +SHA3_shake256_x4_inc_ctx_reset_avx512vl: + movq (arg1), arg1 # load arg1.ctx + vpxorq %xmm31, %xmm31, %xmm31 + vmovdqu64 %ymm31, (32*0)(arg1) # clear 800 bytes of state + vmovdqu64 %ymm31, (32*1)(arg1) + vmovdqu64 %ymm31, (32*2)(arg1) + vmovdqu64 %ymm31, (32*3)(arg1) + vmovdqu64 %ymm31, (32*4)(arg1) + vmovdqu64 %ymm31, (32*5)(arg1) + vmovdqu64 %ymm31, (32*6)(arg1) + vmovdqu64 %ymm31, (32*7)(arg1) + vmovdqu64 %ymm31, (32*8)(arg1) + vmovdqu64 %ymm31, (32*9)(arg1) + vmovdqu64 %ymm31, (32*10)(arg1) + vmovdqu64 %ymm31, (32*11)(arg1) + vmovdqu64 %ymm31, (32*12)(arg1) + vmovdqu64 %ymm31, (32*13)(arg1) + vmovdqu64 %ymm31, (32*14)(arg1) + vmovdqu64 %ymm31, (32*15)(arg1) + vmovdqu64 %ymm31, (32*16)(arg1) + vmovdqu64 %ymm31, (32*17)(arg1) + vmovdqu64 %ymm31, (32*18)(arg1) + vmovdqu64 %ymm31, (32*19)(arg1) + vmovdqu64 %ymm31, (32*20)(arg1) + vmovdqu64 %ymm31, (32*21)(arg1) + vmovdqu64 %ymm31, (32*22)(arg1) + vmovdqu64 %ymm31, (32*23)(arg1) + vmovdqu64 %ymm31, (32*24)(arg1) + vmovq %xmm31, (32*25)(arg1) # also clear additional 8 bytes s[100] + ret +.size SHA3_shake128_x4_inc_ctx_reset_avx512vl,.-SHA3_shake128_x4_inc_ctx_reset_avx512vl +.size SHA3_shake256_x4_inc_ctx_reset_avx512vl,.-SHA3_shake256_x4_inc_ctx_reset_avx512vl + +# +# SHAKE128 API +# + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake128_x4_avx512vl(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, +# size_t outlen, const uint8_t *in0, const uint8_t *in1, +# const uint8_t *in2, const uint8_t *in3, size_t inlen); +# +.globl SHA3_shake128_x4_avx512vl +.type SHA3_shake128_x4_avx512vl,@function +.hidden SHA3_shake128_x4_avx512vl +.balign 32 +SHA3_shake128_x4_avx512vl: + pushq %rbp + movq %rsp, %rbp + pushq %rbx + + subq $sf_size, %rsp + movq %rsp, %rbx + + mov arg1, (sf_arg1)(%rbx) + mov arg2, (sf_arg2)(%rbx) + mov arg3, (sf_arg3)(%rbx) + mov arg4, (sf_arg4)(%rbx) + mov arg5, (sf_arg5)(%rbx) + + lea (sf_state_x4)(%rbx), arg1 # start of x4 state on the stack frame + mov arg1, (sf_state_ptr)(%rbx) + + # Initialize the state array to zero + call keccak_1600_init_state + + call keccak_1600_save_state_x4 + + movq $0, (8*100)(arg1) # clear s[100] + + lea (sf_state_ptr)(%rbx), arg1 + mov arg6, arg2 + mov arg7, arg3 + mov arg8, arg4 + mov arg9, arg5 + mov arg10, arg6 + call SHA3_shake128_x4_inc_absorb_avx512vl + + lea (sf_state_ptr)(%rbx), arg1 + call SHA3_shake128_x4_inc_finalize_avx512vl + + # squeeze + mov (sf_arg1)(%rbx), arg1 + mov (sf_arg2)(%rbx), arg2 + mov (sf_arg3)(%rbx), arg3 + mov (sf_arg4)(%rbx), arg4 + mov (sf_arg5)(%rbx), arg5 + lea (sf_state_ptr)(%rbx), arg6 + call SHA3_shake128_x4_inc_squeeze_avx512vl + + # Clear the temporary buffer + lea (sf_state_x4)(%rbx), %rax + vpxorq %ymm31, %ymm31, %ymm31 + vmovdqu64 %ymm31, (32*0)(%rax) + vmovdqu64 %ymm31, (32*1)(%rax) + vmovdqu64 %ymm31, (32*2)(%rax) + vmovdqu64 %ymm31, (32*3)(%rax) + vmovdqu64 %ymm31, (32*4)(%rax) + vmovdqu64 %ymm31, (32*5)(%rax) + vmovdqu64 %ymm31, (32*6)(%rax) + vmovdqu64 %ymm31, (32*7)(%rax) + vmovdqu64 %ymm31, (32*8)(%rax) + vmovdqu64 %ymm31, (32*9)(%rax) + vmovdqu64 %ymm31, (32*10)(%rax) + vmovdqu64 %ymm31, (32*11)(%rax) + vmovdqu64 %ymm31, (32*12)(%rax) + vmovdqu64 %ymm31, (32*13)(%rax) + vmovdqu64 %ymm31, (32*14)(%rax) + vmovdqu64 %ymm31, (32*15)(%rax) + vmovdqu64 %ymm31, (32*16)(%rax) + vmovdqu64 %ymm31, (32*17)(%rax) + vmovdqu64 %ymm31, (32*18)(%rax) + vmovdqu64 %ymm31, (32*19)(%rax) + vmovdqu64 %ymm31, (32*20)(%rax) + vmovdqu64 %ymm31, (32*21)(%rax) + vmovdqu64 %ymm31, (32*22)(%rax) + vmovdqu64 %ymm31, (32*23)(%rax) + vmovdqu64 %ymm31, (32*24)(%rax) + vmovq %xmm31, (32*25)(%rax) + + addq $sf_size, %rsp + popq %rbx + popq %rbp + ret +.size SHA3_shake128_x4_avx512vl,.-SHA3_shake128_x4_avx512vl + + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake128_x4_inc_absorb_avx512vl( +# OQS_SHA3_shake128_x4_inc_ctx *state, +# const uint8_t *in0, +# const uint8_t *in1, +# const uint8_t *in2, +# const uint8_t *in3, +# size_t inlen); +# +.globl SHA3_shake128_x4_inc_absorb_avx512vl +.type SHA3_shake128_x4_inc_absorb_avx512vl,@function +.hidden SHA3_shake128_x4_inc_absorb_avx512vl +.balign 32 +SHA3_shake128_x4_inc_absorb_avx512vl: + pushq %rbp + movq %rsp, %rbp + pushq %rbx + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 + + mov (arg1), arg1 # state.ctx into arg1 + + # check for partially processed block + movq (8*100)(arg1), %r14 + orq %r14, %r14 # s[100] == 0? + je .shake128_absorb_main_loop_start + + # process remaining bytes if message long enough + movq $SHAKE128_RATE, %r12 # c = rate - s[100] + subq %r14, %r12 # %r12 = capacity + + cmp %r12, arg6 # if mlen <= capacity then no permute + jbe .shake128_absorb_skip_permute + + subq %r12, arg6 + + # r10/state, arg2-arg5/inputs, r12/length + movq arg1, %r10 # %r10 = state + call keccak_1600_partial_add_x4 # arg2-arg5 are updated + + call keccak_1600_load_state_x4 + + call keccak_1600_permute + + movq $0, (8*100)(arg1) # clear s[100] + jmp .shake128_absorb_partial_block_done + +.shake128_absorb_skip_permute: + # r10/state, arg2-arg5/inputs, r12/length + movq arg1, %r10 + movq arg6, %r12 + call keccak_1600_partial_add_x4 + + leaq (arg6, %r14), %r15 + mov %r15, (8*100)(arg1) # s[100] += inlen + + cmpq $SHAKE128_RATE, %r15 # check s[100] below rate + jb .shake128_absorb_exit + + call keccak_1600_load_state_x4 + + call keccak_1600_permute + + call keccak_1600_save_state_x4 + + movq $0, (8*100)(arg1) # clear s[100] + jmp .shake128_absorb_exit + +.shake128_absorb_main_loop_start: + call keccak_1600_load_state_x4 + +.shake128_absorb_partial_block_done: + movq arg6, %r11 # copy message length to %r11 + xorq %r12, %r12 # zero message offset + + # Process the input message in blocks +.balign 32 +.shake128_absorb_while_loop: + cmpq $SHAKE128_RATE, %r11 # compare mlen to rate + jb .shake128_absorb_while_loop_done + + absorb_bytes_x4 arg2, arg3, arg4, arg5, %r12, SHAKE128_RATE + + subq $SHAKE128_RATE, %r11 # Subtract the rate from the remaining length + addq $SHAKE128_RATE, %r12 # Adjust the pointer to the next block of the input message + call keccak_1600_permute # Perform the Keccak permutation + + jmp .shake128_absorb_while_loop + +.balign 32 +.shake128_absorb_while_loop_done: + call keccak_1600_save_state_x4 + + mov %r11, (8*100)(arg1) # update s[100] + orq %r11, %r11 + jz .shake128_absorb_exit + + movq $0, (8*100)(arg1) # clear s[100] + + # r10/state, arg2-arg5/input, r12/length + movq arg1, %r10 + addq %r12, arg2 + addq %r12, arg3 + addq %r12, arg4 + addq %r12, arg5 + movq %r11, %r12 + call keccak_1600_partial_add_x4 + + mov %r11, (8*100)(arg1) # update s[100] + +.shake128_absorb_exit: + vpxorq %xmm16, %xmm16, %xmm16 + vmovdqa64 %ymm16, %ymm17 + vmovdqa64 %ymm16, %ymm18 + vmovdqa64 %ymm16, %ymm19 + vmovdqa64 %ymm16, %ymm20 + vmovdqa64 %ymm16, %ymm21 + vmovdqa64 %ymm16, %ymm22 + vmovdqa64 %ymm16, %ymm23 + vmovdqa64 %ymm16, %ymm24 + vmovdqa64 %ymm16, %ymm25 + vmovdqa64 %ymm16, %ymm26 + vmovdqa64 %ymm16, %ymm27 + vmovdqa64 %ymm16, %ymm28 + vmovdqa64 %ymm16, %ymm29 + vmovdqa64 %ymm16, %ymm30 + vmovdqa64 %ymm16, %ymm31 + vzeroall + + popq %r15 + popq %r14 + popq %r13 + popq %r12 + popq %rbx + popq %rbp + ret +.size SHA3_shake128_x4_inc_absorb_avx512vl,.-SHA3_shake128_x4_inc_absorb_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake128_x4_inc_finalize_avx512vl(OQS_SHA3_shake128_x4_inc_ctx *state); +# +.globl SHA3_shake128_x4_inc_finalize_avx512vl +.type SHA3_shake128_x4_inc_finalize_avx512vl,@function +.hidden SHA3_shake128_x4_inc_finalize_avx512vl +.balign 32 +SHA3_shake128_x4_inc_finalize_avx512vl: + mov (arg1), arg1 # state.ctx into arg1 + movq (8*100)(arg1), %r11 # load state offset from s[100] + movq %r11, %r10 + andl $~7, %r10d # offset to the state register + andl $7, %r11d # offset within the register + + # add EOM byte right after the message + vmovdqu32 (arg1, %r10, 4), %ymm31 + leaq shake_msg_pad_x4(%rip), %rax + subq %r11, %rax + vmovdqu32 (%rax), %ymm30 + vpxorq %ymm30, %ymm31, %ymm31 + vmovdqu32 %ymm31, (arg1, %r10, 4) + + # add terminating byte at offset equal to rate - 1 + vmovdqu32 (SHAKE128_RATE*4 - 4*8)(arg1), %ymm31 + vmovdqa32 shake_terminator_byte_x4(%rip), %ymm30 + vpxorq %ymm30, %ymm31, %ymm31 + vmovdqu32 %ymm31, (SHAKE128_RATE*4 - 4*8)(arg1) + + movq $0, (8*100)(arg1) # clear s[100] + vpxorq %ymm31, %ymm31, %ymm31 + ret +.size SHA3_shake128_x4_inc_finalize_avx512vl,.-SHA3_shake128_x4_inc_finalize_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake128_x4_inc_squeeze_avx512vl( +# uint8_t *out0, +# uint8_t *out1, +# uint8_t *out2, +# uint8_t *out3, +# size_t outlen, +# OQS_SHA3_shake128_x4_inc_ctx *state); +# +.globl SHA3_shake128_x4_inc_squeeze_avx512vl +.type SHA3_shake128_x4_inc_squeeze_avx512vl,@function +.hidden SHA3_shake128_x4_inc_squeeze_avx512vl +.balign 32 +SHA3_shake128_x4_inc_squeeze_avx512vl: + pushq %rbp + movq %rsp, %rbp + pushq %rbx + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 + + or arg5, arg5 + jz .shake128_squeeze_done + + mov (arg6), arg6 # arg6 - state.ctx + + # check for partially processed block + movq (8*100)(arg6), %r15 # s[100] - capacity + orq %r15, %r15 + jnz .shake128_squeeze_no_init_permute + + + movq arg1, %r14 + mov arg6, arg1 + call keccak_1600_load_state_x4 + + movq %r14, arg1 + + xorq %rax, %rax + jmp .shake128_squeeze_loop + + +.balign 32 +.shake128_squeeze_no_init_permute: + # extract bytes: r10 - state/src, arg1-arg4 - output/dst, r12 - length = min(capacity, outlen), r11 - offset + movq arg6, %r10 + + movq %r15, %r12 + cmpq %r15, arg5 + cmovb arg5, %r12 # %r12 = min(capacity, $outlen) + + sub %r12, arg5 # outlen -= length + + movl $SHAKE128_RATE, %r11d + subq %r15, %r11 # state offset + + subq %r12, %r15 # capacity -= length + mov %r15, (8*100)(arg6) # update s[100] + + call keccak_1600_extract_bytes_x4 + + orq %r15, %r15 + jnz .shake128_squeeze_done # check s[100] not zero + + movq arg1, %r14# preserve arg1 + mov arg6, arg1 + call keccak_1600_load_state_x4 + + movq %r14, arg1 + xorq %rax, %rax + +.balign 32 +.shake128_squeeze_loop: + cmp $SHAKE128_RATE, arg5 # outlen > r + jb .shake128_squeeze_final_extract + + call keccak_1600_permute + + # Extract SHAKE128 rate bytes into the destination buffer + extract_state_x4 arg1, arg2, arg3, arg4, %rax, (SHAKE128_RATE / 8) + + addq $SHAKE128_RATE, %rax # dst offset += r + sub $SHAKE128_RATE, arg5 # outlen -= r + jmp .shake128_squeeze_loop + +.balign 32 +.shake128_squeeze_final_extract: + or arg5, arg5 + jz .shake128_squeeze_no_end_permute + + # update output pointers + addq %rax, arg1 + addq %rax, arg2 + addq %rax, arg3 + addq %rax, arg4 + + movl $SHAKE128_RATE, %r15d + subq arg5, %r15 + mov %r15, (8*100)(arg6) # s[100] = c + + call keccak_1600_permute + + + movq arg1, %r14 + mov arg6, arg1 + call keccak_1600_save_state_x4 + + movq %r14, arg1 + + # extract bytes: r10 - state/src, arg1-arg4 - output/dst, r12 - length, r11 - offset = 0 + movq arg6, %r10 + movq arg5, %r12 + xorq %r11, %r11 + call keccak_1600_extract_bytes_x4 + + jmp .shake128_squeeze_done + +.shake128_squeeze_no_end_permute: + movq $0, (8*100)(arg6) # s[100] = 0 + mov arg6, arg1 + call keccak_1600_save_state_x4 + +.shake128_squeeze_done: + vpxorq %xmm16, %xmm16, %xmm16 + vmovdqa64 %ymm16, %ymm17 + vmovdqa64 %ymm16, %ymm18 + vmovdqa64 %ymm16, %ymm19 + vmovdqa64 %ymm16, %ymm20 + vmovdqa64 %ymm16, %ymm21 + vmovdqa64 %ymm16, %ymm22 + vmovdqa64 %ymm16, %ymm23 + vmovdqa64 %ymm16, %ymm24 + vmovdqa64 %ymm16, %ymm25 + vmovdqa64 %ymm16, %ymm26 + vmovdqa64 %ymm16, %ymm27 + vmovdqa64 %ymm16, %ymm28 + vmovdqa64 %ymm16, %ymm29 + vmovdqa64 %ymm16, %ymm30 + vmovdqa64 %ymm16, %ymm31 + vzeroall + + popq %r15 + popq %r14 + popq %r13 + popq %r12 + popq %rbx + popq %rbp + + ret +.size SHA3_shake128_x4_inc_squeeze_avx512vl,.-SHA3_shake128_x4_inc_squeeze_avx512vl + + +# +# SHAKE256 API +# + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake256_x4_avx512vl(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, +# size_t outlen, const uint8_t *in0, const uint8_t *in1, +# const uint8_t *in2, const uint8_t *in3, size_t inlen); +# +.globl SHA3_shake256_x4_avx512vl +.type SHA3_shake256_x4_avx512vl,@function +.hidden SHA3_shake256_x4_avx512vl +.balign 32 +SHA3_shake256_x4_avx512vl: + pushq %rbp + movq %rsp, %rbp + pushq %rbx + + subq $sf_size, %rsp + movq %rsp, %rbx + + mov arg1, (sf_arg1)(%rbx) + mov arg2, (sf_arg2)(%rbx) + mov arg3, (sf_arg3)(%rbx) + mov arg4, (sf_arg4)(%rbx) + mov arg5, (sf_arg5)(%rbx) + + lea (sf_state_x4)(%rbx), arg1 # start of x4 state on the stack frame + mov arg1, (sf_state_ptr)(%rbx) + + # Initialize the state array to zero + call keccak_1600_init_state + + call keccak_1600_save_state_x4 + + movq $0, (8*100)(arg1) # clear s[100] + + lea (sf_state_ptr)(%rbx), arg1 + mov arg6, arg2 + mov arg7, arg3 + mov arg8, arg4 + mov arg9, arg5 + mov arg10, arg6 + call SHA3_shake256_x4_inc_absorb_avx512vl + + + lea (sf_state_ptr)(%rbx), arg1 + call SHA3_shake256_x4_inc_finalize_avx512vl + + + # squeeze + mov (sf_arg1)(%rbx), arg1 + mov (sf_arg2)(%rbx), arg2 + mov (sf_arg3)(%rbx), arg3 + mov (sf_arg4)(%rbx), arg4 + mov (sf_arg5)(%rbx), arg5 + lea (sf_state_ptr)(%rbx), arg6 + call SHA3_shake256_x4_inc_squeeze_avx512vl + + # Clear the temporary buffer + lea (sf_state_x4)(%rbx), %rax + vpxorq %ymm31, %ymm31, %ymm31 + vmovdqu64 %ymm31, (32*0)(%rax) + vmovdqu64 %ymm31, (32*1)(%rax) + vmovdqu64 %ymm31, (32*2)(%rax) + vmovdqu64 %ymm31, (32*3)(%rax) + vmovdqu64 %ymm31, (32*4)(%rax) + vmovdqu64 %ymm31, (32*5)(%rax) + vmovdqu64 %ymm31, (32*6)(%rax) + vmovdqu64 %ymm31, (32*7)(%rax) + vmovdqu64 %ymm31, (32*8)(%rax) + vmovdqu64 %ymm31, (32*9)(%rax) + vmovdqu64 %ymm31, (32*10)(%rax) + vmovdqu64 %ymm31, (32*11)(%rax) + vmovdqu64 %ymm31, (32*12)(%rax) + vmovdqu64 %ymm31, (32*13)(%rax) + vmovdqu64 %ymm31, (32*14)(%rax) + vmovdqu64 %ymm31, (32*15)(%rax) + vmovdqu64 %ymm31, (32*16)(%rax) + vmovdqu64 %ymm31, (32*17)(%rax) + vmovdqu64 %ymm31, (32*18)(%rax) + vmovdqu64 %ymm31, (32*19)(%rax) + vmovdqu64 %ymm31, (32*20)(%rax) + vmovdqu64 %ymm31, (32*21)(%rax) + vmovdqu64 %ymm31, (32*22)(%rax) + vmovdqu64 %ymm31, (32*23)(%rax) + vmovdqu64 %ymm31, (32*24)(%rax) + vmovq %xmm31, (32*25)(%rax) + + addq $sf_size, %rsp + popq %rbx + popq %rbp + ret +.size SHA3_shake256_x4_avx512vl,.-SHA3_shake256_x4_avx512vl + + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake256_x4_inc_absorb_avx512vl( +# OQS_SHA3_shake256_x4_inc_ctx *state, +# const uint8_t *in0, +# const uint8_t *in1, +# const uint8_t *in2, +# const uint8_t *in3, +# size_t inlen); +# +.globl SHA3_shake256_x4_inc_absorb_avx512vl +.type SHA3_shake256_x4_inc_absorb_avx512vl,@function +.hidden SHA3_shake256_x4_inc_absorb_avx512vl +.balign 32 +SHA3_shake256_x4_inc_absorb_avx512vl: + pushq %rbp + movq %rsp, %rbp + pushq %rbx + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 + + mov (arg1), arg1 # state.ctx into arg1 + + # check for partially processed block + movq (8*100)(arg1), %r14 + orq %r14, %r14 # check s[100] is zero + je .shake256_absorb_main_loop_start + + + # process remaining bytes if message long enough + movq $SHAKE256_RATE, %r12 # c = rate - s[100] + subq %r14, %r12 # %r12 = capacity + + cmp %r12, arg6 # if mlen <= capacity then no permute + jbe .shake256_absorb_skip_permute + + subq %r12, arg6 + + # r10/state, arg2-arg5/inputs, r12/length + movq arg1, %r10 # %r10 = state + call keccak_1600_partial_add_x4 # arg2-arg5 are updated + + call keccak_1600_load_state_x4 + + call keccak_1600_permute + + movq $0, (8*100)(arg1) # clear s[100] + jmp .shake256_absorb_partial_block_done + +.shake256_absorb_skip_permute: + # r10/state, arg2-arg5/inputs, r12/length + movq arg1, %r10 + movq arg6, %r12 + call keccak_1600_partial_add_x4 + + leaq (arg6, %r14), %r15 + mov %r15, (8*100)(arg1) # s[100] += inlen + + cmpq $SHAKE256_RATE, %r15 # s[100] >= rate ? + jb .shake256_absorb_exit + + call keccak_1600_load_state_x4 + + call keccak_1600_permute + + call keccak_1600_save_state_x4 + + movq $0, (8*100)(arg1) # clear s[100] + jmp .shake256_absorb_exit + +.shake256_absorb_main_loop_start: + call keccak_1600_load_state_x4 + +.shake256_absorb_partial_block_done: + movq arg6, %r11 # copy message length to %r11 + xorq %r12, %r12 # zero message offset + + # Process the input message in blocks +.balign 32 +.shake256_absorb_while_loop: + cmpq $SHAKE256_RATE, %r11 # compare mlen to rate + jb .shake256_absorb_while_loop_done + + absorb_bytes_x4 arg2, arg3, arg4, arg5, %r12, SHAKE256_RATE + + subq $SHAKE256_RATE, %r11 # Subtract the rate from the remaining length + addq $SHAKE256_RATE, %r12 # Adjust the pointer to the next block of the input message + call keccak_1600_permute # Perform the Keccak permutation + + jmp .shake256_absorb_while_loop + +.balign 32 +.shake256_absorb_while_loop_done: + call keccak_1600_save_state_x4 + + mov %r11, (8*100)(arg1) # update s[100] + orq %r11, %r11 + jz .shake256_absorb_exit + + movq $0, (8*100)(arg1) # clear s[100] + + # r10/state, arg2-arg5/input, r12/length + movq arg1, %r10 + addq %r12, arg2 + addq %r12, arg3 + addq %r12, arg4 + addq %r12, arg5 + movq %r11, %r12 + call keccak_1600_partial_add_x4 + + mov %r11, (8*100)(arg1) # update s[100] + +.shake256_absorb_exit: + vpxorq %xmm16, %xmm16, %xmm16 + vmovdqa64 %ymm16, %ymm17 + vmovdqa64 %ymm16, %ymm18 + vmovdqa64 %ymm16, %ymm19 + vmovdqa64 %ymm16, %ymm20 + vmovdqa64 %ymm16, %ymm21 + vmovdqa64 %ymm16, %ymm22 + vmovdqa64 %ymm16, %ymm23 + vmovdqa64 %ymm16, %ymm24 + vmovdqa64 %ymm16, %ymm25 + vmovdqa64 %ymm16, %ymm26 + vmovdqa64 %ymm16, %ymm27 + vmovdqa64 %ymm16, %ymm28 + vmovdqa64 %ymm16, %ymm29 + vmovdqa64 %ymm16, %ymm30 + vmovdqa64 %ymm16, %ymm31 + vzeroall + + popq %r15 + popq %r14 + popq %r13 + popq %r12 + popq %rbx + popq %rbp + ret +.size SHA3_shake256_x4_inc_absorb_avx512vl,.-SHA3_shake256_x4_inc_absorb_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake256_x4_inc_finalize(OQS_SHA3_shake256_x4_inc_ctx *state); +# +.globl SHA3_shake256_x4_inc_finalize_avx512vl +.type SHA3_shake256_x4_inc_finalize_avx512vl,@function +.hidden SHA3_shake256_x4_inc_finalize_avx512vl +.balign 32 +SHA3_shake256_x4_inc_finalize_avx512vl: + mov (arg1), arg1 # state.ctx into arg1 + movq (8*100)(arg1), %r11 # load state offset from s[100] + movq %r11, %r10 + andl $~7, %r10d # offset to the state register + andl $7, %r11d # offset within the register + + # add EOM byte right after the message + vmovdqu32 (arg1, %r10, 4), %ymm31 + leaq shake_msg_pad_x4(%rip), %rax + subq %r11, %rax + vmovdqu32 (%rax), %ymm30 + vpxorq %ymm30, %ymm31, %ymm31 + vmovdqu32 %ymm31, (arg1, %r10, 4) + + # add terminating byte at offset equal to rate - 1 + vmovdqu32 (SHAKE256_RATE*4 - 4*8)(arg1), %ymm31 + vmovdqa32 shake_terminator_byte_x4(%rip), %ymm30 + vpxorq %ymm30, %ymm31, %ymm31 + vmovdqu32 %ymm31, (SHAKE256_RATE*4 - 4*8)(arg1) + + movq $0, (8*100)(arg1) # clear s[100] + vpxorq %ymm31, %ymm31, %ymm31 + ret +.size SHA3_shake256_x4_inc_finalize_avx512vl,.-SHA3_shake256_x4_inc_finalize_avx512vl + + +# ----------------------------------------------------------------------------- +# +# void SHA3_shake256_x4_inc_squeeze_avx512vl( +# uint8_t *out0, +# uint8_t *out1, +# uint8_t *out2, +# uint8_t *out3, +# size_t outlen, +# OQS_SHA3_shake256_x4_inc_ctx *state); +# +.globl SHA3_shake256_x4_inc_squeeze_avx512vl +.type SHA3_shake256_x4_inc_squeeze_avx512vl,@function +.hidden SHA3_shake256_x4_inc_squeeze_avx512vl +.balign 32 +SHA3_shake256_x4_inc_squeeze_avx512vl: + pushq %rbp + movq %rsp, %rbp + pushq %rbx + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 + + or arg5, arg5 + jz .shake256_squeeze_done + + mov (arg6), arg6 # arg6 - state.ctx + + # check for partially processed block + movq (8*100)(arg6), %r15 # s[100] - capacity + orq %r15, %r15 + jnz .shake256_squeeze_no_init_permute + + + movq arg1, %r14 + mov arg6, arg1 + call keccak_1600_load_state_x4 + + movq %r14, arg1 + + xorq %rax, %rax + jmp .shake256_squeeze_loop + +.balign 32 +.shake256_squeeze_no_init_permute: + # extract bytes: r10 - state/src, arg1-arg4 - output/dst, r12 - length = min(capacity, outlen), r11 - offset + movq arg6, %r10 + + movq %r15, %r12 + cmpq %r15, arg5 + cmovb arg5, %r12 # %r12 = min(capacity, $outlen) + + sub %r12, arg5 # outlen -= length + + movl $SHAKE256_RATE, %r11d + subq %r15, %r11 # state offset + + subq %r12, %r15 # capacity -= length + mov %r15, (8*100)(arg6) # update s[100] + + call keccak_1600_extract_bytes_x4 + + orq %r15, %r15 + jnz .shake256_squeeze_done # check s[100] not zero + + movq arg1, %r14 # preserve arg1 + mov arg6, arg1 + call keccak_1600_load_state_x4 + + movq %r14, arg1 + xorq %rax, %rax + +.balign 32 +.shake256_squeeze_loop: + cmp $SHAKE256_RATE, arg5 # outlen > r + jb .shake256_squeeze_final_extract + + call keccak_1600_permute + + # Extract SHAKE256 rate bytes into the destination buffer + extract_state_x4 arg1, arg2, arg3, arg4, %rax, (SHAKE256_RATE / 8) + + addq $SHAKE256_RATE, %rax # dst offset += r + sub $SHAKE256_RATE, arg5 # outlen -= r + jmp .shake256_squeeze_loop + +.balign 32 +.shake256_squeeze_final_extract: + or arg5, arg5 + jz .shake256_squeeze_no_end_permute + + # update output pointers + addq %rax, arg1 + addq %rax, arg2 + addq %rax, arg3 + addq %rax, arg4 + + movl $SHAKE256_RATE, %r15d + subq arg5, %r15 + mov %r15, (8*100)(arg6) # s[100] = c + + call keccak_1600_permute + + movq arg1, %r14 + mov arg6, arg1 + call keccak_1600_save_state_x4 + + movq %r14, arg1 + + # extract bytes: r10 - state/src, arg1-arg4 - output/dst, r12 - length, r11 - offset = 0 + movq arg6, %r10 + movq arg5, %r12 + xorq %r11, %r11 + call keccak_1600_extract_bytes_x4 + + jmp .shake256_squeeze_done + +.shake256_squeeze_no_end_permute: + movq $0, (8*100)(arg6) # s[100] = 0 + mov arg6, arg1 + call keccak_1600_save_state_x4 + +.shake256_squeeze_done: + vpxorq %xmm16, %xmm16, %xmm16 + vmovdqa64 %ymm16, %ymm17 + vmovdqa64 %ymm16, %ymm18 + vmovdqa64 %ymm16, %ymm19 + vmovdqa64 %ymm16, %ymm20 + vmovdqa64 %ymm16, %ymm21 + vmovdqa64 %ymm16, %ymm22 + vmovdqa64 %ymm16, %ymm23 + vmovdqa64 %ymm16, %ymm24 + vmovdqa64 %ymm16, %ymm25 + vmovdqa64 %ymm16, %ymm26 + vmovdqa64 %ymm16, %ymm27 + vmovdqa64 %ymm16, %ymm28 + vmovdqa64 %ymm16, %ymm29 + vmovdqa64 %ymm16, %ymm30 + vmovdqa64 %ymm16, %ymm31 + vzeroall + + popq %r15 + popq %r14 + popq %r13 + popq %r12 + popq %rbx + popq %rbp + + ret +.size SHA3_shake256_x4_inc_squeeze_avx512vl,.-SHA3_shake256_x4_inc_squeeze_avx512vl + + +.section .rodata + +# SHAKE128 and SHAKE256 use the same terminator byte +.balign 32 +shake_terminator_byte_x4: +.byte 0, 0, 0, 0, 0, 0, 0, 0x80 +.byte 0, 0, 0, 0, 0, 0, 0, 0x80 +.byte 0, 0, 0, 0, 0, 0, 0, 0x80 +.byte 0, 0, 0, 0, 0, 0, 0, 0x80 + +# SHAKE128 and SHAKE256 use the same multi-rate padding byte +.balign 8 +# This is not a mistake and these 8 zero bytes are required here. +# Address is decremented depending on the offset within the state register. +.byte 0, 0, 0, 0, 0, 0, 0, 0 +shake_msg_pad_x4: +.byte SHAKE_MRATE_PADDING, 0, 0, 0, 0, 0, 0, 0 +.byte SHAKE_MRATE_PADDING, 0, 0, 0, 0, 0, 0, 0 +.byte SHAKE_MRATE_PADDING, 0, 0, 0, 0, 0, 0, 0 +.byte SHAKE_MRATE_PADDING, 0, 0, 0, 0, 0, 0, 0 + +.section .note.GNU-stack,"",%progbits diff --git a/src/common/sha3/avx512vl_sha3.c b/src/common/sha3/avx512vl_sha3.c new file mode 100644 index 000000000..6c1805f21 --- /dev/null +++ b/src/common/sha3/avx512vl_sha3.c @@ -0,0 +1,242 @@ +/** + * \file avx512vl_sha3.c + * \brief Implementation of the OQS SHA3 API using the AVX512VL low interface. + * + * Copyright (c) 2025 Intel Corporation + * + * SPDX-License-Identifier: MIT + */ + +#include "sha3.h" + +#include +#include + +#include +#include +#include +#include + +#define KECCAK_CTX_ALIGNMENT 32 +#define _KECCAK_CTX_BYTES (200 + sizeof(uint64_t)) +#define KECCAK_CTX_BYTES \ + (KECCAK_CTX_ALIGNMENT * \ + ((_KECCAK_CTX_BYTES + KECCAK_CTX_ALIGNMENT - 1) / KECCAK_CTX_ALIGNMENT)) + +/* + * External compact functions + */ +extern void SHA3_sha3_256_avx512vl(uint8_t *output, const uint8_t *input, + size_t inlen); +extern void SHA3_sha3_384_avx512vl(uint8_t *output, const uint8_t *input, + size_t inlen); +extern void SHA3_sha3_512_avx512vl(uint8_t *output, const uint8_t *input, + size_t inlen); +extern void SHA3_shake128_avx512vl(uint8_t *output, size_t outlen, + const uint8_t *input, size_t inlen); +extern void SHA3_shake256_avx512vl(uint8_t *output, size_t outlen, + const uint8_t *input, size_t inlen); + +/* + * External reset functions + */ +extern void +SHA3_sha3_256_inc_ctx_reset_avx512vl(OQS_SHA3_sha3_256_inc_ctx *state); +extern void +SHA3_sha3_384_inc_ctx_reset_avx512vl(OQS_SHA3_sha3_384_inc_ctx *state); +extern void +SHA3_sha3_512_inc_ctx_reset_avx512vl(OQS_SHA3_sha3_512_inc_ctx *state); +extern void +SHA3_shake128_inc_ctx_reset_avx512vl(OQS_SHA3_shake128_inc_ctx *state); +extern void +SHA3_shake256_inc_ctx_reset_avx512vl(OQS_SHA3_shake256_inc_ctx *state); + +/* + * External absorb functions + */ +extern void SHA3_sha3_256_inc_absorb_avx512vl(OQS_SHA3_sha3_256_inc_ctx *state, + const uint8_t *input, + size_t inlen); +extern void SHA3_sha3_384_inc_absorb_avx512vl(OQS_SHA3_sha3_384_inc_ctx *state, + const uint8_t *input, + size_t inlen); +extern void SHA3_sha3_512_inc_absorb_avx512vl(OQS_SHA3_sha3_512_inc_ctx *state, + const uint8_t *input, + size_t inlen); +extern void SHA3_shake128_inc_absorb_avx512vl(OQS_SHA3_shake128_inc_ctx *state, + const uint8_t *input, + size_t inlen); +extern void SHA3_shake256_inc_absorb_avx512vl(OQS_SHA3_shake256_inc_ctx *state, + const uint8_t *input, + size_t inlen); +/* + * External finalize functions + */ +extern void +SHA3_sha3_256_inc_finalize_avx512vl(uint8_t *output, + OQS_SHA3_sha3_256_inc_ctx *state); +extern void +SHA3_sha3_384_inc_finalize_avx512vl(uint8_t *output, + OQS_SHA3_sha3_384_inc_ctx *state); +extern void +SHA3_sha3_512_inc_finalize_avx512vl(uint8_t *output, + OQS_SHA3_sha3_512_inc_ctx *state); +extern void +SHA3_shake128_inc_finalize_avx512vl(OQS_SHA3_shake128_inc_ctx *state); +extern void +SHA3_shake256_inc_finalize_avx512vl(OQS_SHA3_shake256_inc_ctx *state); + +/* + * External squeeze functions + */ +extern void +SHA3_shake128_inc_squeeze_avx512vl(uint8_t *output, size_t outlen, + OQS_SHA3_shake128_inc_ctx *state); +extern void +SHA3_shake256_inc_squeeze_avx512vl(uint8_t *output, size_t outlen, + OQS_SHA3_shake256_inc_ctx *state); + +/* + * SHA3-256 + */ +static void SHA3_sha3_256_inc_init_avx512vl(OQS_SHA3_sha3_256_inc_ctx *state) { + state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); + OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3"); + SHA3_sha3_256_inc_ctx_reset_avx512vl(state); +} + +static void +SHA3_sha3_256_inc_ctx_release_avx512vl(OQS_SHA3_sha3_256_inc_ctx *state) { + SHA3_sha3_256_inc_ctx_reset_avx512vl(state); + OQS_MEM_aligned_free(state->ctx); +} + +static void +SHA3_sha3_256_inc_ctx_clone_avx512vl(OQS_SHA3_sha3_256_inc_ctx *dest, + const OQS_SHA3_sha3_256_inc_ctx *src) { + memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); +} + +/* + * SHA3-384 + */ +static void SHA3_sha3_384_inc_init_avx512vl(OQS_SHA3_sha3_384_inc_ctx *state) { + state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); + OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3"); + SHA3_sha3_384_inc_ctx_reset_avx512vl(state); +} + +static void +SHA3_sha3_384_inc_ctx_release_avx512vl(OQS_SHA3_sha3_384_inc_ctx *state) { + SHA3_sha3_384_inc_ctx_reset_avx512vl(state); + OQS_MEM_aligned_free(state->ctx); +} + +static void +SHA3_sha3_384_inc_ctx_clone_avx512vl(OQS_SHA3_sha3_384_inc_ctx *dest, + const OQS_SHA3_sha3_384_inc_ctx *src) { + memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); +} + +/* + * SHA3-512 + */ +static void SHA3_sha3_512_inc_init_avx512vl(OQS_SHA3_sha3_512_inc_ctx *state) { + state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); + OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3"); + SHA3_sha3_512_inc_ctx_reset_avx512vl(state); +} + +static void +SHA3_sha3_512_inc_ctx_release_avx512vl(OQS_SHA3_sha3_512_inc_ctx *state) { + SHA3_sha3_512_inc_ctx_reset_avx512vl(state); + OQS_MEM_aligned_free(state->ctx); +} + +static void +SHA3_sha3_512_inc_ctx_clone_avx512vl(OQS_SHA3_sha3_512_inc_ctx *dest, + const OQS_SHA3_sha3_512_inc_ctx *src) { + memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); +} + +/* + * SHAKE128 + */ +static void SHA3_shake128_inc_init_avx512vl(OQS_SHA3_shake128_inc_ctx *state) { + state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); + OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3"); + SHA3_shake128_inc_ctx_reset_avx512vl(state); +} + +static void +SHA3_shake128_inc_ctx_clone_avx512vl(OQS_SHA3_shake128_inc_ctx *dest, + const OQS_SHA3_shake128_inc_ctx *src) { + memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); +} + +static void +SHA3_shake128_inc_ctx_release_avx512vl(OQS_SHA3_shake128_inc_ctx *state) { + SHA3_shake128_inc_ctx_reset_avx512vl(state); + OQS_MEM_aligned_free(state->ctx); +} + +/* + * SHAKE256 + */ +static void SHA3_shake256_inc_init_avx512vl(OQS_SHA3_shake256_inc_ctx *state) { + state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); + OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3"); + SHA3_shake256_inc_ctx_reset_avx512vl(state); +} + +static void +SHA3_shake256_inc_ctx_release_avx512vl(OQS_SHA3_shake256_inc_ctx *state) { + SHA3_shake256_inc_ctx_reset_avx512vl(state); + OQS_MEM_aligned_free(state->ctx); +} + +static void +SHA3_shake256_inc_ctx_clone_avx512vl(OQS_SHA3_shake256_inc_ctx *dest, + const OQS_SHA3_shake256_inc_ctx *src) { + memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); +} + +const struct OQS_SHA3_callbacks sha3_avx512vl_callbacks = { + SHA3_sha3_256_avx512vl, + SHA3_sha3_256_inc_init_avx512vl, + SHA3_sha3_256_inc_absorb_avx512vl, + SHA3_sha3_256_inc_finalize_avx512vl, + SHA3_sha3_256_inc_ctx_release_avx512vl, + SHA3_sha3_256_inc_ctx_reset_avx512vl, + SHA3_sha3_256_inc_ctx_clone_avx512vl, + SHA3_sha3_384_avx512vl, + SHA3_sha3_384_inc_init_avx512vl, + SHA3_sha3_384_inc_absorb_avx512vl, + SHA3_sha3_384_inc_finalize_avx512vl, + SHA3_sha3_384_inc_ctx_release_avx512vl, + SHA3_sha3_384_inc_ctx_reset_avx512vl, + SHA3_sha3_384_inc_ctx_clone_avx512vl, + SHA3_sha3_512_avx512vl, + SHA3_sha3_512_inc_init_avx512vl, + SHA3_sha3_512_inc_absorb_avx512vl, + SHA3_sha3_512_inc_finalize_avx512vl, + SHA3_sha3_512_inc_ctx_release_avx512vl, + SHA3_sha3_512_inc_ctx_reset_avx512vl, + SHA3_sha3_512_inc_ctx_clone_avx512vl, + SHA3_shake128_avx512vl, + SHA3_shake128_inc_init_avx512vl, + SHA3_shake128_inc_absorb_avx512vl, + SHA3_shake128_inc_finalize_avx512vl, + SHA3_shake128_inc_squeeze_avx512vl, + SHA3_shake128_inc_ctx_release_avx512vl, + SHA3_shake128_inc_ctx_clone_avx512vl, + SHA3_shake128_inc_ctx_reset_avx512vl, + SHA3_shake256_avx512vl, + SHA3_shake256_inc_init_avx512vl, + SHA3_shake256_inc_absorb_avx512vl, + SHA3_shake256_inc_finalize_avx512vl, + SHA3_shake256_inc_squeeze_avx512vl, + SHA3_shake256_inc_ctx_release_avx512vl, + SHA3_shake256_inc_ctx_clone_avx512vl, + SHA3_shake256_inc_ctx_reset_avx512vl, +}; diff --git a/src/common/sha3/avx512vl_sha3x4.c b/src/common/sha3/avx512vl_sha3x4.c new file mode 100644 index 000000000..ee596ff5f --- /dev/null +++ b/src/common/sha3/avx512vl_sha3x4.c @@ -0,0 +1,133 @@ +/** + * \file avx512vl_sha3x4.c + * \brief Implementation of the OQS SHA3 times 4 API using the AVX512VL low interface. + * + * Copyright (c) 2025 Intel Corporation + * + * SPDX-License-Identifier: MIT + */ + +#include "sha3x4.h" + +#include +#include + +#include +#include +#include +#include + +#define KECCAK_X4_CTX_ALIGNMENT 32 +#define _KECCAK_X4_CTX_BYTES (800 + sizeof(uint64_t)) +#define KECCAK_X4_CTX_BYTES \ + (KECCAK_X4_CTX_ALIGNMENT * \ + ((_KECCAK_X4_CTX_BYTES + KECCAK_X4_CTX_ALIGNMENT - 1) / KECCAK_X4_CTX_ALIGNMENT)) + +/********** SHAKE128 ***********/ + +/* SHAKE128 external */ + +extern void +SHA3_shake128_x4_avx512vl(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, + const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, + const uint8_t *in3, size_t inlen); + +extern void +SHA3_shake128_x4_inc_ctx_reset_avx512vl(OQS_SHA3_shake128_x4_inc_ctx *state); + +extern void +SHA3_shake128_x4_inc_absorb_avx512vl(OQS_SHA3_shake128_x4_inc_ctx *state, const uint8_t *in0, + const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, + size_t inlen); + +extern void +SHA3_shake128_x4_inc_finalize_avx512vl(OQS_SHA3_shake128_x4_inc_ctx *state); + +extern void +SHA3_shake128_x4_inc_squeeze_avx512vl(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, + size_t outlen, OQS_SHA3_shake128_x4_inc_ctx *state); + +/* SHAKE128 incremental */ + +static void +SHA3_shake128_x4_inc_init_avx512vl(OQS_SHA3_shake128_x4_inc_ctx *state) { + state->ctx = OQS_MEM_aligned_alloc(KECCAK_X4_CTX_ALIGNMENT, KECCAK_X4_CTX_BYTES); + OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3x4"); + SHA3_shake128_x4_inc_ctx_reset_avx512vl(state); +} + +static void +SHA3_shake128_x4_inc_ctx_clone_avx512vl(OQS_SHA3_shake128_x4_inc_ctx *dest, + const OQS_SHA3_shake128_x4_inc_ctx *src) { + memcpy(dest->ctx, src->ctx, KECCAK_X4_CTX_BYTES); +} + +static void +SHA3_shake128_x4_inc_ctx_release_avx512vl(OQS_SHA3_shake128_x4_inc_ctx *state) { + SHA3_shake128_x4_inc_ctx_reset_avx512vl(state); + OQS_MEM_aligned_free(state->ctx); +} + +/********** SHAKE256 ***********/ + +/* SHAKE256 external */ + +extern void +SHA3_shake256_x4_avx512vl(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, + const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, + const uint8_t *in3, size_t inlen); + +extern void +SHA3_shake256_x4_inc_ctx_reset_avx512vl(OQS_SHA3_shake256_x4_inc_ctx *state); + +extern void +SHA3_shake256_x4_inc_absorb_avx512vl(OQS_SHA3_shake256_x4_inc_ctx *state, const uint8_t *in0, + const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, + size_t inlen); + +extern void +SHA3_shake256_x4_inc_finalize_avx512vl(OQS_SHA3_shake256_x4_inc_ctx *state); + +extern void +SHA3_shake256_x4_inc_squeeze_avx512vl(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, + size_t outlen, OQS_SHA3_shake256_x4_inc_ctx *state); + +/* SHAKE256 incremental */ + +static void +SHA3_shake256_x4_inc_init_avx512vl(OQS_SHA3_shake256_x4_inc_ctx *state) { + state->ctx = OQS_MEM_aligned_alloc(KECCAK_X4_CTX_ALIGNMENT, KECCAK_X4_CTX_BYTES); + OQS_EXIT_IF_NULLPTR(state->ctx, "SHA3x4"); + SHA3_shake256_x4_inc_ctx_reset_avx512vl(state); +} + +static void +SHA3_shake256_x4_inc_ctx_clone_avx512vl(OQS_SHA3_shake256_x4_inc_ctx *dest, + const OQS_SHA3_shake256_x4_inc_ctx *src) { + memcpy(dest->ctx, src->ctx, KECCAK_X4_CTX_BYTES); +} + +static void +SHA3_shake256_x4_inc_ctx_release_avx512vl(OQS_SHA3_shake256_x4_inc_ctx *state) { + SHA3_shake256_x4_inc_ctx_reset_avx512vl(state); + OQS_MEM_aligned_free(state->ctx); +} + +const struct OQS_SHA3_x4_callbacks sha3_x4_avx512vl_callbacks = { + SHA3_shake128_x4_avx512vl, + SHA3_shake128_x4_inc_init_avx512vl, + SHA3_shake128_x4_inc_absorb_avx512vl, + SHA3_shake128_x4_inc_finalize_avx512vl, + SHA3_shake128_x4_inc_squeeze_avx512vl, + SHA3_shake128_x4_inc_ctx_release_avx512vl, + SHA3_shake128_x4_inc_ctx_clone_avx512vl, + SHA3_shake128_x4_inc_ctx_reset_avx512vl, + SHA3_shake256_x4_avx512vl, + SHA3_shake256_x4_inc_init_avx512vl, + SHA3_shake256_x4_inc_absorb_avx512vl, + SHA3_shake256_x4_inc_finalize_avx512vl, + SHA3_shake256_x4_inc_squeeze_avx512vl, + SHA3_shake256_x4_inc_ctx_release_avx512vl, + SHA3_shake256_x4_inc_ctx_clone_avx512vl, + SHA3_shake256_x4_inc_ctx_reset_avx512vl, +}; diff --git a/src/common/sha3/xkcp_sha3.c b/src/common/sha3/xkcp_sha3.c index 32b0db6a8..8087c7f02 100644 --- a/src/common/sha3/xkcp_sha3.c +++ b/src/common/sha3/xkcp_sha3.c @@ -37,10 +37,19 @@ static KeccakPermuteFn *Keccak_Permute_ptr = NULL; static KeccakExtractBytesFn *Keccak_ExtractBytes_ptr = NULL; static KeccakFastLoopAbsorbFn *Keccak_FastLoopAbsorb_ptr = NULL; +extern struct OQS_SHA3_callbacks sha3_default_callbacks; + static void Keccak_Dispatch(void) { // TODO: Simplify this when we have a Windows-compatible AVX2 implementation of SHA3 #if defined(OQS_DIST_X86_64_BUILD) #if defined(OQS_ENABLE_SHA3_xkcp_low_avx2) +#if defined(OQS_USE_SHA3_AVX512VL) + if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX512)) { + extern const struct OQS_SHA3_callbacks sha3_avx512vl_callbacks; + + sha3_default_callbacks = sha3_avx512vl_callbacks; + } +#endif if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2)) { Keccak_Initialize_ptr = &KeccakP1600_Initialize_avx2; Keccak_AddByte_ptr = &KeccakP1600_AddByte_avx2; @@ -385,8 +394,6 @@ static void SHA3_shake256_inc_ctx_reset(OQS_SHA3_shake256_inc_ctx *state) { keccak_inc_reset((uint64_t *)state->ctx); } -extern struct OQS_SHA3_callbacks sha3_default_callbacks; - struct OQS_SHA3_callbacks sha3_default_callbacks = { SHA3_sha3_256, SHA3_sha3_256_inc_init, diff --git a/src/common/sha3/xkcp_sha3x4.c b/src/common/sha3/xkcp_sha3x4.c index 893744def..bbf3f34a3 100644 --- a/src/common/sha3/xkcp_sha3x4.c +++ b/src/common/sha3/xkcp_sha3x4.c @@ -31,10 +31,19 @@ static KeccakX4AddBytesFn *Keccak_X4_AddBytes_ptr = NULL; static KeccakX4PermuteFn *Keccak_X4_Permute_ptr = NULL; static KeccakX4ExtractBytesFn *Keccak_X4_ExtractBytes_ptr = NULL; +extern struct OQS_SHA3_x4_callbacks sha3_x4_default_callbacks; + static void Keccak_X4_Dispatch(void) { // TODO: Simplify this when we have a Windows-compatible AVX2 implementation of SHA3 #if defined(OQS_DIST_X86_64_BUILD) #if defined(OQS_ENABLE_SHA3_xkcp_low_avx2) +#if defined(OQS_USE_SHA3_AVX512VL) + if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX512)) { + extern const struct OQS_SHA3_x4_callbacks sha3_x4_avx512vl_callbacks; + + sha3_x4_default_callbacks = sha3_x4_avx512vl_callbacks; + } +#endif if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2)) { Keccak_X4_Initialize_ptr = &KeccakP1600times4_InitializeAll_avx2; Keccak_X4_AddByte_ptr = &KeccakP1600times4_AddByte_avx2; @@ -238,8 +247,6 @@ static void SHA3_shake256_x4_inc_ctx_reset(OQS_SHA3_shake256_x4_inc_ctx *state) keccak_x4_inc_reset((uint64_t *)state->ctx); } -extern struct OQS_SHA3_x4_callbacks sha3_x4_default_callbacks; - struct OQS_SHA3_x4_callbacks sha3_x4_default_callbacks = { SHA3_shake128_x4, SHA3_shake128_x4_inc_init, diff --git a/src/oqsconfig.h.cmake b/src/oqsconfig.h.cmake index ddfc87289..5d02314c2 100644 --- a/src/oqsconfig.h.cmake +++ b/src/oqsconfig.h.cmake @@ -69,6 +69,7 @@ #cmakedefine OQS_ENABLE_TEST_CONSTANT_TIME 1 #cmakedefine OQS_ENABLE_SHA3_xkcp_low_avx2 1 +#cmakedefine OQS_USE_SHA3_AVX512VL 1 #cmakedefine01 OQS_USE_CUPQC diff --git a/tests/system_info.c b/tests/system_info.c index 7b52ffe7a..f51006137 100644 --- a/tests/system_info.c +++ b/tests/system_info.c @@ -252,6 +252,14 @@ static void print_oqs_configuration(void) { #endif #if defined(OQS_USE_SHA3_OPENSSL) printf("SHA-3: OpenSSL\n"); +#elif defined(OQS_USE_SHA3_AVX512VL) + if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX512)) { + printf("SHA-3: AVX512VL\n"); + } else if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2)) { + printf("SHA-3: AVX2\n"); + } else { + printf("SHA-3: C\n"); + } #else printf("SHA-3: C\n"); #endif diff --git a/tests/test_binary.py b/tests/test_binary.py index 9686388f6..3b8a3ec6c 100644 --- a/tests/test_binary.py +++ b/tests/test_binary.py @@ -33,7 +33,7 @@ def test_namespace(): symbols.append(line) # ideally this would be just ['oqs', 'pqclean'], but contains exceptions (e.g., providing compat implementations of unavailable platform functions) - namespaces = ['oqs', 'pqclean', 'keccak', 'pqcrystals', 'pqmayo', 'init', 'fini', 'seedexpander', '__x86.get_pc_thunk', 'libjade', 'jade', '__jade', '__jasmin_syscall', 'pqcp', 'pqov', '_snova'] + namespaces = ['oqs', 'pqclean', 'keccak', 'pqcrystals', 'pqmayo', 'init', 'fini', 'seedexpander', '__x86.get_pc_thunk', 'libjade', 'jade', '__jade', '__jasmin_syscall', 'pqcp', 'pqov', '_snova', 'sha3'] non_namespaced = [] for symbolstr in symbols: diff --git a/tests/test_sha3.c b/tests/test_sha3.c index eadd9a095..b402ddc0b 100644 --- a/tests/test_sha3.c +++ b/tests/test_sha3.c @@ -197,6 +197,135 @@ int sha3_256_kat_test(void) { return status; } +int sha3_384_kat_test(void) { + int status = EXIT_SUCCESS; + + uint8_t output[48]; + + const uint8_t msg0[1] = {0x0}; + const uint8_t msg24[3] = {0x61, 0x62, 0x63}; + const uint8_t msg448[56] = { + 0x61, 0x62, 0x63, 0x64, 0x62, 0x63, 0x64, 0x65, 0x63, 0x64, 0x65, 0x66, 0x64, 0x65, 0x66, 0x67, + 0x65, 0x66, 0x67, 0x68, 0x66, 0x67, 0x68, 0x69, 0x67, 0x68, 0x69, 0x6A, 0x68, 0x69, 0x6A, 0x6B, + 0x69, 0x6A, 0x6B, 0x6C, 0x6A, 0x6B, 0x6C, 0x6D, 0x6B, 0x6C, 0x6D, 0x6E, 0x6C, 0x6D, 0x6E, 0x6F, + 0x6D, 0x6E, 0x6F, 0x70, 0x6E, 0x6F, 0x70, 0x71 + }; + const uint8_t msg1600[200] = { + 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, + 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, + 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, + 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, + 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, + 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, + 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, + 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, + 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, + 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, + 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, + 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, + 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3, 0xA3 + }; + + const uint8_t exp0[48] = { + 0x0C, 0x63, 0xA7, 0x5B, 0x84, 0x5E, 0x4F, 0x7D, 0x01, 0x10, 0x7D, 0x85, 0x2E, 0x4C, 0x24, + 0x85, 0xC5, 0x1A, 0x50, 0xAA, 0xAA, 0x94, 0xFC, 0x61, 0x99, 0x5E, 0x71, 0xBB, 0xEE, 0x98, + 0x3A, 0x2A, 0xC3, 0x71, 0x38, 0x31, 0x26, 0x4A, 0xDB, 0x47, 0xFB, 0x6B, 0xD1, 0xE0, 0x58, + 0xD5, 0xF0, 0x04 + }; + const uint8_t exp24[48] = { + 0xEC, 0x01, 0x49, 0x82, 0x88, 0x51, 0x6F, 0xC9, 0x26, 0x45, 0x9F, 0x58, 0xE2, 0xC6, 0xAD, + 0x8D, 0xF9, 0xB4, 0x73, 0xCB, 0x0F, 0xC0, 0x8C, 0x25, 0x96, 0xDA, 0x7C, 0xF0, 0xE4, 0x9B, + 0xE4, 0xB2, 0x98, 0xD8, 0x8C, 0xEA, 0x92, 0x7A, 0xC7, 0xF5, 0x39, 0xF1, 0xED, 0xF2, 0x28, + 0x37, 0x6D, 0x25 + }; + const uint8_t exp448[48] = { + 0x99, 0x1C, 0x66, 0x57, 0x55, 0xEB, 0x3A, 0x4B, 0x6B, 0xBD, 0xFB, 0x75, 0xC7, 0x8A, 0x49, + 0x2E, 0x8C, 0x56, 0xA2, 0x2C, 0x5C, 0x4D, 0x7E, 0x42, 0x9B, 0xFD, 0xBC, 0x32, 0xB9, 0xD4, + 0xAD, 0x5A, 0xA0, 0x4A, 0x1F, 0x07, 0x6E, 0x62, 0xFE, 0xA1, 0x9E, 0xEF, 0x51, 0xAC, 0xD0, + 0x65, 0x7C, 0x22 + }; + const uint8_t exp1600[48] = { + 0x18, 0x81, 0xDE, 0x2C, 0xA7, 0xE4, 0x1E, 0xF9, 0x5D, 0xC4, 0x73, 0x2B, 0x8F, 0x5F, 0x00, + 0x2B, 0x18, 0x9C, 0xC1, 0xE4, 0x2B, 0x74, 0x16, 0x8E, 0xD1, 0x73, 0x26, 0x49, 0xCE, 0x1D, + 0xBC, 0xDD, 0x76, 0x19, 0x7A, 0x31, 0xFD, 0x55, 0xEE, 0x98, 0x9F, 0x2D, 0x70, 0x50, 0xDD, + 0x47, 0x3E, 0x8F + }; + + /* test compact api */ + + clear8(output, 48); + OQS_SHA3_sha3_384(output, msg0, 0); + + if (are_equal8(output, exp0, 48) == EXIT_FAILURE) { + status = EXIT_FAILURE; + } + + clear8(output, 48); + OQS_SHA3_sha3_384(output, msg24, 3); + + if (are_equal8(output, exp24, 48) == EXIT_FAILURE) { + status = EXIT_FAILURE; + } + + clear8(output, 48); + OQS_SHA3_sha3_384(output, msg448, 56); + + if (are_equal8(output, exp448, 48) == EXIT_FAILURE) { + status = EXIT_FAILURE; + } + + clear8(output, 48); + OQS_SHA3_sha3_384(output, msg1600, 200); + + if (are_equal8(output, exp1600, 48) == EXIT_FAILURE) { + status = EXIT_FAILURE; + } + + /* test long-form api */ + + OQS_SHA3_sha3_384_inc_ctx state; + uint8_t hash[200]; + + clear8(hash, 200); + OQS_SHA3_sha3_384_inc_init(&state); + OQS_SHA3_sha3_384_inc_absorb(&state, msg0, 0); + OQS_SHA3_sha3_384_inc_finalize(hash, &state); + + if (are_equal8(hash, exp0, 48) == EXIT_FAILURE) { + status = EXIT_FAILURE; + } + + clear8(hash, 200); + OQS_SHA3_sha3_384_inc_ctx_reset(&state); + OQS_SHA3_sha3_384_inc_absorb(&state, msg24, 3); + OQS_SHA3_sha3_384_inc_finalize(hash, &state); + + if (are_equal8(hash, exp24, 48) == EXIT_FAILURE) { + status = EXIT_FAILURE; + } + + clear8(hash, 200); + OQS_SHA3_sha3_384_inc_ctx_reset(&state); + OQS_SHA3_sha3_384_inc_absorb(&state, msg448, 56); + OQS_SHA3_sha3_384_inc_finalize(hash, &state); + + if (are_equal8(hash, exp448, 48) == EXIT_FAILURE) { + status = EXIT_FAILURE; + } + + clear8(hash, 200); + OQS_SHA3_sha3_384_inc_ctx_reset(&state); + OQS_SHA3_sha3_384_inc_absorb(&state, msg1600, 200); + OQS_SHA3_sha3_384_inc_finalize(hash, &state); + OQS_SHA3_sha3_384_inc_ctx_release(&state); + + if (are_equal8(hash, exp1600, 48) == EXIT_FAILURE) { + status = EXIT_FAILURE; + } + + return status; +} + /** * \brief Tests the 512 bit version of the keccak message digest for correct operation, * using selected vectors from NIST Fips202 and alternative references. @@ -1054,12 +1183,31 @@ static void override_SHA3_shake128_x4_inc_init(OQS_SHA3_shake128_x4_inc_ctx *sta sha3_x4_default_callbacks.SHA3_shake128_x4_inc_init(state); } +#ifdef OQS_USE_SHA3_AVX512VL +/** + * \brief Trigger SHA3 internal callback dispatcher. + * + * This is required to trigger runtime AVX512/AVX2 detection and set + * SHA3 default callbacks before test application callbacks are configured. + */ +static void sha3_trigger_dispatcher(void) { + OQS_SHA3_sha3_256_inc_ctx state; + + OQS_SHA3_sha3_256_inc_init(&state); + OQS_SHA3_sha3_256_inc_ctx_release(&state); +} +#endif + /** * \brief Run the SHA3 and SHAKE KAT tests */ int main(UNUSED int argc, UNUSED char **argv) { int ret = EXIT_SUCCESS; +#ifdef OQS_USE_SHA3_AVX512VL + /* set SHA3 default callbacks */ + sha3_trigger_dispatcher(); +#endif struct OQS_SHA3_callbacks sha3_callbacks = sha3_default_callbacks; sha3_callbacks.SHA3_sha3_256_inc_init = override_SHA3_sha3_256_inc_init; @@ -1080,6 +1228,13 @@ int main(UNUSED int argc, UNUSED char **argv) { ret = EXIT_FAILURE; } + if (sha3_384_kat_test() == EXIT_SUCCESS) { + printf("Success! passed sha3-384 known answer tests \n"); + } else { + printf("Failure! failed sha3-384 known answer tests \n"); + ret = EXIT_FAILURE; + } + if (sha3_512_kat_test() == EXIT_SUCCESS) { printf("Success! passed sha3-512 known answer tests \n"); } else {