Misc SCRAM code cleanups.
authorHeikki Linnakangas
Fri, 28 Apr 2017 12:04:02 +0000 (15:04 +0300)
committerHeikki Linnakangas
Fri, 28 Apr 2017 12:22:38 +0000 (15:22 +0300)
* Move computation of SaltedPassword to a separate function from
  scram_ClientOrServerKey(). This saves a lot of cycles in libpq, by
  computing SaltedPassword only once per authentication. (Computing
  SaltedPassword is expensive by design.)

* Split scram_ClientOrServerKey() into two functions. Improves
  readability, by making the calling code less verbose.

* Rename "server proof" to "server signature", to better match the
  nomenclature used in RFC 5802.

* Rename SCRAM_SALT_LEN to SCRAM_DEFAULT_SALT_LEN, to make it more clear
  that the salt can be of any length, and the constant only specifies how
  long a salt we use when we generate a new verifier. Also rename
  SCRAM_ITERATIONS_DEFAULT to SCRAM_DEFAULT_ITERATIONS, for consistency.

These things caught my eye while working on other upcoming changes.

src/backend/libpq/auth-scram.c
src/common/scram-common.c
src/include/common/scram-common.h
src/interfaces/libpq/fe-auth-scram.c

index 16bea446e37dae75750b41faa360337fd91234be..5c85af943cdcbd2e62a6d073fdf38b8081a969ff 100644 (file)
@@ -396,7 +396,8 @@ scram_build_verifier(const char *username, const char *password,
 {
    char       *prep_password = NULL;
    pg_saslprep_rc rc;
-   char        saltbuf[SCRAM_SALT_LEN];
+   char        saltbuf[SCRAM_DEFAULT_SALT_LEN];
+   uint8       salted_password[SCRAM_KEY_LEN];
    uint8       keybuf[SCRAM_KEY_LEN];
    char       *encoded_salt;
    char       *encoded_storedkey;
@@ -414,10 +415,10 @@ scram_build_verifier(const char *username, const char *password,
        password = (const char *) prep_password;
 
    if (iterations <= 0)
-       iterations = SCRAM_ITERATIONS_DEFAULT;
+       iterations = SCRAM_DEFAULT_ITERATIONS;
 
    /* Generate salt, and encode it in base64 */
-   if (!pg_backend_random(saltbuf, SCRAM_SALT_LEN))
+   if (!pg_backend_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
    {
        ereport(LOG,
                (errcode(ERRCODE_INTERNAL_ERROR),
@@ -425,13 +426,14 @@ scram_build_verifier(const char *username, const char *password,
        return NULL;
    }
 
-   encoded_salt = palloc(pg_b64_enc_len(SCRAM_SALT_LEN) + 1);
-   encoded_len = pg_b64_encode(saltbuf, SCRAM_SALT_LEN, encoded_salt);
+   encoded_salt = palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1);
+   encoded_len = pg_b64_encode(saltbuf, SCRAM_DEFAULT_SALT_LEN, encoded_salt);
    encoded_salt[encoded_len] = '\0';
 
    /* Calculate StoredKey, and encode it in base64 */
-   scram_ClientOrServerKey(password, saltbuf, SCRAM_SALT_LEN,
-                           iterations, SCRAM_CLIENT_KEY_NAME, keybuf);
+   scram_SaltedPassword(password, saltbuf, SCRAM_DEFAULT_SALT_LEN,
+                        iterations, salted_password);
+   scram_ClientKey(salted_password, keybuf);
    scram_H(keybuf, SCRAM_KEY_LEN, keybuf);     /* StoredKey */
 
    encoded_storedkey = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
@@ -440,8 +442,7 @@ scram_build_verifier(const char *username, const char *password,
    encoded_storedkey[encoded_len] = '\0';
 
    /* And same for ServerKey */
-   scram_ClientOrServerKey(password, saltbuf, SCRAM_SALT_LEN, iterations,
-                           SCRAM_SERVER_KEY_NAME, keybuf);
+   scram_ServerKey(salted_password, keybuf);
 
    encoded_serverkey = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
    encoded_len = pg_b64_encode((const char *) keybuf, SCRAM_KEY_LEN,
@@ -473,6 +474,7 @@ scram_verify_plain_password(const char *username, const char *password,
    char       *salt;
    int         saltlen;
    int         iterations;
+   uint8       salted_password[SCRAM_KEY_LEN];
    uint8       stored_key[SCRAM_KEY_LEN];
    uint8       server_key[SCRAM_KEY_LEN];
    uint8       computed_key[SCRAM_KEY_LEN];
@@ -502,9 +504,9 @@ scram_verify_plain_password(const char *username, const char *password,
    if (rc == SASLPREP_SUCCESS)
        password = prep_password;
 
-   /* Compute Server key based on the user-supplied plaintext password */
-   scram_ClientOrServerKey(password, salt, saltlen, iterations,
-                           SCRAM_SERVER_KEY_NAME, computed_key);
+   /* Compute Server Key based on the user-supplied plaintext password */
+   scram_SaltedPassword(password, salt, saltlen, iterations, salted_password);
+   scram_ServerKey(salted_password, computed_key);
 
    if (prep_password)
        pfree(prep_password);
@@ -630,12 +632,12 @@ mock_scram_verifier(const char *username, int *iterations, char **salt,
    /* Generate deterministic salt */
    raw_salt = scram_MockSalt(username);
 
-   encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_SALT_LEN) + 1);
-   encoded_len = pg_b64_encode(raw_salt, SCRAM_SALT_LEN, encoded_salt);
+   encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1);
+   encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt);
    encoded_salt[encoded_len] = '\0';
 
    *salt = encoded_salt;
-   *iterations = SCRAM_ITERATIONS_DEFAULT;
+   *iterations = SCRAM_DEFAULT_ITERATIONS;
 
    /* StoredKey and ServerKey are not used in a doomed authentication */
    memset(stored_key, 0, SCRAM_KEY_LEN);
@@ -1179,7 +1181,7 @@ build_server_final_message(scram_state *state)
 /*
  * Determinisitcally generate salt for mock authentication, using a SHA256
  * hash based on the username and a cluster-level secret key.  Returns a
- * pointer to a static buffer of size SCRAM_SALT_LEN.
+ * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN.
  */
 static char *
 scram_MockSalt(const char *username)
@@ -1194,7 +1196,7 @@ scram_MockSalt(const char *username)
     * not larger the SHA256 digest length. If the salt is smaller, the caller
     * will just ignore the extra data))
     */
-   StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_SALT_LEN,
+   StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
                     "salt length greater than SHA256 digest length");
 
    pg_sha256_init(&ctx);
index df9f0eaa90d1c6106613fbec6e995d5b963b6384..a8ea44944c493749ce88b82b73fa3322efdc3b21 100644 (file)
@@ -98,14 +98,16 @@ scram_HMAC_final(uint8 *result, scram_HMAC_ctx *ctx)
 }
 
 /*
- * Iterate hash calculation of HMAC entry using given salt.
- * scram_Hi() is essentially PBKDF2 (see RFC2898) with HMAC() as the
- * pseudorandom function.
+ * Calculate SaltedPassword.
+ *
+ * The password should already be normalized by SASLprep.
  */
-static void
-scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 *result)
+void
+scram_SaltedPassword(const char *password,
+                    const char *salt, int saltlen, int iterations,
+                    uint8 *result)
 {
-   int         str_len = strlen(str);
+   int         password_len = strlen(password);
    uint32      one = htonl(1);
    int         i,
                j;
@@ -113,8 +115,14 @@ scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 *
    uint8       Ui_prev[SCRAM_KEY_LEN];
    scram_HMAC_ctx hmac_ctx;
 
+   /*
+    * Iterate hash calculation of HMAC entry using given salt.  This is
+    * essentially PBKDF2 (see RFC2898) with HMAC() as the pseudorandom
+    * function.
+    */
+
    /* First iteration */
-   scram_HMAC_init(&hmac_ctx, (uint8 *) str, str_len);
+   scram_HMAC_init(&hmac_ctx, (uint8 *) password, password_len);
    scram_HMAC_update(&hmac_ctx, salt, saltlen);
    scram_HMAC_update(&hmac_ctx, (char *) &one, sizeof(uint32));
    scram_HMAC_final(Ui_prev, &hmac_ctx);
@@ -123,7 +131,7 @@ scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 *
    /* Subsequent iterations */
    for (i = 2; i <= iterations; i++)
    {
-       scram_HMAC_init(&hmac_ctx, (uint8 *) str, str_len);
+       scram_HMAC_init(&hmac_ctx, (uint8 *) password, password_len);
        scram_HMAC_update(&hmac_ctx, (const char *) Ui_prev, SCRAM_KEY_LEN);
        scram_HMAC_final(Ui, &hmac_ctx);
        for (j = 0; j < SCRAM_KEY_LEN; j++)
@@ -148,20 +156,27 @@ scram_H(const uint8 *input, int len, uint8 *result)
 }
 
 /*
- * Calculate ClientKey or ServerKey.
- *
- * The password should already be normalized by SASLprep.
+ * Calculate ClientKey.
  */
 void
-scram_ClientOrServerKey(const char *password,
-                       const char *salt, int saltlen, int iterations,
-                       const char *keystr, uint8 *result)
+scram_ClientKey(const uint8 *salted_password, uint8 *result)
+{
+   scram_HMAC_ctx ctx;
+
+   scram_HMAC_init(&ctx, salted_password, SCRAM_KEY_LEN);
+   scram_HMAC_update(&ctx, "Client Key", strlen("Client Key"));
+   scram_HMAC_final(result, &ctx);
+}
+
+/*
+ * Calculate ServerKey.
+ */
+void
+scram_ServerKey(const uint8 *salted_password, uint8 *result)
 {
-   uint8       keybuf[SCRAM_KEY_LEN];
    scram_HMAC_ctx ctx;
 
-   scram_Hi(password, salt, saltlen, iterations, keybuf);
-   scram_HMAC_init(&ctx, keybuf, SCRAM_KEY_LEN);
-   scram_HMAC_update(&ctx, keystr, strlen(keystr));
+   scram_HMAC_init(&ctx, salted_password, SCRAM_KEY_LEN);
+   scram_HMAC_update(&ctx, "Server Key", strlen("Server Key"));
    scram_HMAC_final(result, &ctx);
 }
index 6740069eee18368715748600c18a07dcddc19060..656d9e1e6b1378e70d6bfb5f1c2d29c84282fa02 100644 (file)
 #define SCRAM_RAW_NONCE_LEN            10
 
 /* length of salt when generating new verifiers */
-#define SCRAM_SALT_LEN             10
+#define SCRAM_DEFAULT_SALT_LEN     10
 
 /* default number of iterations when generating verifier */
-#define SCRAM_ITERATIONS_DEFAULT   4096
-
-/* Base name of keys used for proof generation */
-#define SCRAM_SERVER_KEY_NAME "Server Key"
-#define SCRAM_CLIENT_KEY_NAME "Client Key"
+#define SCRAM_DEFAULT_ITERATIONS   4096
 
 /*
  * Context data for HMAC used in SCRAM authentication.
@@ -51,9 +47,10 @@ extern void scram_HMAC_init(scram_HMAC_ctx *ctx, const uint8 *key, int keylen);
 extern void scram_HMAC_update(scram_HMAC_ctx *ctx, const char *str, int slen);
 extern void scram_HMAC_final(uint8 *result, scram_HMAC_ctx *ctx);
 
+extern void scram_SaltedPassword(const char *password, const char *salt,
+                       int saltlen, int iterations, uint8 *result);
 extern void scram_H(const uint8 *str, int len, uint8 *result);
-extern void scram_ClientOrServerKey(const char *password, const char *salt,
-                       int saltlen, int iterations,
-                       const char *keystr, uint8 *result);
+extern void scram_ClientKey(const uint8 *salted_password, uint8 *result);
+extern void scram_ServerKey(const uint8 *salted_password, uint8 *result);
 
 #endif   /* SCRAM_COMMON_H */
index c56e91e0e04bdccd314496144f052f69833cf4a4..be271ce8ac01d3544c8648effd97f141f8f889f1 100644 (file)
@@ -46,6 +46,7 @@ typedef struct
    char       *password;
 
    /* We construct these */
+   uint8       SaltedPassword[SCRAM_KEY_LEN];
    char       *client_nonce;
    char       *client_first_message_bare;
    char       *client_final_message_without_proof;
@@ -59,7 +60,7 @@ typedef struct
 
    /* These come from the server-final message */
    char       *server_final_message;
-   char        ServerProof[SCRAM_KEY_LEN];
+   char        ServerSignature[SCRAM_KEY_LEN];
 } fe_scram_state;
 
 static bool read_server_first_message(fe_scram_state *state, char *input,
@@ -70,7 +71,7 @@ static char *build_client_first_message(fe_scram_state *state,
                           PQExpBuffer errormessage);
 static char *build_client_final_message(fe_scram_state *state,
                           PQExpBuffer errormessage);
-static bool verify_server_proof(fe_scram_state *state);
+static bool verify_server_signature(fe_scram_state *state);
 static void calculate_client_proof(fe_scram_state *state,
                       const char *client_final_message_without_proof,
                       uint8 *result);
@@ -216,12 +217,12 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
                goto error;
 
            /*
-            * Verify server proof, to make sure we're talking to the genuine
-            * server.  XXX: A fake server could simply not require
+            * Verify server signature, to make sure we're talking to the
+            * genuine server.  XXX: A fake server could simply not require
             * authentication, though.  There is currently no option in libpq
             * to reject a connection, if SCRAM authentication did not happen.
             */
-           if (verify_server_proof(state))
+           if (verify_server_signature(state))
                *success = true;
            else
            {
@@ -486,12 +487,11 @@ read_server_first_message(fe_scram_state *state, char *input,
  * Read the final exchange message coming from the server.
  */
 static bool
-read_server_final_message(fe_scram_state *state,
-                         char *input,
+read_server_final_message(fe_scram_state *state, char *input,
                          PQExpBuffer errormessage)
 {
-   char       *encoded_server_proof;
-   int         server_proof_len;
+   char       *encoded_server_signature;
+   int         server_signature_len;
 
    state->server_final_message = strdup(input);
    if (!state->server_final_message)
@@ -513,8 +513,8 @@ read_server_final_message(fe_scram_state *state,
    }
 
    /* Parse the message. */
-   encoded_server_proof = read_attr_value(&input, 'v', errormessage);
-   if (encoded_server_proof == NULL)
+   encoded_server_signature = read_attr_value(&input, 'v', errormessage);
+   if (encoded_server_signature == NULL)
    {
        /* read_attr_value() has generated an error message */
        return false;
@@ -524,13 +524,13 @@ read_server_final_message(fe_scram_state *state,
        printfPQExpBuffer(errormessage,
                          libpq_gettext("malformed SCRAM message (garbage at end of server-final-message)\n"));
 
-   server_proof_len = pg_b64_decode(encoded_server_proof,
-                                    strlen(encoded_server_proof),
-                                    state->ServerProof);
-   if (server_proof_len != SCRAM_KEY_LEN)
+   server_signature_len = pg_b64_decode(encoded_server_signature,
+                                        strlen(encoded_server_signature),
+                                        state->ServerSignature);
+   if (server_signature_len != SCRAM_KEY_LEN)
    {
        printfPQExpBuffer(errormessage,
-         libpq_gettext("malformed SCRAM message (invalid server proof)\n"));
+                         libpq_gettext("malformed SCRAM message (invalid server signature)\n"));
        return false;
    }
 
@@ -552,8 +552,14 @@ calculate_client_proof(fe_scram_state *state,
    int         i;
    scram_HMAC_ctx ctx;
 
-   scram_ClientOrServerKey(state->password, state->salt, state->saltlen,
-                       state->iterations, SCRAM_CLIENT_KEY_NAME, ClientKey);
+   /*
+    * Calculate SaltedPassword, and store it in 'state' so that we can reuse
+    * it later in verify_server_signature.
+    */
+   scram_SaltedPassword(state->password, state->salt, state->saltlen,
+                        state->iterations, state->SaltedPassword);
+
+   scram_ClientKey(state->SaltedPassword, ClientKey);
    scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey);
 
    scram_HMAC_init(&ctx, StoredKey, SCRAM_KEY_LEN);
@@ -575,19 +581,17 @@ calculate_client_proof(fe_scram_state *state,
 }
 
 /*
- * Validate the server proof, received as part of the final exchange message
- * received from the server.
+ * Validate the server signature, received as part of the final exchange
+ * message received from the server.
  */
 static bool
-verify_server_proof(fe_scram_state *state)
+verify_server_signature(fe_scram_state *state)
 {
-   uint8       ServerSignature[SCRAM_KEY_LEN];
+   uint8       expected_ServerSignature[SCRAM_KEY_LEN];
    uint8       ServerKey[SCRAM_KEY_LEN];
    scram_HMAC_ctx ctx;
 
-   scram_ClientOrServerKey(state->password, state->salt, state->saltlen,
-                           state->iterations, SCRAM_SERVER_KEY_NAME,
-                           ServerKey);
+   scram_ServerKey(state->SaltedPassword, ServerKey);
 
    /* calculate ServerSignature */
    scram_HMAC_init(&ctx, ServerKey, SCRAM_KEY_LEN);
@@ -602,9 +606,9 @@ verify_server_proof(fe_scram_state *state)
    scram_HMAC_update(&ctx,
                      state->client_final_message_without_proof,
                      strlen(state->client_final_message_without_proof));
-   scram_HMAC_final(ServerSignature, &ctx);
+   scram_HMAC_final(expected_ServerSignature, &ctx);
 
-   if (memcmp(ServerSignature, state->ServerProof, SCRAM_KEY_LEN) != 0)
+   if (memcmp(expected_ServerSignature, state->ServerSignature, SCRAM_KEY_LEN) != 0)
        return false;
 
    return true;