diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c index c9bab85e82f..126eb70974a 100644 --- a/src/backend/libpq/auth-scram.c +++ b/src/backend/libpq/auth-scram.c @@ -141,10 +141,14 @@ typedef struct Port *port; bool channel_binding_in_use; + /* State data depending on the hash type */ + pg_cryptohash_type hash_type; + int key_length; + int iterations; char *salt; /* base64-encoded */ - uint8 StoredKey[SCRAM_KEY_LEN]; - uint8 ServerKey[SCRAM_KEY_LEN]; + uint8 StoredKey[SCRAM_MAX_KEY_LEN]; + uint8 ServerKey[SCRAM_MAX_KEY_LEN]; /* Fields of the first message from client */ char cbind_flag; @@ -155,7 +159,7 @@ typedef struct /* Fields from the last message from client */ char *client_final_message_without_proof; char *client_final_nonce; - char ClientProof[SCRAM_KEY_LEN]; + char ClientProof[SCRAM_MAX_KEY_LEN]; /* Fields generated in the server */ char *server_first_message; @@ -177,12 +181,15 @@ static char *build_server_first_message(scram_state *state); static char *build_server_final_message(scram_state *state); static bool verify_client_proof(scram_state *state); static bool verify_final_nonce(scram_state *state); -static void mock_scram_secret(const char *username, int *iterations, - char **salt, uint8 *stored_key, uint8 *server_key); +static void mock_scram_secret(const char *username, pg_cryptohash_type *hash_type, + int *iterations, int *key_length, char **salt, + uint8 *stored_key, uint8 *server_key); static bool is_scram_printable(char *p); static char *sanitize_char(char c); static char *sanitize_str(const char *s); -static char *scram_mock_salt(const char *username); +static char *scram_mock_salt(const char *username, + pg_cryptohash_type hash_type, + int key_length); /* * Get a list of SASL mechanisms that this module supports. @@ -266,8 +273,11 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass) if (password_type == PASSWORD_TYPE_SCRAM_SHA_256) { - if (parse_scram_secret(shadow_pass, &state->iterations, &state->salt, - state->StoredKey, state->ServerKey)) + if (parse_scram_secret(shadow_pass, &state->iterations, + &state->hash_type, &state->key_length, + &state->salt, + state->StoredKey, + state->ServerKey)) got_secret = true; else { @@ -310,8 +320,10 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass) */ if (!got_secret) { - mock_scram_secret(state->port->user_name, &state->iterations, - &state->salt, state->StoredKey, state->ServerKey); + mock_scram_secret(state->port->user_name, &state->hash_type, + &state->iterations, &state->key_length, + &state->salt, + state->StoredKey, state->ServerKey); state->doomed = true; } @@ -482,7 +494,8 @@ pg_be_scram_build_secret(const char *password) (errcode(ERRCODE_INTERNAL_ERROR), errmsg("could not generate random salt"))); - result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN, + result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN, + saltbuf, SCRAM_DEFAULT_SALT_LEN, SCRAM_DEFAULT_ITERATIONS, password, &errstr); @@ -505,16 +518,18 @@ scram_verify_plain_password(const char *username, const char *password, char *salt; int saltlen; int iterations; - uint8 salted_password[SCRAM_KEY_LEN]; - uint8 stored_key[SCRAM_KEY_LEN]; - uint8 server_key[SCRAM_KEY_LEN]; - uint8 computed_key[SCRAM_KEY_LEN]; + int key_length = 0; + pg_cryptohash_type hash_type; + uint8 salted_password[SCRAM_MAX_KEY_LEN]; + uint8 stored_key[SCRAM_MAX_KEY_LEN]; + uint8 server_key[SCRAM_MAX_KEY_LEN]; + uint8 computed_key[SCRAM_MAX_KEY_LEN]; char *prep_password; pg_saslprep_rc rc; const char *errstr = NULL; - if (!parse_scram_secret(secret, &iterations, &encoded_salt, - stored_key, server_key)) + if (!parse_scram_secret(secret, &iterations, &hash_type, &key_length, + &encoded_salt, stored_key, server_key)) { /* * The password looked like a SCRAM secret, but could not be parsed. @@ -541,9 +556,11 @@ scram_verify_plain_password(const char *username, const char *password, password = prep_password; /* Compute Server Key based on the user-supplied plaintext password */ - if (scram_SaltedPassword(password, salt, saltlen, iterations, + if (scram_SaltedPassword(password, hash_type, key_length, + salt, saltlen, iterations, salted_password, &errstr) < 0 || - scram_ServerKey(salted_password, computed_key, &errstr) < 0) + scram_ServerKey(salted_password, hash_type, key_length, + computed_key, &errstr) < 0) { elog(ERROR, "could not compute server key: %s", errstr); } @@ -555,7 +572,7 @@ scram_verify_plain_password(const char *username, const char *password, * Compare the secret's Server Key with the one computed from the * user-supplied password. */ - return memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0; + return memcmp(computed_key, server_key, key_length) == 0; } @@ -565,14 +582,15 @@ scram_verify_plain_password(const char *username, const char *password, * On success, the iteration count, salt, stored key, and server key are * extracted from the secret, and returned to the caller. For 'stored_key' * and 'server_key', the caller must pass pre-allocated buffers of size - * SCRAM_KEY_LEN. Salt is returned as a base64-encoded, null-terminated + * SCRAM_MAX_KEY_LEN. Salt is returned as a base64-encoded, null-terminated * string. The buffer for the salt is palloc'd by this function. * * Returns true if the SCRAM secret has been parsed, and false otherwise. */ bool -parse_scram_secret(const char *secret, int *iterations, char **salt, - uint8 *stored_key, uint8 *server_key) +parse_scram_secret(const char *secret, int *iterations, + pg_cryptohash_type *hash_type, int *key_length, + char **salt, uint8 *stored_key, uint8 *server_key) { char *v; char *p; @@ -606,6 +624,8 @@ parse_scram_secret(const char *secret, int *iterations, char **salt, /* Parse the fields */ if (strcmp(scheme_str, "SCRAM-SHA-256") != 0) goto invalid_secret; + *hash_type = PG_SHA256; + *key_length = SCRAM_SHA_256_KEY_LEN; errno = 0; *iterations = strtol(iterations_str, &p, 10); @@ -631,17 +651,17 @@ parse_scram_secret(const char *secret, int *iterations, char **salt, decoded_stored_buf = palloc(decoded_len); decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str), decoded_stored_buf, decoded_len); - if (decoded_len != SCRAM_KEY_LEN) + if (decoded_len != *key_length) goto invalid_secret; - memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN); + memcpy(stored_key, decoded_stored_buf, *key_length); decoded_len = pg_b64_dec_len(strlen(serverkey_str)); decoded_server_buf = palloc(decoded_len); decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str), decoded_server_buf, decoded_len); - if (decoded_len != SCRAM_KEY_LEN) + if (decoded_len != *key_length) goto invalid_secret; - memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN); + memcpy(server_key, decoded_server_buf, *key_length); return true; @@ -655,20 +675,25 @@ invalid_secret: * * In a normal authentication, these are extracted from the secret * stored in the server. This function generates values that look - * realistic, for when there is no stored secret. + * realistic, for when there is no stored secret, using SCRAM-SHA-256. * * Like in parse_scram_secret(), for 'stored_key' and 'server_key', the - * caller must pass pre-allocated buffers of size SCRAM_KEY_LEN, and + * caller must pass pre-allocated buffers of size SCRAM_MAX_KEY_LEN, and * the buffer for the salt is palloc'd by this function. */ static void -mock_scram_secret(const char *username, int *iterations, char **salt, +mock_scram_secret(const char *username, pg_cryptohash_type *hash_type, + int *iterations, int *key_length, char **salt, uint8 *stored_key, uint8 *server_key) { char *raw_salt; char *encoded_salt; int encoded_len; + /* Enforce the use of SHA-256, which would be realistic enough */ + *hash_type = PG_SHA256; + *key_length = SCRAM_SHA_256_KEY_LEN; + /* * Generate deterministic salt. * @@ -677,7 +702,7 @@ mock_scram_secret(const char *username, int *iterations, char **salt, * as the salt generated for mock authentication uses the cluster's nonce * value. */ - raw_salt = scram_mock_salt(username); + raw_salt = scram_mock_salt(username, *hash_type, *key_length); if (raw_salt == NULL) elog(ERROR, "could not encode salt"); @@ -695,8 +720,8 @@ mock_scram_secret(const char *username, int *iterations, char **salt, *iterations = SCRAM_DEFAULT_ITERATIONS; /* StoredKey and ServerKey are not used in a doomed authentication */ - memset(stored_key, 0, SCRAM_KEY_LEN); - memset(server_key, 0, SCRAM_KEY_LEN); + memset(stored_key, 0, SCRAM_MAX_KEY_LEN); + memset(server_key, 0, SCRAM_MAX_KEY_LEN); } /* @@ -1111,10 +1136,10 @@ verify_final_nonce(scram_state *state) static bool verify_client_proof(scram_state *state) { - uint8 ClientSignature[SCRAM_KEY_LEN]; - uint8 ClientKey[SCRAM_KEY_LEN]; - uint8 client_StoredKey[SCRAM_KEY_LEN]; - pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256); + uint8 ClientSignature[SCRAM_MAX_KEY_LEN]; + uint8 ClientKey[SCRAM_MAX_KEY_LEN]; + uint8 client_StoredKey[SCRAM_MAX_KEY_LEN]; + pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type); int i; const char *errstr = NULL; @@ -1123,7 +1148,7 @@ verify_client_proof(scram_state *state) * here even when processing the calculations as this could involve a mock * authentication. */ - if (pg_hmac_init(ctx, state->StoredKey, SCRAM_KEY_LEN) < 0 || + if (pg_hmac_init(ctx, state->StoredKey, state->key_length) < 0 || pg_hmac_update(ctx, (uint8 *) state->client_first_message_bare, strlen(state->client_first_message_bare)) < 0 || @@ -1135,7 +1160,7 @@ verify_client_proof(scram_state *state) pg_hmac_update(ctx, (uint8 *) state->client_final_message_without_proof, strlen(state->client_final_message_without_proof)) < 0 || - pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0) + pg_hmac_final(ctx, ClientSignature, state->key_length) < 0) { elog(ERROR, "could not calculate client signature: %s", pg_hmac_error(ctx)); @@ -1144,14 +1169,15 @@ verify_client_proof(scram_state *state) pg_hmac_free(ctx); /* Extract the ClientKey that the client calculated from the proof */ - for (i = 0; i < SCRAM_KEY_LEN; i++) + for (i = 0; i < state->key_length; i++) ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i]; /* Hash it one more time, and compare with StoredKey */ - if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey, &errstr) < 0) + if (scram_H(ClientKey, state->hash_type, state->key_length, + client_StoredKey, &errstr) < 0) elog(ERROR, "could not hash stored key: %s", errstr); - if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0) + if (memcmp(client_StoredKey, state->StoredKey, state->key_length) != 0) return false; return true; @@ -1349,12 +1375,12 @@ read_client_final_message(scram_state *state, const char *input) client_proof_len = pg_b64_dec_len(strlen(value)); client_proof = palloc(client_proof_len); if (pg_b64_decode(value, strlen(value), client_proof, - client_proof_len) != SCRAM_KEY_LEN) + client_proof_len) != state->key_length) ereport(ERROR, (errcode(ERRCODE_PROTOCOL_VIOLATION), errmsg("malformed SCRAM message"), errdetail("Malformed proof in client-final-message."))); - memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN); + memcpy(state->ClientProof, client_proof, state->key_length); pfree(client_proof); if (*p != '\0') @@ -1374,13 +1400,13 @@ read_client_final_message(scram_state *state, const char *input) static char * build_server_final_message(scram_state *state) { - uint8 ServerSignature[SCRAM_KEY_LEN]; + uint8 ServerSignature[SCRAM_MAX_KEY_LEN]; char *server_signature_base64; int siglen; - pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256); + pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type); /* calculate ServerSignature */ - if (pg_hmac_init(ctx, state->ServerKey, SCRAM_KEY_LEN) < 0 || + if (pg_hmac_init(ctx, state->ServerKey, state->key_length) < 0 || pg_hmac_update(ctx, (uint8 *) state->client_first_message_bare, strlen(state->client_first_message_bare)) < 0 || @@ -1392,7 +1418,7 @@ build_server_final_message(scram_state *state) pg_hmac_update(ctx, (uint8 *) state->client_final_message_without_proof, strlen(state->client_final_message_without_proof)) < 0 || - pg_hmac_final(ctx, ServerSignature, sizeof(ServerSignature)) < 0) + pg_hmac_final(ctx, ServerSignature, state->key_length) < 0) { elog(ERROR, "could not calculate server signature: %s", pg_hmac_error(ctx)); @@ -1400,11 +1426,11 @@ build_server_final_message(scram_state *state) pg_hmac_free(ctx); - siglen = pg_b64_enc_len(SCRAM_KEY_LEN); + siglen = pg_b64_enc_len(state->key_length); /* don't forget the zero-terminator */ server_signature_base64 = palloc(siglen + 1); siglen = pg_b64_encode((const char *) ServerSignature, - SCRAM_KEY_LEN, server_signature_base64, + state->key_length, server_signature_base64, siglen); if (siglen < 0) elog(ERROR, "could not encode server signature"); @@ -1431,10 +1457,11 @@ build_server_final_message(scram_state *state) * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL. */ static char * -scram_mock_salt(const char *username) +scram_mock_salt(const char *username, pg_cryptohash_type hash_type, + int key_length) { pg_cryptohash_ctx *ctx; - static uint8 sha_digest[PG_SHA256_DIGEST_LENGTH]; + static uint8 sha_digest[SCRAM_MAX_KEY_LEN]; char *mock_auth_nonce = GetMockAuthenticationNonce(); /* @@ -1446,11 +1473,17 @@ scram_mock_salt(const char *username) StaticAssertDecl(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN, "salt length greater than SHA256 digest length"); - ctx = pg_cryptohash_create(PG_SHA256); + /* + * This may be worth refreshing if support for more hash methods is\ + * added. + */ + Assert(hash_type == PG_SHA256); + + ctx = pg_cryptohash_create(hash_type); if (pg_cryptohash_init(ctx) < 0 || pg_cryptohash_update(ctx, (uint8 *) username, strlen(username)) < 0 || pg_cryptohash_update(ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 || - pg_cryptohash_final(ctx, sha_digest, sizeof(sha_digest)) < 0) + pg_cryptohash_final(ctx, sha_digest, key_length) < 0) { pg_cryptohash_free(ctx); return NULL; diff --git a/src/backend/libpq/crypt.c b/src/backend/libpq/crypt.c index 1ff8b0507d4..a81af0749a0 100644 --- a/src/backend/libpq/crypt.c +++ b/src/backend/libpq/crypt.c @@ -90,15 +90,17 @@ get_password_type(const char *shadow_pass) { char *encoded_salt; int iterations; - uint8 stored_key[SCRAM_KEY_LEN]; - uint8 server_key[SCRAM_KEY_LEN]; + int key_length = 0; + pg_cryptohash_type hash_type; + uint8 stored_key[SCRAM_MAX_KEY_LEN]; + uint8 server_key[SCRAM_MAX_KEY_LEN]; if (strncmp(shadow_pass, "md5", 3) == 0 && strlen(shadow_pass) == MD5_PASSWD_LEN && strspn(shadow_pass + 3, MD5_PASSWD_CHARSET) == MD5_PASSWD_LEN - 3) return PASSWORD_TYPE_MD5; - if (parse_scram_secret(shadow_pass, &iterations, &encoded_salt, - stored_key, server_key)) + if (parse_scram_secret(shadow_pass, &iterations, &hash_type, &key_length, + &encoded_salt, stored_key, server_key)) return PASSWORD_TYPE_SCRAM_SHA_256; return PASSWORD_TYPE_PLAINTEXT; } diff --git a/src/common/scram-common.c b/src/common/scram-common.c index 12686259299..bffbbb43172 100644 --- a/src/common/scram-common.c +++ b/src/common/scram-common.c @@ -33,6 +33,7 @@ */ int scram_SaltedPassword(const char *password, + pg_cryptohash_type hash_type, int key_length, const char *salt, int saltlen, int iterations, uint8 *result, const char **errstr) { @@ -40,9 +41,9 @@ scram_SaltedPassword(const char *password, uint32 one = pg_hton32(1); int i, j; - uint8 Ui[SCRAM_KEY_LEN]; - uint8 Ui_prev[SCRAM_KEY_LEN]; - pg_hmac_ctx *hmac_ctx = pg_hmac_create(PG_SHA256); + uint8 Ui[SCRAM_MAX_KEY_LEN]; + uint8 Ui_prev[SCRAM_MAX_KEY_LEN]; + pg_hmac_ctx *hmac_ctx = pg_hmac_create(hash_type); if (hmac_ctx == NULL) { @@ -60,30 +61,30 @@ scram_SaltedPassword(const char *password, if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 || pg_hmac_update(hmac_ctx, (uint8 *) salt, saltlen) < 0 || pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 || - pg_hmac_final(hmac_ctx, Ui_prev, sizeof(Ui_prev)) < 0) + pg_hmac_final(hmac_ctx, Ui_prev, key_length) < 0) { *errstr = pg_hmac_error(hmac_ctx); pg_hmac_free(hmac_ctx); return -1; } - memcpy(result, Ui_prev, SCRAM_KEY_LEN); + memcpy(result, Ui_prev, key_length); /* Subsequent iterations */ for (i = 2; i <= iterations; i++) { if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 || - pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, SCRAM_KEY_LEN) < 0 || - pg_hmac_final(hmac_ctx, Ui, sizeof(Ui)) < 0) + pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, key_length) < 0 || + pg_hmac_final(hmac_ctx, Ui, key_length) < 0) { *errstr = pg_hmac_error(hmac_ctx); pg_hmac_free(hmac_ctx); return -1; } - for (j = 0; j < SCRAM_KEY_LEN; j++) + for (j = 0; j < key_length; j++) result[j] ^= Ui[j]; - memcpy(Ui_prev, Ui, SCRAM_KEY_LEN); + memcpy(Ui_prev, Ui, key_length); } pg_hmac_free(hmac_ctx); @@ -92,16 +93,17 @@ scram_SaltedPassword(const char *password, /* - * Calculate SHA-256 hash for a NULL-terminated string. (The NULL terminator is + * Calculate hash for a NULL-terminated string. (The NULL terminator is * not included in the hash). Returns 0 on success, -1 on failure with *errstr * pointing to a message about the error details. */ int -scram_H(const uint8 *input, int len, uint8 *result, const char **errstr) +scram_H(const uint8 *input, pg_cryptohash_type hash_type, int key_length, + uint8 *result, const char **errstr) { pg_cryptohash_ctx *ctx; - ctx = pg_cryptohash_create(PG_SHA256); + ctx = pg_cryptohash_create(hash_type); if (ctx == NULL) { *errstr = pg_cryptohash_error(NULL); /* returns OOM */ @@ -109,8 +111,8 @@ scram_H(const uint8 *input, int len, uint8 *result, const char **errstr) } if (pg_cryptohash_init(ctx) < 0 || - pg_cryptohash_update(ctx, input, len) < 0 || - pg_cryptohash_final(ctx, result, SCRAM_KEY_LEN) < 0) + pg_cryptohash_update(ctx, input, key_length) < 0 || + pg_cryptohash_final(ctx, result, key_length) < 0) { *errstr = pg_cryptohash_error(ctx); pg_cryptohash_free(ctx); @@ -126,10 +128,11 @@ scram_H(const uint8 *input, int len, uint8 *result, const char **errstr) * pointing to a message about the error details. */ int -scram_ClientKey(const uint8 *salted_password, uint8 *result, - const char **errstr) +scram_ClientKey(const uint8 *salted_password, + pg_cryptohash_type hash_type, int key_length, + uint8 *result, const char **errstr) { - pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256); + pg_hmac_ctx *ctx = pg_hmac_create(hash_type); if (ctx == NULL) { @@ -137,9 +140,9 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result, return -1; } - if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 || + if (pg_hmac_init(ctx, salted_password, key_length) < 0 || pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 || - pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0) + pg_hmac_final(ctx, result, key_length) < 0) { *errstr = pg_hmac_error(ctx); pg_hmac_free(ctx); @@ -155,10 +158,11 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result, * pointing to a message about the error details. */ int -scram_ServerKey(const uint8 *salted_password, uint8 *result, - const char **errstr) +scram_ServerKey(const uint8 *salted_password, + pg_cryptohash_type hash_type, int key_length, + uint8 *result, const char **errstr) { - pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256); + pg_hmac_ctx *ctx = pg_hmac_create(hash_type); if (ctx == NULL) { @@ -166,9 +170,9 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result, return -1; } - if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 || + if (pg_hmac_init(ctx, salted_password, key_length) < 0 || pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 || - pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0) + pg_hmac_final(ctx, result, key_length) < 0) { *errstr = pg_hmac_error(ctx); pg_hmac_free(ctx); @@ -192,12 +196,13 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result, * error details. */ char * -scram_build_secret(const char *salt, int saltlen, int iterations, +scram_build_secret(pg_cryptohash_type hash_type, int key_length, + const char *salt, int saltlen, int iterations, const char *password, const char **errstr) { - uint8 salted_password[SCRAM_KEY_LEN]; - uint8 stored_key[SCRAM_KEY_LEN]; - uint8 server_key[SCRAM_KEY_LEN]; + uint8 salted_password[SCRAM_MAX_KEY_LEN]; + uint8 stored_key[SCRAM_MAX_KEY_LEN]; + uint8 server_key[SCRAM_MAX_KEY_LEN]; char *result; char *p; int maxlen; @@ -206,15 +211,22 @@ scram_build_secret(const char *salt, int saltlen, int iterations, int encoded_server_len; int encoded_result; + /* Only this hash method is supported currently */ + Assert(hash_type == PG_SHA256); + if (iterations <= 0) iterations = SCRAM_DEFAULT_ITERATIONS; /* Calculate StoredKey and ServerKey */ - if (scram_SaltedPassword(password, salt, saltlen, iterations, + if (scram_SaltedPassword(password, hash_type, key_length, + salt, saltlen, iterations, salted_password, errstr) < 0 || - scram_ClientKey(salted_password, stored_key, errstr) < 0 || - scram_H(stored_key, SCRAM_KEY_LEN, stored_key, errstr) < 0 || - scram_ServerKey(salted_password, server_key, errstr) < 0) + scram_ClientKey(salted_password, hash_type, key_length, + stored_key, errstr) < 0 || + scram_H(stored_key, hash_type, key_length, + stored_key, errstr) < 0 || + scram_ServerKey(salted_password, hash_type, key_length, + server_key, errstr) < 0) { /* errstr is filled already here */ #ifdef FRONTEND @@ -231,8 +243,8 @@ scram_build_secret(const char *salt, int saltlen, int iterations, *---------- */ encoded_salt_len = pg_b64_enc_len(saltlen); - encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN); - encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN); + encoded_stored_len = pg_b64_enc_len(key_length); + encoded_server_len = pg_b64_enc_len(key_length); maxlen = strlen("SCRAM-SHA-256") + 1 + 10 + 1 /* iteration count */ @@ -269,7 +281,7 @@ scram_build_secret(const char *salt, int saltlen, int iterations, *(p++) = '$'; /* stored key */ - encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p, + encoded_result = pg_b64_encode((char *) stored_key, key_length, p, encoded_stored_len); if (encoded_result < 0) { @@ -286,7 +298,7 @@ scram_build_secret(const char *salt, int saltlen, int iterations, *(p++) = ':'; /* server key */ - encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p, + encoded_result = pg_b64_encode((char *) server_key, key_length, p, encoded_server_len); if (encoded_result < 0) { diff --git a/src/include/common/scram-common.h b/src/include/common/scram-common.h index 4acf2a78adb..953d30ac549 100644 --- a/src/include/common/scram-common.h +++ b/src/include/common/scram-common.h @@ -21,7 +21,13 @@ #define SCRAM_SHA_256_PLUS_NAME "SCRAM-SHA-256-PLUS" /* with channel binding */ /* Length of SCRAM keys (client and server) */ -#define SCRAM_KEY_LEN PG_SHA256_DIGEST_LENGTH +#define SCRAM_SHA_256_KEY_LEN PG_SHA256_DIGEST_LENGTH + +/* + * Size of buffers used internally by SCRAM routines, that should be the + * maximum of SCRAM_SHA_*_KEY_LEN among the hash methods supported. + */ +#define SCRAM_MAX_KEY_LEN SCRAM_SHA_256_KEY_LEN /* * Size of random nonce generated in the authentication exchange. This @@ -43,17 +49,22 @@ */ #define SCRAM_DEFAULT_ITERATIONS 4096 -extern int scram_SaltedPassword(const char *password, const char *salt, - int saltlen, int iterations, uint8 *result, - const char **errstr); -extern int scram_H(const uint8 *input, int len, uint8 *result, +extern int scram_SaltedPassword(const char *password, + pg_cryptohash_type hash_type, int key_length, + const char *salt, int saltlen, int iterations, + uint8 *result, const char **errstr); +extern int scram_H(const uint8 *input, pg_cryptohash_type hash_type, + int key_length, uint8 *result, const char **errstr); -extern int scram_ClientKey(const uint8 *salted_password, uint8 *result, - const char **errstr); -extern int scram_ServerKey(const uint8 *salted_password, uint8 *result, - const char **errstr); +extern int scram_ClientKey(const uint8 *salted_password, + pg_cryptohash_type hash_type, int key_length, + uint8 *result, const char **errstr); +extern int scram_ServerKey(const uint8 *salted_password, + pg_cryptohash_type hash_type, int key_length, + uint8 *result, const char **errstr); -extern char *scram_build_secret(const char *salt, int saltlen, int iterations, +extern char *scram_build_secret(pg_cryptohash_type hash_type, int key_length, + const char *salt, int saltlen, int iterations, const char *password, const char **errstr); #endif /* SCRAM_COMMON_H */ diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h index c51e848c24d..b29501ef969 100644 --- a/src/include/libpq/scram.h +++ b/src/include/libpq/scram.h @@ -13,6 +13,7 @@ #ifndef PG_SCRAM_H #define PG_SCRAM_H +#include "common/cryptohash.h" #include "lib/stringinfo.h" #include "libpq/libpq-be.h" #include "libpq/sasl.h" @@ -22,7 +23,10 @@ extern PGDLLIMPORT const pg_be_sasl_mech pg_be_scram_mech; /* Routines to handle and check SCRAM-SHA-256 secret */ extern char *pg_be_scram_build_secret(const char *password); -extern bool parse_scram_secret(const char *secret, int *iterations, char **salt, +extern bool parse_scram_secret(const char *secret, + int *iterations, + pg_cryptohash_type *hash_type, + int *key_length, char **salt, uint8 *stored_key, uint8 *server_key); extern bool scram_verify_plain_password(const char *username, const char *password, const char *secret); diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c index c500bea9e74..7410d5ba529 100644 --- a/src/interfaces/libpq/fe-auth-scram.c +++ b/src/interfaces/libpq/fe-auth-scram.c @@ -58,8 +58,12 @@ typedef struct char *password; char *sasl_mechanism; + /* State data depending on the hash type */ + pg_cryptohash_type hash_type; + int key_length; + /* We construct these */ - uint8 SaltedPassword[SCRAM_KEY_LEN]; + uint8 SaltedPassword[SCRAM_MAX_KEY_LEN]; char *client_nonce; char *client_first_message_bare; char *client_final_message_without_proof; @@ -73,7 +77,7 @@ typedef struct /* These come from the server-final message */ char *server_final_message; - char ServerSignature[SCRAM_KEY_LEN]; + char ServerSignature[SCRAM_MAX_KEY_LEN]; } fe_scram_state; static bool read_server_first_message(fe_scram_state *state, char *input); @@ -106,8 +110,10 @@ scram_init(PGconn *conn, memset(state, 0, sizeof(fe_scram_state)); state->conn = conn; state->state = FE_SCRAM_INIT; - state->sasl_mechanism = strdup(sasl_mechanism); + state->key_length = SCRAM_SHA_256_KEY_LEN; + state->hash_type = PG_SHA256; + state->sasl_mechanism = strdup(sasl_mechanism); if (!state->sasl_mechanism) { free(state); @@ -450,7 +456,7 @@ build_client_final_message(fe_scram_state *state) { PQExpBufferData buf; PGconn *conn = state->conn; - uint8 client_proof[SCRAM_KEY_LEN]; + uint8 client_proof[SCRAM_MAX_KEY_LEN]; char *result; int encoded_len; const char *errstr = NULL; @@ -565,11 +571,11 @@ build_client_final_message(fe_scram_state *state) } appendPQExpBufferStr(&buf, ",p="); - encoded_len = pg_b64_enc_len(SCRAM_KEY_LEN); + encoded_len = pg_b64_enc_len(state->key_length); if (!enlargePQExpBuffer(&buf, encoded_len)) goto oom_error; encoded_len = pg_b64_encode((char *) client_proof, - SCRAM_KEY_LEN, + state->key_length, buf.data + buf.len, encoded_len); if (encoded_len < 0) @@ -738,13 +744,14 @@ read_server_final_message(fe_scram_state *state, char *input) strlen(encoded_server_signature), decoded_server_signature, server_signature_len); - if (server_signature_len != SCRAM_KEY_LEN) + if (server_signature_len != state->key_length) { free(decoded_server_signature); libpq_append_conn_error(conn, "malformed SCRAM message (invalid server signature)"); return false; } - memcpy(state->ServerSignature, decoded_server_signature, SCRAM_KEY_LEN); + memcpy(state->ServerSignature, decoded_server_signature, + state->key_length); free(decoded_server_signature); return true; @@ -760,13 +767,13 @@ calculate_client_proof(fe_scram_state *state, const char *client_final_message_without_proof, uint8 *result, const char **errstr) { - uint8 StoredKey[SCRAM_KEY_LEN]; - uint8 ClientKey[SCRAM_KEY_LEN]; - uint8 ClientSignature[SCRAM_KEY_LEN]; + uint8 StoredKey[SCRAM_MAX_KEY_LEN]; + uint8 ClientKey[SCRAM_MAX_KEY_LEN]; + uint8 ClientSignature[SCRAM_MAX_KEY_LEN]; int i; pg_hmac_ctx *ctx; - ctx = pg_hmac_create(PG_SHA256); + ctx = pg_hmac_create(state->hash_type); if (ctx == NULL) { *errstr = pg_hmac_error(NULL); /* returns OOM */ @@ -777,18 +784,21 @@ calculate_client_proof(fe_scram_state *state, * Calculate SaltedPassword, and store it in 'state' so that we can reuse * it later in verify_server_signature. */ - if (scram_SaltedPassword(state->password, state->salt, state->saltlen, + if (scram_SaltedPassword(state->password, state->hash_type, + state->key_length, state->salt, state->saltlen, state->iterations, state->SaltedPassword, errstr) < 0 || - scram_ClientKey(state->SaltedPassword, ClientKey, errstr) < 0 || - scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey, errstr) < 0) + scram_ClientKey(state->SaltedPassword, state->hash_type, + state->key_length, ClientKey, errstr) < 0 || + scram_H(ClientKey, state->hash_type, state->key_length, + StoredKey, errstr) < 0) { /* errstr is already filled here */ pg_hmac_free(ctx); return false; } - if (pg_hmac_init(ctx, StoredKey, SCRAM_KEY_LEN) < 0 || + if (pg_hmac_init(ctx, StoredKey, state->key_length) < 0 || pg_hmac_update(ctx, (uint8 *) state->client_first_message_bare, strlen(state->client_first_message_bare)) < 0 || @@ -800,14 +810,14 @@ calculate_client_proof(fe_scram_state *state, pg_hmac_update(ctx, (uint8 *) client_final_message_without_proof, strlen(client_final_message_without_proof)) < 0 || - pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0) + pg_hmac_final(ctx, ClientSignature, state->key_length) < 0) { *errstr = pg_hmac_error(ctx); pg_hmac_free(ctx); return false; } - for (i = 0; i < SCRAM_KEY_LEN; i++) + for (i = 0; i < state->key_length; i++) result[i] = ClientKey[i] ^ ClientSignature[i]; pg_hmac_free(ctx); @@ -825,18 +835,19 @@ static bool verify_server_signature(fe_scram_state *state, bool *match, const char **errstr) { - uint8 expected_ServerSignature[SCRAM_KEY_LEN]; - uint8 ServerKey[SCRAM_KEY_LEN]; + uint8 expected_ServerSignature[SCRAM_MAX_KEY_LEN]; + uint8 ServerKey[SCRAM_MAX_KEY_LEN]; pg_hmac_ctx *ctx; - ctx = pg_hmac_create(PG_SHA256); + ctx = pg_hmac_create(state->hash_type); if (ctx == NULL) { *errstr = pg_hmac_error(NULL); /* returns OOM */ return false; } - if (scram_ServerKey(state->SaltedPassword, ServerKey, errstr) < 0) + if (scram_ServerKey(state->SaltedPassword, state->hash_type, + state->key_length, ServerKey, errstr) < 0) { /* errstr is filled already */ pg_hmac_free(ctx); @@ -844,7 +855,7 @@ verify_server_signature(fe_scram_state *state, bool *match, } /* calculate ServerSignature */ - if (pg_hmac_init(ctx, ServerKey, SCRAM_KEY_LEN) < 0 || + if (pg_hmac_init(ctx, ServerKey, state->key_length) < 0 || pg_hmac_update(ctx, (uint8 *) state->client_first_message_bare, strlen(state->client_first_message_bare)) < 0 || @@ -857,7 +868,7 @@ verify_server_signature(fe_scram_state *state, bool *match, (uint8 *) state->client_final_message_without_proof, strlen(state->client_final_message_without_proof)) < 0 || pg_hmac_final(ctx, expected_ServerSignature, - sizeof(expected_ServerSignature)) < 0) + state->key_length) < 0) { *errstr = pg_hmac_error(ctx); pg_hmac_free(ctx); @@ -867,7 +878,8 @@ verify_server_signature(fe_scram_state *state, bool *match, pg_hmac_free(ctx); /* signature processed, so now check after it */ - if (memcmp(expected_ServerSignature, state->ServerSignature, SCRAM_KEY_LEN) != 0) + if (memcmp(expected_ServerSignature, state->ServerSignature, + state->key_length) != 0) *match = false; else *match = true; @@ -912,7 +924,8 @@ pg_fe_scram_build_secret(const char *password, const char **errstr) return NULL; } - result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN, + result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN, saltbuf, + SCRAM_DEFAULT_SALT_LEN, SCRAM_DEFAULT_ITERATIONS, password, errstr);