ml: Store decoded public/private key and matrix A on initiator

While this does require quite a bit of memory, on initiators there are
usually fewer concurrent SAs getting created so this should be less of
an issue than on a gateway that handles lots of SAs as responder.

The speed up is about 30% on the initiator during the decapsulation,
while the key generation does take a bit more time (about 3%).
This commit is contained in:
Tobias Brunner 2024-10-28 15:12:32 +01:00
parent 89f4b345e3
commit 9de4efb1ae

View File

@ -43,10 +43,22 @@ struct private_key_exchange_t {
const ml_kem_params_t *params;
/**
* Decryption/private key as initiator.
* Decryption/private key as initiator (array of k polynomials).
*/
chunk_t private_key;
/**
* Encryption/public key and matrix A as initiator (array of k polynomials,
* followed by a matrix of k*k polynomials).
*/
chunk_t public_key;
/**
* Additional key data as initiator (hash of encoded public key,
* rejection seed z).
*/
chunk_t key_data;
/**
* Ciphertext as responder.
*/
@ -507,12 +519,12 @@ static void poly_to_message(ml_poly_t *p, uint8_t *m)
}
/**
* Generate a key pair from the given random seed d.
* Generate a key pair from the given random seed d. Returns the encoded public
* key.
*
* Algorithm 13 in FIPS 203.
*/
static bool pke_keygen(private_key_exchange_t *this, chunk_t d, chunk_t *ek,
chunk_t *dk)
static bool pke_keygen(private_key_exchange_t *this, chunk_t d, chunk_t *ek)
{
const uint8_t k = this->params->k;
const uint8_t eta1 = this->params->eta1;
@ -521,7 +533,7 @@ static bool pke_keygen(private_key_exchange_t *this, chunk_t d, chunk_t *ek,
uint8_t *rho = seeds;
uint8_t *sigma = seeds + ML_KEM_SEED_LEN;
uint8_t N = 0;
ml_poly_t a[k*k], s[k], e[k], t[k];
ml_poly_t *a, *s, e[k], *t;
int i;
bool success = FALSE;
@ -533,12 +545,19 @@ static bool pke_keygen(private_key_exchange_t *this, chunk_t d, chunk_t *ek,
goto err;
}
this->public_key = chunk_alloc((k+1) * k * sizeof(ml_poly_t));
t = (ml_poly_t*)this->public_key.ptr;
a = (ml_poly_t*)this->public_key.ptr + k;
/* generate matrix A */
if (!generate_a(this, a, rho))
{
goto err;
}
this->private_key = chunk_alloc(k * sizeof(ml_poly_t));
s = (ml_poly_t*)this->private_key.ptr;
/* sample s from CBD using noise seed sigma and nonce N as input */
for (i = 0; i < k; i++)
{
@ -575,16 +594,11 @@ static bool pke_keygen(private_key_exchange_t *this, chunk_t d, chunk_t *ek,
encode_poly_arr(k, t, ek->ptr);
memcpy(ek->ptr + k * ML_KEM_POLY_LEN, rho, ML_KEM_SEED_LEN);
/* pack private key */
*dk = chunk_alloc(k * ML_KEM_POLY_LEN);
encode_poly_arr(k, s, dk->ptr);
success = TRUE;
err:
memwipe(seeds, sizeof(seeds));
memwipe(sigma, ML_KEM_SEED_LEN);
memwipe(s, sizeof(s));
memwipe(e, sizeof(e));
return success;
}
@ -606,18 +620,28 @@ static bool pke_encrypt(private_key_exchange_t *this, chunk_t ek, uint8_t *m,
uint8_t rho[ML_KEM_SEED_LEN];
uint8_t N = 0;
ml_poly_t a[k*k], t[k], y[k], e1[k], e2, u[k], mu, v;
ml_poly_t a_gen[k*k], *a = a_gen, t_dec[k], *t = t_dec;
ml_poly_t y[k], e1[k], e2, u[k], mu, v;
int i;
bool success = FALSE;
/* decode polynomial t and extract seed rho from the public key */
decode_poly_arr(k, ek.ptr, t);
memcpy(rho, ek.ptr + k * ML_KEM_POLY_LEN, ML_KEM_SEED_LEN);
/* generate matrix A */
if (!generate_a(this, a, rho))
if (!this->public_key.ptr)
{
goto err;
/* decode polynomial t and extract seed rho from the public key */
decode_poly_arr(k, ek.ptr, t);
memcpy(rho, ek.ptr + k * ML_KEM_POLY_LEN, ML_KEM_SEED_LEN);
/* generate matrix A */
if (!generate_a(this, a, rho))
{
goto err;
}
}
else
{
/* as initiator, we already have the decoded polynomial and matrix A */
t = (ml_poly_t*)this->public_key.ptr;
a = (ml_poly_t*)this->public_key.ptr + k;
}
/* sample y from CBD using noise seed r and nonce N as input */
@ -675,26 +699,26 @@ err:
}
/**
* Decrypt message m using the given private key and ciphertext.
* Decrypt message m using the stored private key and given ciphertext.
*
* Algorithm 14 in FIPS 203.
*/
static bool pke_decrypt(private_key_exchange_t *this, chunk_t dk,
chunk_t ciphertext, uint8_t *m)
static bool pke_decrypt(private_key_exchange_t *this, chunk_t ciphertext,
uint8_t *m)
{
const uint8_t k = this->params->k;
const uint8_t du = this->params->du;
const uint8_t dv = this->params->dv;
ml_poly_t s[k], u[k], v, w;
ml_poly_t *s, u[k], v, w;
int i;
/* decode u and v from c1 and c2, the two parts of the ciphertext */
decompress_poly_arr(k, du, ciphertext.ptr, u);
decompress_poly_arr(1, dv, ciphertext.ptr + k * du * ML_KEM_N / 8, &v);
/* decode polynomial s from private key */
decode_poly_arr(k, dk.ptr, s);
/* we already have private key s stored */
s = (ml_poly_t*)this->private_key.ptr;
/* calculate w = v - NTT^-1(s * NTT(u)) */
for (i = 0; i < k; i++)
@ -707,9 +731,6 @@ static bool pke_decrypt(private_key_exchange_t *this, chunk_t dk,
/* decode plaintext message m from polynomial w */
poly_to_message(&w, m);
memwipe(s, sizeof(s));
memwipe(&w, sizeof(w));
return TRUE;
}
@ -723,7 +744,7 @@ static bool generate_keypair(private_key_exchange_t *this, chunk_t *ek)
uint8_t dz[2*ML_KEM_SEED_LEN];
chunk_t d = chunk_create(dz, ML_KEM_SEED_LEN);
chunk_t z = chunk_create(dz + ML_KEM_SEED_LEN, ML_KEM_SEED_LEN);
chunk_t dk = chunk_empty, Hek;
chunk_t Hek;
bool success = FALSE;
/* get random seeds d and z */
@ -732,17 +753,16 @@ static bool generate_keypair(private_key_exchange_t *this, chunk_t *ek)
return FALSE;
}
/* generate a key pair and store the private key, the public key, a hash
* of the latter and seed z as our secret key */
if (pke_keygen(this, d, ek, &dk) &&
/* generate a key pair and generate a hash of the latter to be stored
* together with the rejection seed z */
if (pke_keygen(this, d, ek) &&
this->H->allocate_hash(this->H, *ek, &Hek))
{
this->private_key = chunk_cat("ccmc", dk, *ek, Hek, z);
this->key_data = chunk_cat("mc", Hek, z);
success = TRUE;
}
memwipe(dz, sizeof(dz));
chunk_clear(&dk);
return success;
}
@ -769,25 +789,21 @@ METHOD(key_exchange_t, get_public_key, bool,
*/
static bool decaps_shared_secret(private_key_exchange_t *this, chunk_t ciphertext)
{
const uint8_t k = this->params->k;
chunk_t dk, ek, Hek, z, zc, c = chunk_empty;
chunk_t Hek, z, zc, c = chunk_empty;
chunk_t m = chunk_alloca(ML_KEM_SEED_LEN);
uint8_t Kr[2*ML_KEM_SEED_LEN];
uint8_t *r = Kr + ML_KEM_SEED_LEN;
bool success = FALSE;
/* get the private and public keys, a hash of the latter and seed z */
chunk_split(this->private_key, "mmmm",
k * ML_KEM_POLY_LEN, &dk,
k * ML_KEM_POLY_LEN + ML_KEM_SEED_LEN, &ek,
/* get the hash of the encoded public key and seed z */
chunk_split(this->key_data, "mm",
ML_KEM_SEED_LEN, &Hek,
ML_KEM_SEED_LEN, &z);
/* prepare the seed to derive the implicit rejection secret */
zc = chunk_cat("cc", z, ciphertext);
/* decrypt message m */
if (!pke_decrypt(this, dk, ciphertext, m.ptr))
if (!pke_decrypt(this, ciphertext, m.ptr))
{
goto err;
}
@ -801,7 +817,7 @@ static bool decaps_shared_secret(private_key_exchange_t *this, chunk_t ciphertex
/* encrypt the decrypted message again using the derived r */
c = chunk_alloc(this->params->ct_len);
if (!pke_encrypt(this, ek, m.ptr, r, c))
if (!pke_encrypt(this, chunk_empty, m.ptr, r, c))
{
goto err;
}
@ -936,8 +952,10 @@ METHOD(key_exchange_t, set_seed, bool,
METHOD(key_exchange_t, destroy, void,
private_key_exchange_t *this)
{
chunk_clear(&this->shared_secret);
chunk_clear(&this->private_key);
chunk_clear(&this->key_data);
chunk_clear(&this->shared_secret);
chunk_free(&this->public_key);
chunk_free(&this->ciphertext);
DESTROY_IF(this->drbg);
DESTROY_IF(this->shake128);