liboqs/tests/test_kex.c
2018-10-02 17:26:22 -04:00

411 lines
12 KiB
C

#if defined(_WIN32)
#pragma warning(disable : 4244 4293)
#endif
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <oqs/oqs.h>
#include "ds_benchmark.h"
#if __ANDROID__
#define eprintf(...) printf(__VA_ARGS__);
#else
#define eprintf(...) fprintf(stderr, __VA_ARGS__);
#endif
struct kex_testcase {
enum OQS_KEX_alg_name alg_name;
unsigned char *seed;
size_t seed_len;
const char *named_parameters;
const char *id;
int run;
int iter;
};
/* Add new testcases here */
struct kex_testcase kex_testcases[] = {
#ifndef DISABLE_NTRU_ON_WINDOWS_BY_DEFAULT
#ifdef ENABLE_KEX_NTRU
{OQS_KEX_alg_ntru, NULL, 0, NULL, "ntru", 0, 25},
#endif
#endif
};
#define KEX_TEST_ITERATIONS 100
#define KEX_BENCH_SECONDS_DEFAULT 1
static OQS_STATUS kex_test_correctness(OQS_RAND *rand, enum OQS_KEX_alg_name alg_name, const uint8_t *seed, const size_t seed_len, const char *named_parameters, const int print, unsigned long occurrences[256]) {
OQS_KEX *kex = NULL;
OQS_STATUS rc;
void *alice_priv = NULL;
uint8_t *alice_msg = NULL;
size_t alice_msg_len;
uint8_t *alice_key = NULL;
size_t alice_key_len = 0;
uint8_t *bob_msg = NULL;
size_t bob_msg_len;
uint8_t *bob_key = NULL;
size_t bob_key_len = 0;
/* setup KEX */
kex = OQS_KEX_new(rand, alg_name, seed, seed_len, named_parameters);
if (kex == NULL) {
eprintf("new_method failed\n");
goto err;
}
if (print) {
printf("================================================================================\n");
printf("Sample computation for key exchange method %s\n", kex->method_name);
printf("================================================================================\n");
}
/* Alice's initial message */
rc = OQS_KEX_alice_0(kex, &alice_priv, &alice_msg, &alice_msg_len);
if (rc != OQS_SUCCESS) {
eprintf("OQS_KEX_alice_0 failed\n");
goto err;
}
if (print) {
OQS_print_part_hex_string("Alice message", alice_msg, alice_msg_len, 20);
}
/* Bob's response */
rc = OQS_KEX_bob(kex, alice_msg, alice_msg_len, &bob_msg, &bob_msg_len, &bob_key, &bob_key_len);
if (rc != OQS_SUCCESS) {
eprintf("OQS_KEX_bob failed\n");
goto err;
}
if (print) {
OQS_print_part_hex_string("Bob message", bob_msg, bob_msg_len, 20);
OQS_print_hex_string("Bob session key", bob_key, bob_key_len);
}
/* Alice processes Bob's response */
rc = OQS_KEX_alice_1(kex, alice_priv, bob_msg, bob_msg_len, &alice_key, &alice_key_len);
if (rc != OQS_SUCCESS) {
eprintf("OQS_KEX_alice_1 failed\n");
goto err;
}
if (print) {
OQS_print_hex_string("Alice session key", alice_key, alice_key_len);
}
/* compare session key lengths and values */
if (alice_key_len != bob_key_len) {
eprintf("ERROR: Alice's session key and Bob's session key are different lengths (%zu vs %zu)\n", alice_key_len, bob_key_len);
goto err;
}
if (memcmp(alice_key, bob_key, alice_key_len) != 0) {
eprintf("ERROR: Alice's session key and Bob's session key are not equal\n");
OQS_print_hex_string("Alice session key", alice_key, alice_key_len);
OQS_print_hex_string("Bob session key", bob_key, bob_key_len);
goto err;
}
if (print) {
printf("Alice and Bob's session keys match.\n");
printf("\n\n");
}
/* record generated bytes for statistical analysis */
for (size_t i = 0; i < alice_key_len; i++) {
OQS_RAND_test_record_occurrence(alice_key[i], occurrences);
}
rc = OQS_SUCCESS;
goto cleanup;
err:
rc = OQS_ERROR;
cleanup:
OQS_MEM_insecure_free(alice_msg);
OQS_MEM_secure_free(alice_key, alice_key_len);
OQS_MEM_insecure_free(bob_msg);
OQS_MEM_secure_free(bob_key, bob_key_len);
OQS_KEX_alice_priv_free(kex, alice_priv);
OQS_KEX_free(kex);
return rc;
}
static OQS_STATUS kex_test_correctness_wrapper(OQS_RAND *rand, enum OQS_KEX_alg_name alg_name, const uint8_t *seed, const size_t seed_len, const char *named_parameters, int iterations, bool quiet) {
OQS_KEX *kex = NULL;
OQS_STATUS ret;
unsigned long occurrences[256];
for (int i = 0; i < 256; i++) {
occurrences[i] = 0;
}
ret = kex_test_correctness(rand, alg_name, seed, seed_len, named_parameters, quiet ? 0 : 1, occurrences);
if (ret != OQS_SUCCESS) {
goto err;
}
/* setup KEX */
kex = OQS_KEX_new(rand, alg_name, seed, seed_len, named_parameters);
if (kex == NULL) {
goto err;
}
printf("================================================================================\n");
printf("Testing correctness and randomness of key exchange method %s (params=%s) for %d iterations\n",
kex->method_name, named_parameters, iterations);
printf("================================================================================\n");
for (int i = 0; i < iterations; i++) {
ret = kex_test_correctness(rand, alg_name, seed, seed_len, named_parameters, 0, occurrences);
if (ret != OQS_SUCCESS) {
goto err;
}
}
printf("All session keys matched.\n");
OQS_RAND_report_statistics(occurrences, "");
printf("\n\n");
ret = OQS_SUCCESS;
goto cleanup;
err:
ret = OQS_ERROR;
cleanup:
OQS_KEX_free(kex);
return ret;
}
static void cleanup_alice_0(OQS_KEX *kex, void *alice_priv, uint8_t *alice_msg) {
OQS_MEM_insecure_free(alice_msg);
OQS_KEX_alice_priv_free(kex, alice_priv);
}
static void cleanup_bob(uint8_t *bob_msg, uint8_t *bob_key, size_t bob_key_len) {
OQS_MEM_insecure_free(bob_msg);
OQS_MEM_secure_free(bob_key, bob_key_len);
}
static OQS_STATUS kex_bench_wrapper(OQS_RAND *rand, enum OQS_KEX_alg_name alg_name, const uint8_t *seed, const size_t seed_len, const char *named_parameters, const size_t seconds) {
OQS_KEX *kex = NULL;
OQS_STATUS rc;
void *alice_priv = NULL;
uint8_t *alice_msg = NULL;
size_t alice_msg_len;
uint8_t *alice_key = NULL;
size_t alice_key_len = 0;
uint8_t *bob_msg = NULL;
size_t bob_msg_len;
uint8_t *bob_key = NULL;
size_t bob_key_len = 0;
/* setup KEX */
kex = OQS_KEX_new(rand, alg_name, seed, seed_len, named_parameters);
if (kex == NULL) {
eprintf("new_method failed\n");
goto err;
}
printf("%-30s | %10s | %14s | %15s | %10s | %16s | %10s\n", kex->method_name, "", "", "", "", "", "");
TIME_OPERATION_SECONDS({ OQS_KEX_alice_0(kex, &alice_priv, &alice_msg, &alice_msg_len); cleanup_alice_0(kex, alice_priv, alice_msg); }, "alice 0", seconds);
OQS_KEX_alice_0(kex, &alice_priv, &alice_msg, &alice_msg_len);
TIME_OPERATION_SECONDS({ OQS_KEX_bob(kex, alice_msg, alice_msg_len, &bob_msg, &bob_msg_len, &bob_key, &bob_key_len); cleanup_bob(bob_msg, bob_key, bob_key_len); }, "bob", seconds);
OQS_KEX_bob(kex, alice_msg, alice_msg_len, &bob_msg, &bob_msg_len, &bob_key, &bob_key_len);
TIME_OPERATION_SECONDS({ OQS_KEX_alice_1(kex, alice_priv, bob_msg, bob_msg_len, &alice_key, &alice_key_len); OQS_MEM_secure_free(alice_key, alice_key_len); }, "alice 1", seconds);
alice_key = NULL;
printf("Communication (bytes): A->B: %zu, B->A: %zu, total: %zu; classical/quantum security bits [%u:%u] \n", alice_msg_len, bob_msg_len, alice_msg_len + bob_msg_len, kex->estimated_classical_security, kex->estimated_quantum_security);
rc = OQS_SUCCESS;
goto cleanup;
err:
rc = OQS_ERROR;
cleanup:
OQS_MEM_insecure_free(alice_msg);
OQS_MEM_secure_free(alice_key, alice_key_len);
OQS_MEM_insecure_free(bob_msg);
OQS_MEM_secure_free(bob_key, bob_key_len);
OQS_KEX_alice_priv_free(kex, alice_priv);
OQS_KEX_free(kex);
return rc;
}
static OQS_STATUS kex_mem_bench_wrapper(OQS_RAND *rand, enum OQS_KEX_alg_name alg_name, const uint8_t *seed, const size_t seed_len, const char *named_parameters) {
OQS_KEX *kex = NULL;
OQS_STATUS rc;
void *alice_priv = NULL;
uint8_t *alice_msg = NULL;
size_t alice_msg_len;
uint8_t *alice_key = NULL;
size_t alice_key_len = 0;
uint8_t *bob_msg = NULL;
size_t bob_msg_len;
uint8_t *bob_key = NULL;
size_t bob_key_len = 0;
kex = OQS_KEX_new(rand, alg_name, seed, seed_len, named_parameters);
if (kex == NULL) {
fprintf(stderr, "new_method failed\n");
goto err;
}
printf("running %s..\n", kex->method_name);
OQS_KEX_alice_0(kex, &alice_priv, &alice_msg, &alice_msg_len);
OQS_KEX_bob(kex, alice_msg, alice_msg_len, &bob_msg, &bob_msg_len, &bob_key, &bob_key_len);
OQS_KEX_alice_1(kex, alice_priv, bob_msg, bob_msg_len, &alice_key, &alice_key_len);
rc = OQS_SUCCESS;
goto cleanup;
err:
rc = OQS_ERROR;
cleanup:
OQS_MEM_insecure_free(alice_msg);
OQS_MEM_secure_free(alice_key, alice_key_len);
OQS_MEM_insecure_free(bob_msg);
OQS_MEM_secure_free(bob_key, bob_key_len);
OQS_KEX_alice_priv_free(kex, alice_priv);
OQS_KEX_free(kex);
return rc;
}
void print_help() {
printf("Usage: ./test_kex [options] [algorithms]\n");
printf("\nOptions:\n");
printf(" --quiet, -q\n");
printf(" Less verbose output\n");
printf(" --bench, -b\n");
printf(" Run benchmarks\n");
printf(" --seconds -s [SECONDS]\n");
printf(" Number of seconds to run benchmarks (default==%d)\n", KEX_BENCH_SECONDS_DEFAULT);
printf(" --mem-bench\n");
printf(" Run memory benchmarks (run once and allocate only what is required)\n");
printf("\nalgorithms:\n");
size_t kex_testcases_len = sizeof(kex_testcases) / sizeof(struct kex_testcase);
for (size_t i = 0; i < kex_testcases_len; i++) {
printf(" %s\n", kex_testcases[i].id);
}
}
int main(int argc, char **argv) {
OQS_STATUS success = OQS_SUCCESS;
bool run_all = true;
bool quiet = false;
bool bench = false;
bool mem_bench = false;
size_t kex_testcases_len = sizeof(kex_testcases) / sizeof(struct kex_testcase);
size_t kex_bench_seconds = KEX_BENCH_SECONDS_DEFAULT;
for (int i = 1; i < argc; i++) {
if (argv[i][0] == '-') {
if ((strcmp(argv[i], "-h") == 0) || (strcmp(argv[i], "-help") == 0) || (strcmp(argv[i], "--help") == 0)) {
print_help();
return EXIT_SUCCESS;
} else if (strcmp(argv[i], "--quiet") == 0 || strcmp(argv[i], "-q") == 0) {
quiet = true;
} else if (strcmp(argv[i], "--bench") == 0 || strcmp(argv[i], "-b") == 0) {
bench = true;
} else if (strcmp(argv[i], "--seconds") == 0 || strcmp(argv[i], "-s") == 0) {
if (++i == argc) {
print_help();
return EXIT_SUCCESS;
}
char *end;
int kex_bench_seconds_input = strtol(argv[i], &end, 10);
if (kex_bench_seconds_input < 1) {
print_help();
return EXIT_SUCCESS;
}
kex_bench_seconds = kex_bench_seconds_input;
} else if ((strcmp(argv[i], "--mem-bench") == 0 || strcmp(argv[i], "-m") == 0)) {
mem_bench = true;
}
} else {
run_all = false;
for (size_t j = 0; j < kex_testcases_len; j++) {
if (strcmp(argv[i], kex_testcases[j].id) == 0) {
kex_testcases[j].run = 1;
}
}
}
}
/* setup RAND */
OQS_RAND *rand = OQS_RAND_new(OQS_RAND_alg_urandom_chacha20);
if (rand == NULL) {
goto err;
}
if (mem_bench) {
for (size_t i = 0; i < kex_testcases_len; i++) {
if (run_all || kex_testcases[i].run == 1) {
success = kex_mem_bench_wrapper(rand, kex_testcases[i].alg_name, kex_testcases[i].seed, kex_testcases[i].seed_len, kex_testcases[i].named_parameters);
}
if (success != OQS_SUCCESS) {
goto err;
}
}
printf("memory benchmarks done, exiting..\n");
success = OQS_SUCCESS;
goto cleanup;
}
for (size_t i = 0; i < kex_testcases_len; i++) {
if (run_all || kex_testcases[i].run == 1) {
int num_iter = kex_testcases[i].iter;
success = kex_test_correctness_wrapper(rand, kex_testcases[i].alg_name, kex_testcases[i].seed, kex_testcases[i].seed_len, kex_testcases[i].named_parameters, num_iter, quiet);
}
if (success != OQS_SUCCESS) {
goto err;
}
}
if (bench) {
PRINT_TIMER_HEADER
for (size_t i = 0; i < kex_testcases_len; i++) {
if (run_all || kex_testcases[i].run == 1) {
kex_bench_wrapper(rand, kex_testcases[i].alg_name, kex_testcases[i].seed, kex_testcases[i].seed_len, kex_testcases[i].named_parameters, kex_bench_seconds);
}
}
PRINT_TIMER_FOOTER
}
success = OQS_SUCCESS;
goto cleanup;
err:
success = OQS_ERROR;
eprintf("ERROR!\n");
cleanup:
OQS_RAND_free(rand);
return (success == OQS_SUCCESS) ? EXIT_SUCCESS : EXIT_FAILURE;
}