Remove hardcoded dependency to cryptohash type in the internals of SCRAM

SCRAM_KEY_LEN was a variable used in the internal routines of SCRAM to
size a set of fixed-sized arrays used in the SHA and HMAC computations
during the SASL exchange or when building a SCRAM password.  This had a
hard dependency on SHA-256, reducing the flexibility of SCRAM when it
comes to the addition of more hash methods.  A second issue was that
SHA-256 is assumed as the cryptohash method to use all the time.

This commit renames SCRAM_KEY_LEN to a more generic SCRAM_KEY_MAX_LEN,
which is used as the size of the buffers used by the internal routines
of SCRAM.  This is aimed at tracking centrally the maximum size
necessary for all the hash methods supported by SCRAM.  A global
variable has the advantage of keeping the code in its simplest form,
reducing the need of more alloc/free logic for all the buffers used in
the hash calculations.

A second change is that the key length (SHA digest length) and hash
types are now tracked by the state data in the backend and the frontend,
the common portions being extended to handle these as arguments by the
internal routines of SCRAM.  There are a few RFC proposals floating
around to extend the SCRAM protocol, including some to use stronger
cryptohash algorithms, so this lifts some of the existing restrictions
in the code.

The code in charge of parsing and building SCRAM secrets is extended to
rely on the key length and on the cryptohash type used for the exchange,
assuming currently that only SHA-256 is supported for the moment.  Note
that the mock authentication simply enforces SHA-256.

Author: Michael Paquier
Reviewed-by: Peter Eisentraut, Jonathan Katz
Discussion: https://postgr.es/m/Y5k3Qiweo/1g9CG6@paquier.xyz
This commit is contained in:
Michael Paquier 2022-12-20 08:53:22 +09:00
parent eb60eb08a9
commit b3bb7d12af
6 changed files with 206 additions and 131 deletions

View File

@ -141,10 +141,14 @@ typedef struct
Port *port; Port *port;
bool channel_binding_in_use; bool channel_binding_in_use;
/* State data depending on the hash type */
pg_cryptohash_type hash_type;
int key_length;
int iterations; int iterations;
char *salt; /* base64-encoded */ char *salt; /* base64-encoded */
uint8 StoredKey[SCRAM_KEY_LEN]; uint8 StoredKey[SCRAM_MAX_KEY_LEN];
uint8 ServerKey[SCRAM_KEY_LEN]; uint8 ServerKey[SCRAM_MAX_KEY_LEN];
/* Fields of the first message from client */ /* Fields of the first message from client */
char cbind_flag; char cbind_flag;
@ -155,7 +159,7 @@ typedef struct
/* Fields from the last message from client */ /* Fields from the last message from client */
char *client_final_message_without_proof; char *client_final_message_without_proof;
char *client_final_nonce; char *client_final_nonce;
char ClientProof[SCRAM_KEY_LEN]; char ClientProof[SCRAM_MAX_KEY_LEN];
/* Fields generated in the server */ /* Fields generated in the server */
char *server_first_message; char *server_first_message;
@ -177,12 +181,15 @@ static char *build_server_first_message(scram_state *state);
static char *build_server_final_message(scram_state *state); static char *build_server_final_message(scram_state *state);
static bool verify_client_proof(scram_state *state); static bool verify_client_proof(scram_state *state);
static bool verify_final_nonce(scram_state *state); static bool verify_final_nonce(scram_state *state);
static void mock_scram_secret(const char *username, int *iterations, static void mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
char **salt, uint8 *stored_key, uint8 *server_key); int *iterations, int *key_length, char **salt,
uint8 *stored_key, uint8 *server_key);
static bool is_scram_printable(char *p); static bool is_scram_printable(char *p);
static char *sanitize_char(char c); static char *sanitize_char(char c);
static char *sanitize_str(const char *s); static char *sanitize_str(const char *s);
static char *scram_mock_salt(const char *username); static char *scram_mock_salt(const char *username,
pg_cryptohash_type hash_type,
int key_length);
/* /*
* Get a list of SASL mechanisms that this module supports. * Get a list of SASL mechanisms that this module supports.
@ -266,8 +273,11 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
if (password_type == PASSWORD_TYPE_SCRAM_SHA_256) if (password_type == PASSWORD_TYPE_SCRAM_SHA_256)
{ {
if (parse_scram_secret(shadow_pass, &state->iterations, &state->salt, if (parse_scram_secret(shadow_pass, &state->iterations,
state->StoredKey, state->ServerKey)) &state->hash_type, &state->key_length,
&state->salt,
state->StoredKey,
state->ServerKey))
got_secret = true; got_secret = true;
else else
{ {
@ -310,8 +320,10 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
*/ */
if (!got_secret) if (!got_secret)
{ {
mock_scram_secret(state->port->user_name, &state->iterations, mock_scram_secret(state->port->user_name, &state->hash_type,
&state->salt, state->StoredKey, state->ServerKey); &state->iterations, &state->key_length,
&state->salt,
state->StoredKey, state->ServerKey);
state->doomed = true; state->doomed = true;
} }
@ -482,7 +494,8 @@ pg_be_scram_build_secret(const char *password)
(errcode(ERRCODE_INTERNAL_ERROR), (errcode(ERRCODE_INTERNAL_ERROR),
errmsg("could not generate random salt"))); errmsg("could not generate random salt")));
result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN, result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN,
saltbuf, SCRAM_DEFAULT_SALT_LEN,
SCRAM_DEFAULT_ITERATIONS, password, SCRAM_DEFAULT_ITERATIONS, password,
&errstr); &errstr);
@ -505,16 +518,18 @@ scram_verify_plain_password(const char *username, const char *password,
char *salt; char *salt;
int saltlen; int saltlen;
int iterations; int iterations;
uint8 salted_password[SCRAM_KEY_LEN]; int key_length = 0;
uint8 stored_key[SCRAM_KEY_LEN]; pg_cryptohash_type hash_type;
uint8 server_key[SCRAM_KEY_LEN]; uint8 salted_password[SCRAM_MAX_KEY_LEN];
uint8 computed_key[SCRAM_KEY_LEN]; uint8 stored_key[SCRAM_MAX_KEY_LEN];
uint8 server_key[SCRAM_MAX_KEY_LEN];
uint8 computed_key[SCRAM_MAX_KEY_LEN];
char *prep_password; char *prep_password;
pg_saslprep_rc rc; pg_saslprep_rc rc;
const char *errstr = NULL; const char *errstr = NULL;
if (!parse_scram_secret(secret, &iterations, &encoded_salt, if (!parse_scram_secret(secret, &iterations, &hash_type, &key_length,
stored_key, server_key)) &encoded_salt, stored_key, server_key))
{ {
/* /*
* The password looked like a SCRAM secret, but could not be parsed. * The password looked like a SCRAM secret, but could not be parsed.
@ -541,9 +556,11 @@ scram_verify_plain_password(const char *username, const char *password,
password = prep_password; password = prep_password;
/* Compute Server Key based on the user-supplied plaintext password */ /* Compute Server Key based on the user-supplied plaintext password */
if (scram_SaltedPassword(password, salt, saltlen, iterations, if (scram_SaltedPassword(password, hash_type, key_length,
salt, saltlen, iterations,
salted_password, &errstr) < 0 || salted_password, &errstr) < 0 ||
scram_ServerKey(salted_password, computed_key, &errstr) < 0) scram_ServerKey(salted_password, hash_type, key_length,
computed_key, &errstr) < 0)
{ {
elog(ERROR, "could not compute server key: %s", errstr); elog(ERROR, "could not compute server key: %s", errstr);
} }
@ -555,7 +572,7 @@ scram_verify_plain_password(const char *username, const char *password,
* Compare the secret's Server Key with the one computed from the * Compare the secret's Server Key with the one computed from the
* user-supplied password. * user-supplied password.
*/ */
return memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0; return memcmp(computed_key, server_key, key_length) == 0;
} }
@ -565,14 +582,15 @@ scram_verify_plain_password(const char *username, const char *password,
* On success, the iteration count, salt, stored key, and server key are * On success, the iteration count, salt, stored key, and server key are
* extracted from the secret, and returned to the caller. For 'stored_key' * extracted from the secret, and returned to the caller. For 'stored_key'
* and 'server_key', the caller must pass pre-allocated buffers of size * and 'server_key', the caller must pass pre-allocated buffers of size
* SCRAM_KEY_LEN. Salt is returned as a base64-encoded, null-terminated * SCRAM_MAX_KEY_LEN. Salt is returned as a base64-encoded, null-terminated
* string. The buffer for the salt is palloc'd by this function. * string. The buffer for the salt is palloc'd by this function.
* *
* Returns true if the SCRAM secret has been parsed, and false otherwise. * Returns true if the SCRAM secret has been parsed, and false otherwise.
*/ */
bool bool
parse_scram_secret(const char *secret, int *iterations, char **salt, parse_scram_secret(const char *secret, int *iterations,
uint8 *stored_key, uint8 *server_key) pg_cryptohash_type *hash_type, int *key_length,
char **salt, uint8 *stored_key, uint8 *server_key)
{ {
char *v; char *v;
char *p; char *p;
@ -606,6 +624,8 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
/* Parse the fields */ /* Parse the fields */
if (strcmp(scheme_str, "SCRAM-SHA-256") != 0) if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
goto invalid_secret; goto invalid_secret;
*hash_type = PG_SHA256;
*key_length = SCRAM_SHA_256_KEY_LEN;
errno = 0; errno = 0;
*iterations = strtol(iterations_str, &p, 10); *iterations = strtol(iterations_str, &p, 10);
@ -631,17 +651,17 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
decoded_stored_buf = palloc(decoded_len); decoded_stored_buf = palloc(decoded_len);
decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str), decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
decoded_stored_buf, decoded_len); decoded_stored_buf, decoded_len);
if (decoded_len != SCRAM_KEY_LEN) if (decoded_len != *key_length)
goto invalid_secret; goto invalid_secret;
memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN); memcpy(stored_key, decoded_stored_buf, *key_length);
decoded_len = pg_b64_dec_len(strlen(serverkey_str)); decoded_len = pg_b64_dec_len(strlen(serverkey_str));
decoded_server_buf = palloc(decoded_len); decoded_server_buf = palloc(decoded_len);
decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str), decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
decoded_server_buf, decoded_len); decoded_server_buf, decoded_len);
if (decoded_len != SCRAM_KEY_LEN) if (decoded_len != *key_length)
goto invalid_secret; goto invalid_secret;
memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN); memcpy(server_key, decoded_server_buf, *key_length);
return true; return true;
@ -655,20 +675,25 @@ invalid_secret:
* *
* In a normal authentication, these are extracted from the secret * In a normal authentication, these are extracted from the secret
* stored in the server. This function generates values that look * stored in the server. This function generates values that look
* realistic, for when there is no stored secret. * realistic, for when there is no stored secret, using SCRAM-SHA-256.
* *
* Like in parse_scram_secret(), for 'stored_key' and 'server_key', the * Like in parse_scram_secret(), for 'stored_key' and 'server_key', the
* caller must pass pre-allocated buffers of size SCRAM_KEY_LEN, and * caller must pass pre-allocated buffers of size SCRAM_MAX_KEY_LEN, and
* the buffer for the salt is palloc'd by this function. * the buffer for the salt is palloc'd by this function.
*/ */
static void static void
mock_scram_secret(const char *username, int *iterations, char **salt, mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
int *iterations, int *key_length, char **salt,
uint8 *stored_key, uint8 *server_key) uint8 *stored_key, uint8 *server_key)
{ {
char *raw_salt; char *raw_salt;
char *encoded_salt; char *encoded_salt;
int encoded_len; int encoded_len;
/* Enforce the use of SHA-256, which would be realistic enough */
*hash_type = PG_SHA256;
*key_length = SCRAM_SHA_256_KEY_LEN;
/* /*
* Generate deterministic salt. * Generate deterministic salt.
* *
@ -677,7 +702,7 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
* as the salt generated for mock authentication uses the cluster's nonce * as the salt generated for mock authentication uses the cluster's nonce
* value. * value.
*/ */
raw_salt = scram_mock_salt(username); raw_salt = scram_mock_salt(username, *hash_type, *key_length);
if (raw_salt == NULL) if (raw_salt == NULL)
elog(ERROR, "could not encode salt"); elog(ERROR, "could not encode salt");
@ -695,8 +720,8 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
*iterations = SCRAM_DEFAULT_ITERATIONS; *iterations = SCRAM_DEFAULT_ITERATIONS;
/* StoredKey and ServerKey are not used in a doomed authentication */ /* StoredKey and ServerKey are not used in a doomed authentication */
memset(stored_key, 0, SCRAM_KEY_LEN); memset(stored_key, 0, SCRAM_MAX_KEY_LEN);
memset(server_key, 0, SCRAM_KEY_LEN); memset(server_key, 0, SCRAM_MAX_KEY_LEN);
} }
/* /*
@ -1111,10 +1136,10 @@ verify_final_nonce(scram_state *state)
static bool static bool
verify_client_proof(scram_state *state) verify_client_proof(scram_state *state)
{ {
uint8 ClientSignature[SCRAM_KEY_LEN]; uint8 ClientSignature[SCRAM_MAX_KEY_LEN];
uint8 ClientKey[SCRAM_KEY_LEN]; uint8 ClientKey[SCRAM_MAX_KEY_LEN];
uint8 client_StoredKey[SCRAM_KEY_LEN]; uint8 client_StoredKey[SCRAM_MAX_KEY_LEN];
pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256); pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
int i; int i;
const char *errstr = NULL; const char *errstr = NULL;
@ -1123,7 +1148,7 @@ verify_client_proof(scram_state *state)
* here even when processing the calculations as this could involve a mock * here even when processing the calculations as this could involve a mock
* authentication. * authentication.
*/ */
if (pg_hmac_init(ctx, state->StoredKey, SCRAM_KEY_LEN) < 0 || if (pg_hmac_init(ctx, state->StoredKey, state->key_length) < 0 ||
pg_hmac_update(ctx, pg_hmac_update(ctx,
(uint8 *) state->client_first_message_bare, (uint8 *) state->client_first_message_bare,
strlen(state->client_first_message_bare)) < 0 || strlen(state->client_first_message_bare)) < 0 ||
@ -1135,7 +1160,7 @@ verify_client_proof(scram_state *state)
pg_hmac_update(ctx, pg_hmac_update(ctx,
(uint8 *) state->client_final_message_without_proof, (uint8 *) state->client_final_message_without_proof,
strlen(state->client_final_message_without_proof)) < 0 || strlen(state->client_final_message_without_proof)) < 0 ||
pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0) pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
{ {
elog(ERROR, "could not calculate client signature: %s", elog(ERROR, "could not calculate client signature: %s",
pg_hmac_error(ctx)); pg_hmac_error(ctx));
@ -1144,14 +1169,15 @@ verify_client_proof(scram_state *state)
pg_hmac_free(ctx); pg_hmac_free(ctx);
/* Extract the ClientKey that the client calculated from the proof */ /* Extract the ClientKey that the client calculated from the proof */
for (i = 0; i < SCRAM_KEY_LEN; i++) for (i = 0; i < state->key_length; i++)
ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i]; ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
/* Hash it one more time, and compare with StoredKey */ /* Hash it one more time, and compare with StoredKey */
if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey, &errstr) < 0) if (scram_H(ClientKey, state->hash_type, state->key_length,
client_StoredKey, &errstr) < 0)
elog(ERROR, "could not hash stored key: %s", errstr); elog(ERROR, "could not hash stored key: %s", errstr);
if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0) if (memcmp(client_StoredKey, state->StoredKey, state->key_length) != 0)
return false; return false;
return true; return true;
@ -1349,12 +1375,12 @@ read_client_final_message(scram_state *state, const char *input)
client_proof_len = pg_b64_dec_len(strlen(value)); client_proof_len = pg_b64_dec_len(strlen(value));
client_proof = palloc(client_proof_len); client_proof = palloc(client_proof_len);
if (pg_b64_decode(value, strlen(value), client_proof, if (pg_b64_decode(value, strlen(value), client_proof,
client_proof_len) != SCRAM_KEY_LEN) client_proof_len) != state->key_length)
ereport(ERROR, ereport(ERROR,
(errcode(ERRCODE_PROTOCOL_VIOLATION), (errcode(ERRCODE_PROTOCOL_VIOLATION),
errmsg("malformed SCRAM message"), errmsg("malformed SCRAM message"),
errdetail("Malformed proof in client-final-message."))); errdetail("Malformed proof in client-final-message.")));
memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN); memcpy(state->ClientProof, client_proof, state->key_length);
pfree(client_proof); pfree(client_proof);
if (*p != '\0') if (*p != '\0')
@ -1374,13 +1400,13 @@ read_client_final_message(scram_state *state, const char *input)
static char * static char *
build_server_final_message(scram_state *state) build_server_final_message(scram_state *state)
{ {
uint8 ServerSignature[SCRAM_KEY_LEN]; uint8 ServerSignature[SCRAM_MAX_KEY_LEN];
char *server_signature_base64; char *server_signature_base64;
int siglen; int siglen;
pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256); pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
/* calculate ServerSignature */ /* calculate ServerSignature */
if (pg_hmac_init(ctx, state->ServerKey, SCRAM_KEY_LEN) < 0 || if (pg_hmac_init(ctx, state->ServerKey, state->key_length) < 0 ||
pg_hmac_update(ctx, pg_hmac_update(ctx,
(uint8 *) state->client_first_message_bare, (uint8 *) state->client_first_message_bare,
strlen(state->client_first_message_bare)) < 0 || strlen(state->client_first_message_bare)) < 0 ||
@ -1392,7 +1418,7 @@ build_server_final_message(scram_state *state)
pg_hmac_update(ctx, pg_hmac_update(ctx,
(uint8 *) state->client_final_message_without_proof, (uint8 *) state->client_final_message_without_proof,
strlen(state->client_final_message_without_proof)) < 0 || strlen(state->client_final_message_without_proof)) < 0 ||
pg_hmac_final(ctx, ServerSignature, sizeof(ServerSignature)) < 0) pg_hmac_final(ctx, ServerSignature, state->key_length) < 0)
{ {
elog(ERROR, "could not calculate server signature: %s", elog(ERROR, "could not calculate server signature: %s",
pg_hmac_error(ctx)); pg_hmac_error(ctx));
@ -1400,11 +1426,11 @@ build_server_final_message(scram_state *state)
pg_hmac_free(ctx); pg_hmac_free(ctx);
siglen = pg_b64_enc_len(SCRAM_KEY_LEN); siglen = pg_b64_enc_len(state->key_length);
/* don't forget the zero-terminator */ /* don't forget the zero-terminator */
server_signature_base64 = palloc(siglen + 1); server_signature_base64 = palloc(siglen + 1);
siglen = pg_b64_encode((const char *) ServerSignature, siglen = pg_b64_encode((const char *) ServerSignature,
SCRAM_KEY_LEN, server_signature_base64, state->key_length, server_signature_base64,
siglen); siglen);
if (siglen < 0) if (siglen < 0)
elog(ERROR, "could not encode server signature"); elog(ERROR, "could not encode server signature");
@ -1431,10 +1457,11 @@ build_server_final_message(scram_state *state)
* pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL. * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL.
*/ */
static char * static char *
scram_mock_salt(const char *username) scram_mock_salt(const char *username, pg_cryptohash_type hash_type,
int key_length)
{ {
pg_cryptohash_ctx *ctx; pg_cryptohash_ctx *ctx;
static uint8 sha_digest[PG_SHA256_DIGEST_LENGTH]; static uint8 sha_digest[SCRAM_MAX_KEY_LEN];
char *mock_auth_nonce = GetMockAuthenticationNonce(); char *mock_auth_nonce = GetMockAuthenticationNonce();
/* /*
@ -1446,11 +1473,17 @@ scram_mock_salt(const char *username)
StaticAssertDecl(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN, StaticAssertDecl(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
"salt length greater than SHA256 digest length"); "salt length greater than SHA256 digest length");
ctx = pg_cryptohash_create(PG_SHA256); /*
* This may be worth refreshing if support for more hash methods is\
* added.
*/
Assert(hash_type == PG_SHA256);
ctx = pg_cryptohash_create(hash_type);
if (pg_cryptohash_init(ctx) < 0 || if (pg_cryptohash_init(ctx) < 0 ||
pg_cryptohash_update(ctx, (uint8 *) username, strlen(username)) < 0 || pg_cryptohash_update(ctx, (uint8 *) username, strlen(username)) < 0 ||
pg_cryptohash_update(ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 || pg_cryptohash_update(ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 ||
pg_cryptohash_final(ctx, sha_digest, sizeof(sha_digest)) < 0) pg_cryptohash_final(ctx, sha_digest, key_length) < 0)
{ {
pg_cryptohash_free(ctx); pg_cryptohash_free(ctx);
return NULL; return NULL;

View File

@ -90,15 +90,17 @@ get_password_type(const char *shadow_pass)
{ {
char *encoded_salt; char *encoded_salt;
int iterations; int iterations;
uint8 stored_key[SCRAM_KEY_LEN]; int key_length = 0;
uint8 server_key[SCRAM_KEY_LEN]; pg_cryptohash_type hash_type;
uint8 stored_key[SCRAM_MAX_KEY_LEN];
uint8 server_key[SCRAM_MAX_KEY_LEN];
if (strncmp(shadow_pass, "md5", 3) == 0 && if (strncmp(shadow_pass, "md5", 3) == 0 &&
strlen(shadow_pass) == MD5_PASSWD_LEN && strlen(shadow_pass) == MD5_PASSWD_LEN &&
strspn(shadow_pass + 3, MD5_PASSWD_CHARSET) == MD5_PASSWD_LEN - 3) strspn(shadow_pass + 3, MD5_PASSWD_CHARSET) == MD5_PASSWD_LEN - 3)
return PASSWORD_TYPE_MD5; return PASSWORD_TYPE_MD5;
if (parse_scram_secret(shadow_pass, &iterations, &encoded_salt, if (parse_scram_secret(shadow_pass, &iterations, &hash_type, &key_length,
stored_key, server_key)) &encoded_salt, stored_key, server_key))
return PASSWORD_TYPE_SCRAM_SHA_256; return PASSWORD_TYPE_SCRAM_SHA_256;
return PASSWORD_TYPE_PLAINTEXT; return PASSWORD_TYPE_PLAINTEXT;
} }

View File

@ -33,6 +33,7 @@
*/ */
int int
scram_SaltedPassword(const char *password, scram_SaltedPassword(const char *password,
pg_cryptohash_type hash_type, int key_length,
const char *salt, int saltlen, int iterations, const char *salt, int saltlen, int iterations,
uint8 *result, const char **errstr) uint8 *result, const char **errstr)
{ {
@ -40,9 +41,9 @@ scram_SaltedPassword(const char *password,
uint32 one = pg_hton32(1); uint32 one = pg_hton32(1);
int i, int i,
j; j;
uint8 Ui[SCRAM_KEY_LEN]; uint8 Ui[SCRAM_MAX_KEY_LEN];
uint8 Ui_prev[SCRAM_KEY_LEN]; uint8 Ui_prev[SCRAM_MAX_KEY_LEN];
pg_hmac_ctx *hmac_ctx = pg_hmac_create(PG_SHA256); pg_hmac_ctx *hmac_ctx = pg_hmac_create(hash_type);
if (hmac_ctx == NULL) if (hmac_ctx == NULL)
{ {
@ -60,30 +61,30 @@ scram_SaltedPassword(const char *password,
if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 || if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
pg_hmac_update(hmac_ctx, (uint8 *) salt, saltlen) < 0 || pg_hmac_update(hmac_ctx, (uint8 *) salt, saltlen) < 0 ||
pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 || pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 ||
pg_hmac_final(hmac_ctx, Ui_prev, sizeof(Ui_prev)) < 0) pg_hmac_final(hmac_ctx, Ui_prev, key_length) < 0)
{ {
*errstr = pg_hmac_error(hmac_ctx); *errstr = pg_hmac_error(hmac_ctx);
pg_hmac_free(hmac_ctx); pg_hmac_free(hmac_ctx);
return -1; return -1;
} }
memcpy(result, Ui_prev, SCRAM_KEY_LEN); memcpy(result, Ui_prev, key_length);
/* Subsequent iterations */ /* Subsequent iterations */
for (i = 2; i <= iterations; i++) for (i = 2; i <= iterations; i++)
{ {
if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 || if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, SCRAM_KEY_LEN) < 0 || pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, key_length) < 0 ||
pg_hmac_final(hmac_ctx, Ui, sizeof(Ui)) < 0) pg_hmac_final(hmac_ctx, Ui, key_length) < 0)
{ {
*errstr = pg_hmac_error(hmac_ctx); *errstr = pg_hmac_error(hmac_ctx);
pg_hmac_free(hmac_ctx); pg_hmac_free(hmac_ctx);
return -1; return -1;
} }
for (j = 0; j < SCRAM_KEY_LEN; j++) for (j = 0; j < key_length; j++)
result[j] ^= Ui[j]; result[j] ^= Ui[j];
memcpy(Ui_prev, Ui, SCRAM_KEY_LEN); memcpy(Ui_prev, Ui, key_length);
} }
pg_hmac_free(hmac_ctx); pg_hmac_free(hmac_ctx);
@ -92,16 +93,17 @@ scram_SaltedPassword(const char *password,
/* /*
* Calculate SHA-256 hash for a NULL-terminated string. (The NULL terminator is * Calculate hash for a NULL-terminated string. (The NULL terminator is
* not included in the hash). Returns 0 on success, -1 on failure with *errstr * not included in the hash). Returns 0 on success, -1 on failure with *errstr
* pointing to a message about the error details. * pointing to a message about the error details.
*/ */
int int
scram_H(const uint8 *input, int len, uint8 *result, const char **errstr) scram_H(const uint8 *input, pg_cryptohash_type hash_type, int key_length,
uint8 *result, const char **errstr)
{ {
pg_cryptohash_ctx *ctx; pg_cryptohash_ctx *ctx;
ctx = pg_cryptohash_create(PG_SHA256); ctx = pg_cryptohash_create(hash_type);
if (ctx == NULL) if (ctx == NULL)
{ {
*errstr = pg_cryptohash_error(NULL); /* returns OOM */ *errstr = pg_cryptohash_error(NULL); /* returns OOM */
@ -109,8 +111,8 @@ scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
} }
if (pg_cryptohash_init(ctx) < 0 || if (pg_cryptohash_init(ctx) < 0 ||
pg_cryptohash_update(ctx, input, len) < 0 || pg_cryptohash_update(ctx, input, key_length) < 0 ||
pg_cryptohash_final(ctx, result, SCRAM_KEY_LEN) < 0) pg_cryptohash_final(ctx, result, key_length) < 0)
{ {
*errstr = pg_cryptohash_error(ctx); *errstr = pg_cryptohash_error(ctx);
pg_cryptohash_free(ctx); pg_cryptohash_free(ctx);
@ -126,10 +128,11 @@ scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
* pointing to a message about the error details. * pointing to a message about the error details.
*/ */
int int
scram_ClientKey(const uint8 *salted_password, uint8 *result, scram_ClientKey(const uint8 *salted_password,
const char **errstr) pg_cryptohash_type hash_type, int key_length,
uint8 *result, const char **errstr)
{ {
pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256); pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
if (ctx == NULL) if (ctx == NULL)
{ {
@ -137,9 +140,9 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result,
return -1; return -1;
} }
if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 || if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 || pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 ||
pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0) pg_hmac_final(ctx, result, key_length) < 0)
{ {
*errstr = pg_hmac_error(ctx); *errstr = pg_hmac_error(ctx);
pg_hmac_free(ctx); pg_hmac_free(ctx);
@ -155,10 +158,11 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result,
* pointing to a message about the error details. * pointing to a message about the error details.
*/ */
int int
scram_ServerKey(const uint8 *salted_password, uint8 *result, scram_ServerKey(const uint8 *salted_password,
const char **errstr) pg_cryptohash_type hash_type, int key_length,
uint8 *result, const char **errstr)
{ {
pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256); pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
if (ctx == NULL) if (ctx == NULL)
{ {
@ -166,9 +170,9 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result,
return -1; return -1;
} }
if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 || if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 || pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 ||
pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0) pg_hmac_final(ctx, result, key_length) < 0)
{ {
*errstr = pg_hmac_error(ctx); *errstr = pg_hmac_error(ctx);
pg_hmac_free(ctx); pg_hmac_free(ctx);
@ -192,12 +196,13 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result,
* error details. * error details.
*/ */
char * char *
scram_build_secret(const char *salt, int saltlen, int iterations, scram_build_secret(pg_cryptohash_type hash_type, int key_length,
const char *salt, int saltlen, int iterations,
const char *password, const char **errstr) const char *password, const char **errstr)
{ {
uint8 salted_password[SCRAM_KEY_LEN]; uint8 salted_password[SCRAM_MAX_KEY_LEN];
uint8 stored_key[SCRAM_KEY_LEN]; uint8 stored_key[SCRAM_MAX_KEY_LEN];
uint8 server_key[SCRAM_KEY_LEN]; uint8 server_key[SCRAM_MAX_KEY_LEN];
char *result; char *result;
char *p; char *p;
int maxlen; int maxlen;
@ -206,15 +211,22 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
int encoded_server_len; int encoded_server_len;
int encoded_result; int encoded_result;
/* Only this hash method is supported currently */
Assert(hash_type == PG_SHA256);
if (iterations <= 0) if (iterations <= 0)
iterations = SCRAM_DEFAULT_ITERATIONS; iterations = SCRAM_DEFAULT_ITERATIONS;
/* Calculate StoredKey and ServerKey */ /* Calculate StoredKey and ServerKey */
if (scram_SaltedPassword(password, salt, saltlen, iterations, if (scram_SaltedPassword(password, hash_type, key_length,
salt, saltlen, iterations,
salted_password, errstr) < 0 || salted_password, errstr) < 0 ||
scram_ClientKey(salted_password, stored_key, errstr) < 0 || scram_ClientKey(salted_password, hash_type, key_length,
scram_H(stored_key, SCRAM_KEY_LEN, stored_key, errstr) < 0 || stored_key, errstr) < 0 ||
scram_ServerKey(salted_password, server_key, errstr) < 0) scram_H(stored_key, hash_type, key_length,
stored_key, errstr) < 0 ||
scram_ServerKey(salted_password, hash_type, key_length,
server_key, errstr) < 0)
{ {
/* errstr is filled already here */ /* errstr is filled already here */
#ifdef FRONTEND #ifdef FRONTEND
@ -231,8 +243,8 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
*---------- *----------
*/ */
encoded_salt_len = pg_b64_enc_len(saltlen); encoded_salt_len = pg_b64_enc_len(saltlen);
encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN); encoded_stored_len = pg_b64_enc_len(key_length);
encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN); encoded_server_len = pg_b64_enc_len(key_length);
maxlen = strlen("SCRAM-SHA-256") + 1 maxlen = strlen("SCRAM-SHA-256") + 1
+ 10 + 1 /* iteration count */ + 10 + 1 /* iteration count */
@ -269,7 +281,7 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
*(p++) = '$'; *(p++) = '$';
/* stored key */ /* stored key */
encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p, encoded_result = pg_b64_encode((char *) stored_key, key_length, p,
encoded_stored_len); encoded_stored_len);
if (encoded_result < 0) if (encoded_result < 0)
{ {
@ -286,7 +298,7 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
*(p++) = ':'; *(p++) = ':';
/* server key */ /* server key */
encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p, encoded_result = pg_b64_encode((char *) server_key, key_length, p,
encoded_server_len); encoded_server_len);
if (encoded_result < 0) if (encoded_result < 0)
{ {

View File

@ -21,7 +21,13 @@
#define SCRAM_SHA_256_PLUS_NAME "SCRAM-SHA-256-PLUS" /* with channel binding */ #define SCRAM_SHA_256_PLUS_NAME "SCRAM-SHA-256-PLUS" /* with channel binding */
/* Length of SCRAM keys (client and server) */ /* Length of SCRAM keys (client and server) */
#define SCRAM_KEY_LEN PG_SHA256_DIGEST_LENGTH #define SCRAM_SHA_256_KEY_LEN PG_SHA256_DIGEST_LENGTH
/*
* Size of buffers used internally by SCRAM routines, that should be the
* maximum of SCRAM_SHA_*_KEY_LEN among the hash methods supported.
*/
#define SCRAM_MAX_KEY_LEN SCRAM_SHA_256_KEY_LEN
/* /*
* Size of random nonce generated in the authentication exchange. This * Size of random nonce generated in the authentication exchange. This
@ -43,17 +49,22 @@
*/ */
#define SCRAM_DEFAULT_ITERATIONS 4096 #define SCRAM_DEFAULT_ITERATIONS 4096
extern int scram_SaltedPassword(const char *password, const char *salt, extern int scram_SaltedPassword(const char *password,
int saltlen, int iterations, uint8 *result, pg_cryptohash_type hash_type, int key_length,
const char **errstr); const char *salt, int saltlen, int iterations,
extern int scram_H(const uint8 *input, int len, uint8 *result, uint8 *result, const char **errstr);
const char **errstr); extern int scram_H(const uint8 *input, pg_cryptohash_type hash_type,
extern int scram_ClientKey(const uint8 *salted_password, uint8 *result, int key_length, uint8 *result,
const char **errstr);
extern int scram_ServerKey(const uint8 *salted_password, uint8 *result,
const char **errstr); const char **errstr);
extern int scram_ClientKey(const uint8 *salted_password,
pg_cryptohash_type hash_type, int key_length,
uint8 *result, const char **errstr);
extern int scram_ServerKey(const uint8 *salted_password,
pg_cryptohash_type hash_type, int key_length,
uint8 *result, const char **errstr);
extern char *scram_build_secret(const char *salt, int saltlen, int iterations, extern char *scram_build_secret(pg_cryptohash_type hash_type, int key_length,
const char *salt, int saltlen, int iterations,
const char *password, const char **errstr); const char *password, const char **errstr);
#endif /* SCRAM_COMMON_H */ #endif /* SCRAM_COMMON_H */

View File

@ -13,6 +13,7 @@
#ifndef PG_SCRAM_H #ifndef PG_SCRAM_H
#define PG_SCRAM_H #define PG_SCRAM_H
#include "common/cryptohash.h"
#include "lib/stringinfo.h" #include "lib/stringinfo.h"
#include "libpq/libpq-be.h" #include "libpq/libpq-be.h"
#include "libpq/sasl.h" #include "libpq/sasl.h"
@ -22,7 +23,10 @@ extern PGDLLIMPORT const pg_be_sasl_mech pg_be_scram_mech;
/* Routines to handle and check SCRAM-SHA-256 secret */ /* Routines to handle and check SCRAM-SHA-256 secret */
extern char *pg_be_scram_build_secret(const char *password); extern char *pg_be_scram_build_secret(const char *password);
extern bool parse_scram_secret(const char *secret, int *iterations, char **salt, extern bool parse_scram_secret(const char *secret,
int *iterations,
pg_cryptohash_type *hash_type,
int *key_length, char **salt,
uint8 *stored_key, uint8 *server_key); uint8 *stored_key, uint8 *server_key);
extern bool scram_verify_plain_password(const char *username, extern bool scram_verify_plain_password(const char *username,
const char *password, const char *secret); const char *password, const char *secret);

View File

@ -58,8 +58,12 @@ typedef struct
char *password; char *password;
char *sasl_mechanism; char *sasl_mechanism;
/* State data depending on the hash type */
pg_cryptohash_type hash_type;
int key_length;
/* We construct these */ /* We construct these */
uint8 SaltedPassword[SCRAM_KEY_LEN]; uint8 SaltedPassword[SCRAM_MAX_KEY_LEN];
char *client_nonce; char *client_nonce;
char *client_first_message_bare; char *client_first_message_bare;
char *client_final_message_without_proof; char *client_final_message_without_proof;
@ -73,7 +77,7 @@ typedef struct
/* These come from the server-final message */ /* These come from the server-final message */
char *server_final_message; char *server_final_message;
char ServerSignature[SCRAM_KEY_LEN]; char ServerSignature[SCRAM_MAX_KEY_LEN];
} fe_scram_state; } fe_scram_state;
static bool read_server_first_message(fe_scram_state *state, char *input); static bool read_server_first_message(fe_scram_state *state, char *input);
@ -106,8 +110,10 @@ scram_init(PGconn *conn,
memset(state, 0, sizeof(fe_scram_state)); memset(state, 0, sizeof(fe_scram_state));
state->conn = conn; state->conn = conn;
state->state = FE_SCRAM_INIT; state->state = FE_SCRAM_INIT;
state->sasl_mechanism = strdup(sasl_mechanism); state->key_length = SCRAM_SHA_256_KEY_LEN;
state->hash_type = PG_SHA256;
state->sasl_mechanism = strdup(sasl_mechanism);
if (!state->sasl_mechanism) if (!state->sasl_mechanism)
{ {
free(state); free(state);
@ -450,7 +456,7 @@ build_client_final_message(fe_scram_state *state)
{ {
PQExpBufferData buf; PQExpBufferData buf;
PGconn *conn = state->conn; PGconn *conn = state->conn;
uint8 client_proof[SCRAM_KEY_LEN]; uint8 client_proof[SCRAM_MAX_KEY_LEN];
char *result; char *result;
int encoded_len; int encoded_len;
const char *errstr = NULL; const char *errstr = NULL;
@ -565,11 +571,11 @@ build_client_final_message(fe_scram_state *state)
} }
appendPQExpBufferStr(&buf, ",p="); appendPQExpBufferStr(&buf, ",p=");
encoded_len = pg_b64_enc_len(SCRAM_KEY_LEN); encoded_len = pg_b64_enc_len(state->key_length);
if (!enlargePQExpBuffer(&buf, encoded_len)) if (!enlargePQExpBuffer(&buf, encoded_len))
goto oom_error; goto oom_error;
encoded_len = pg_b64_encode((char *) client_proof, encoded_len = pg_b64_encode((char *) client_proof,
SCRAM_KEY_LEN, state->key_length,
buf.data + buf.len, buf.data + buf.len,
encoded_len); encoded_len);
if (encoded_len < 0) if (encoded_len < 0)
@ -738,13 +744,14 @@ read_server_final_message(fe_scram_state *state, char *input)
strlen(encoded_server_signature), strlen(encoded_server_signature),
decoded_server_signature, decoded_server_signature,
server_signature_len); server_signature_len);
if (server_signature_len != SCRAM_KEY_LEN) if (server_signature_len != state->key_length)
{ {
free(decoded_server_signature); free(decoded_server_signature);
libpq_append_conn_error(conn, "malformed SCRAM message (invalid server signature)"); libpq_append_conn_error(conn, "malformed SCRAM message (invalid server signature)");
return false; return false;
} }
memcpy(state->ServerSignature, decoded_server_signature, SCRAM_KEY_LEN); memcpy(state->ServerSignature, decoded_server_signature,
state->key_length);
free(decoded_server_signature); free(decoded_server_signature);
return true; return true;
@ -760,13 +767,13 @@ calculate_client_proof(fe_scram_state *state,
const char *client_final_message_without_proof, const char *client_final_message_without_proof,
uint8 *result, const char **errstr) uint8 *result, const char **errstr)
{ {
uint8 StoredKey[SCRAM_KEY_LEN]; uint8 StoredKey[SCRAM_MAX_KEY_LEN];
uint8 ClientKey[SCRAM_KEY_LEN]; uint8 ClientKey[SCRAM_MAX_KEY_LEN];
uint8 ClientSignature[SCRAM_KEY_LEN]; uint8 ClientSignature[SCRAM_MAX_KEY_LEN];
int i; int i;
pg_hmac_ctx *ctx; pg_hmac_ctx *ctx;
ctx = pg_hmac_create(PG_SHA256); ctx = pg_hmac_create(state->hash_type);
if (ctx == NULL) if (ctx == NULL)
{ {
*errstr = pg_hmac_error(NULL); /* returns OOM */ *errstr = pg_hmac_error(NULL); /* returns OOM */
@ -777,18 +784,21 @@ calculate_client_proof(fe_scram_state *state,
* Calculate SaltedPassword, and store it in 'state' so that we can reuse * Calculate SaltedPassword, and store it in 'state' so that we can reuse
* it later in verify_server_signature. * it later in verify_server_signature.
*/ */
if (scram_SaltedPassword(state->password, state->salt, state->saltlen, if (scram_SaltedPassword(state->password, state->hash_type,
state->key_length, state->salt, state->saltlen,
state->iterations, state->SaltedPassword, state->iterations, state->SaltedPassword,
errstr) < 0 || errstr) < 0 ||
scram_ClientKey(state->SaltedPassword, ClientKey, errstr) < 0 || scram_ClientKey(state->SaltedPassword, state->hash_type,
scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey, errstr) < 0) state->key_length, ClientKey, errstr) < 0 ||
scram_H(ClientKey, state->hash_type, state->key_length,
StoredKey, errstr) < 0)
{ {
/* errstr is already filled here */ /* errstr is already filled here */
pg_hmac_free(ctx); pg_hmac_free(ctx);
return false; return false;
} }
if (pg_hmac_init(ctx, StoredKey, SCRAM_KEY_LEN) < 0 || if (pg_hmac_init(ctx, StoredKey, state->key_length) < 0 ||
pg_hmac_update(ctx, pg_hmac_update(ctx,
(uint8 *) state->client_first_message_bare, (uint8 *) state->client_first_message_bare,
strlen(state->client_first_message_bare)) < 0 || strlen(state->client_first_message_bare)) < 0 ||
@ -800,14 +810,14 @@ calculate_client_proof(fe_scram_state *state,
pg_hmac_update(ctx, pg_hmac_update(ctx,
(uint8 *) client_final_message_without_proof, (uint8 *) client_final_message_without_proof,
strlen(client_final_message_without_proof)) < 0 || strlen(client_final_message_without_proof)) < 0 ||
pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0) pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
{ {
*errstr = pg_hmac_error(ctx); *errstr = pg_hmac_error(ctx);
pg_hmac_free(ctx); pg_hmac_free(ctx);
return false; return false;
} }
for (i = 0; i < SCRAM_KEY_LEN; i++) for (i = 0; i < state->key_length; i++)
result[i] = ClientKey[i] ^ ClientSignature[i]; result[i] = ClientKey[i] ^ ClientSignature[i];
pg_hmac_free(ctx); pg_hmac_free(ctx);
@ -825,18 +835,19 @@ static bool
verify_server_signature(fe_scram_state *state, bool *match, verify_server_signature(fe_scram_state *state, bool *match,
const char **errstr) const char **errstr)
{ {
uint8 expected_ServerSignature[SCRAM_KEY_LEN]; uint8 expected_ServerSignature[SCRAM_MAX_KEY_LEN];
uint8 ServerKey[SCRAM_KEY_LEN]; uint8 ServerKey[SCRAM_MAX_KEY_LEN];
pg_hmac_ctx *ctx; pg_hmac_ctx *ctx;
ctx = pg_hmac_create(PG_SHA256); ctx = pg_hmac_create(state->hash_type);
if (ctx == NULL) if (ctx == NULL)
{ {
*errstr = pg_hmac_error(NULL); /* returns OOM */ *errstr = pg_hmac_error(NULL); /* returns OOM */
return false; return false;
} }
if (scram_ServerKey(state->SaltedPassword, ServerKey, errstr) < 0) if (scram_ServerKey(state->SaltedPassword, state->hash_type,
state->key_length, ServerKey, errstr) < 0)
{ {
/* errstr is filled already */ /* errstr is filled already */
pg_hmac_free(ctx); pg_hmac_free(ctx);
@ -844,7 +855,7 @@ verify_server_signature(fe_scram_state *state, bool *match,
} }
/* calculate ServerSignature */ /* calculate ServerSignature */
if (pg_hmac_init(ctx, ServerKey, SCRAM_KEY_LEN) < 0 || if (pg_hmac_init(ctx, ServerKey, state->key_length) < 0 ||
pg_hmac_update(ctx, pg_hmac_update(ctx,
(uint8 *) state->client_first_message_bare, (uint8 *) state->client_first_message_bare,
strlen(state->client_first_message_bare)) < 0 || strlen(state->client_first_message_bare)) < 0 ||
@ -857,7 +868,7 @@ verify_server_signature(fe_scram_state *state, bool *match,
(uint8 *) state->client_final_message_without_proof, (uint8 *) state->client_final_message_without_proof,
strlen(state->client_final_message_without_proof)) < 0 || strlen(state->client_final_message_without_proof)) < 0 ||
pg_hmac_final(ctx, expected_ServerSignature, pg_hmac_final(ctx, expected_ServerSignature,
sizeof(expected_ServerSignature)) < 0) state->key_length) < 0)
{ {
*errstr = pg_hmac_error(ctx); *errstr = pg_hmac_error(ctx);
pg_hmac_free(ctx); pg_hmac_free(ctx);
@ -867,7 +878,8 @@ verify_server_signature(fe_scram_state *state, bool *match,
pg_hmac_free(ctx); pg_hmac_free(ctx);
/* signature processed, so now check after it */ /* signature processed, so now check after it */
if (memcmp(expected_ServerSignature, state->ServerSignature, SCRAM_KEY_LEN) != 0) if (memcmp(expected_ServerSignature, state->ServerSignature,
state->key_length) != 0)
*match = false; *match = false;
else else
*match = true; *match = true;
@ -912,7 +924,8 @@ pg_fe_scram_build_secret(const char *password, const char **errstr)
return NULL; return NULL;
} }
result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN, result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN, saltbuf,
SCRAM_DEFAULT_SALT_LEN,
SCRAM_DEFAULT_ITERATIONS, password, SCRAM_DEFAULT_ITERATIONS, password,
errstr); errstr);