Skip to content
This repository was archived by the owner on Dec 12, 2025. It is now read-only.

Commit 658585c

Browse files
authored
Add functions to generate Scram Credentials (#102)
1 parent 9afd5f5 commit 658585c

File tree

2 files changed

+197
-0
lines changed

2 files changed

+197
-0
lines changed
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
package scramcredentials
2+
3+
import (
4+
"crypto/hmac"
5+
"crypto/md5"
6+
"crypto/sha1"
7+
"crypto/sha256"
8+
"encoding/base64"
9+
"encoding/hex"
10+
"fmt"
11+
"hash"
12+
13+
"github.com/xdg/stringprep"
14+
)
15+
16+
const (
17+
RFC5802MandatedSaltSize = 4
18+
19+
clientKeyInput = "Client Key" // specified in RFC 5802
20+
serverKeyInput = "Server Key" // specified in RFC 5802
21+
22+
// using the default MongoDB values for the number of iterations depending on mechanism
23+
scramSha1Iterations = 10000
24+
scramSha256Iterations = 15000
25+
)
26+
27+
type ScramCreds struct {
28+
IterationCount int `json:"iterationCount"`
29+
Salt string `json:"salt"`
30+
ServerKey string `json:"serverKey"`
31+
StoredKey string `json:"storedKey"`
32+
}
33+
34+
func ComputeScramSha256Creds(password string, salt []byte) (ScramCreds, error) {
35+
base64EncodedSalt := base64.StdEncoding.EncodeToString(salt)
36+
return computeScramCredentials(sha256.New, scramSha256Iterations, base64EncodedSalt, password)
37+
}
38+
39+
func ComputeScramSha1Creds(username, password string, salt []byte) (ScramCreds, error) {
40+
base64EncodedSalt := base64.StdEncoding.EncodeToString(salt)
41+
password = md5Hex(username + ":mongo:" + password)
42+
return computeScramCredentials(sha1.New, scramSha1Iterations, base64EncodedSalt, password)
43+
}
44+
45+
func md5Hex(s string) string {
46+
h := md5.New()
47+
h.Write([]byte(s))
48+
return hex.EncodeToString(h.Sum(nil))
49+
}
50+
51+
func generateSaltedPassword(hashConstructor func() hash.Hash, password string, salt []byte, iterationCount int) ([]byte, error) {
52+
preparedPassword, err := stringprep.SASLprep.Prepare(password)
53+
if err != nil {
54+
return nil, fmt.Errorf("error SASLprep'ing password: %s", err)
55+
}
56+
57+
result, err := hmacIteration(hashConstructor, []byte(preparedPassword), salt, iterationCount)
58+
if err != nil {
59+
return nil, fmt.Errorf("error running hmacIteration: %s", err)
60+
}
61+
return result, nil
62+
}
63+
64+
func hmacIteration(hashConstructor func() hash.Hash, input, salt []byte, iterationCount int) ([]byte, error) {
65+
hashSize := hashConstructor().Size()
66+
67+
// incorrect salt size will pass validation, but the credentials will be invalid. i.e. it will not
68+
// be possible to auth with the password provided to create the credentials.
69+
if len(salt) != hashSize-RFC5802MandatedSaltSize {
70+
return nil, fmt.Errorf("salt should have a size of %v bytes, but instead has a size of %v bytes", hashSize-RFC5802MandatedSaltSize, len(salt))
71+
}
72+
73+
startKey := append(salt, 0, 0, 0, 1)
74+
result := make([]byte, hashSize)
75+
76+
hmacHash := hmac.New(hashConstructor, input)
77+
if _, err := hmacHash.Write(startKey); err != nil {
78+
return nil, fmt.Errorf("error running hmacHash: %s", err)
79+
}
80+
81+
intermediateDigest := hmacHash.Sum(nil)
82+
83+
for i := 0; i < len(intermediateDigest); i++ {
84+
result[i] = intermediateDigest[i]
85+
}
86+
87+
for i := 1; i < iterationCount; i++ {
88+
hmacHash.Reset()
89+
if _, err := hmacHash.Write(intermediateDigest); err != nil {
90+
return nil, fmt.Errorf("error running hmacHash: %s", err)
91+
}
92+
93+
intermediateDigest = hmacHash.Sum(nil)
94+
95+
for i := 0; i < len(intermediateDigest); i++ {
96+
result[i] ^= intermediateDigest[i]
97+
}
98+
}
99+
100+
return result, nil
101+
}
102+
103+
func generateClientOrServerKey(hashConstructor func() hash.Hash, saltedPassword []byte, input string) ([]byte, error) {
104+
hmacHash := hmac.New(hashConstructor, saltedPassword)
105+
if _, err := hmacHash.Write([]byte(input)); err != nil {
106+
return nil, fmt.Errorf("error running hmacHash: %s", err)
107+
}
108+
109+
return hmacHash.Sum(nil), nil
110+
}
111+
112+
func generateStoredKey(hashConstructor func() hash.Hash, clientKey []byte) ([]byte, error) {
113+
h := hashConstructor()
114+
if _, err := h.Write(clientKey); err != nil {
115+
return nil, fmt.Errorf("error hashing: %s", err)
116+
}
117+
return h.Sum(nil), nil
118+
}
119+
120+
func generateSecrets(hashConstructor func() hash.Hash, password string, salt []byte, iterationCount int) (storedKey, serverKey []byte, err error) {
121+
saltedPassword, err := generateSaltedPassword(hashConstructor, password, salt, iterationCount)
122+
if err != nil {
123+
return nil, nil, fmt.Errorf("error generating salted password: %s", err)
124+
}
125+
126+
clientKey, err := generateClientOrServerKey(hashConstructor, saltedPassword, clientKeyInput)
127+
if err != nil {
128+
return nil, nil, fmt.Errorf("error generating client key: %s", err)
129+
}
130+
131+
storedKey, err = generateStoredKey(hashConstructor, clientKey)
132+
if err != nil {
133+
return nil, nil, fmt.Errorf("error generating stored key: %s", err)
134+
}
135+
136+
serverKey, err = generateClientOrServerKey(hashConstructor, saltedPassword, serverKeyInput)
137+
if err != nil {
138+
return nil, nil, fmt.Errorf("error generating server key: %s", err)
139+
}
140+
141+
return storedKey, serverKey, err
142+
}
143+
144+
func generateB64EncodedSecrets(hashConstructor func() hash.Hash, password, b64EncodedSalt string, iterationCount int) (storedKey, serverKey string, err error) {
145+
salt, err := base64.StdEncoding.DecodeString(b64EncodedSalt)
146+
if err != nil {
147+
return "", "", fmt.Errorf("error decoding salt: %s", err)
148+
}
149+
150+
unencodedStoredKey, unencodedServerKey, err := generateSecrets(hashConstructor, password, salt, iterationCount)
151+
if err != nil {
152+
return "", "", fmt.Errorf("error generating secrets: %s", err)
153+
}
154+
155+
storedKey = base64.StdEncoding.EncodeToString(unencodedStoredKey)
156+
serverKey = base64.StdEncoding.EncodeToString(unencodedServerKey)
157+
return storedKey, serverKey, nil
158+
}
159+
160+
// password should be encrypted in the case of SCRAM-SHA-1 and unencrypted in the case of SCRAM-SHA-256
161+
func computeScramCredentials(hashConstructor func() hash.Hash, iterationCount int, base64EncodedSalt string, password string) (ScramCreds, error) {
162+
storedKey, serverKey, err := generateB64EncodedSecrets(hashConstructor, password, base64EncodedSalt, iterationCount)
163+
if err != nil {
164+
return ScramCreds{}, fmt.Errorf("error generating SCRAM-SHA keys: %s", err)
165+
}
166+
167+
return ScramCreds{IterationCount: iterationCount, Salt: base64EncodedSalt, StoredKey: storedKey, ServerKey: serverKey}, nil
168+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package scramcredentials
2+
3+
import (
4+
"crypto/sha1"
5+
"crypto/sha256"
6+
"hash"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
func TestScramSha1SecretsMatch(t *testing.T) {
13+
assertSecretsMatch(t, sha1.New, "caeec61ba3b15b15b188d29e876514e8", 10, "S3cuk2Rnu/MlbewzxrmmVA==", "sYBa3XlSPKNrgjzhOuEuRlJY4dQ=", "zuAxRSQb3gZkbaB1IGlusK4jy1M=")
14+
assertSecretsMatch(t, sha1.New, "4d9625b297999b3ca786d4a9622d04f1", 10, "kW9KbCQiCOll5Ljd44cjkQ==", "VJ8fFVHkPltibvT//mG/OWw44Hc=", "ceDRsgj9HezpZ4/vkZX8GZNNN50=")
15+
assertSecretsMatch(t, sha1.New, "fd0a78e418dcef39f8c768222810b894", 10, "hhX6xsoID6FeWjXncuNgAg==", "TxgaZJ4cIn+S9EfTcc9IOEG7RGc=", "d6/qjwBs0qkPKfUAjSh5eemsySE=")
16+
}
17+
func TestScramSha256SecretsMatch(t *testing.T) {
18+
assertSecretsMatch(t, sha256.New, "Gy4ZNMr-SYEsEpAEZv", 15000, "ajdf1E1QTsNAQdBEodB4vzQOFuvcw9K6PmouVg==", "/pBk9XBwSm9UyeQmyJ3LfogfHu9Z/XTjGmRhQDHx/4I=", "Avm8mjtMyg659LAyeD4VmuzQb5lxL5iy3dCuzfscfMc=")
19+
assertSecretsMatch(t, sha256.New, "Y9SPYSJYUJB_", 15000, "Oplsu3uju+lYyX4apKb0K6xfHpmFtH99Oyk4Ow==", "oTJhml8KKZUSt9k4tg+tS6D/ygR+a2Xfo8JKjTpQoAI=", "SUfA2+SKL35u665WY5NnJJmA9L5dHu/TnWXX/0nm42Y=")
20+
assertSecretsMatch(t, sha256.New, "157VDZr0h-Pz-wj72", 15000, "P/4xs3anygxu3/l2p35CSBe4Z47IV/FtE/e44A==", "jOb27nFF72SQoY7WUqKXOTR4e8jETXxMS67SONrcbjA=", "3FnslkgUweautAfPRCOEjhS+YbUYUNmdDQUGxB+oaFE=")
21+
assertSecretsMatch(t, sha256.New, "P8z1sDfELCePTNbVqX", 15000, "RPNhenwTHlqW5OE597XpuwvPLaiecPpYFa58Pg==", "sJ8UhQRszLNo15cOe62+HLjt2NxmSkJGjdJpclTIMBs=", "CSg02ODAvh9+swUHoimXcDsT9lLp/A5IhQXavXl7+qA=")
22+
}
23+
24+
func assertSecretsMatch(t *testing.T, hash func() hash.Hash, passwordHash string, iterationCount int, salt, storedKey, serverKey string) {
25+
computedStoredKey, computedServerKey, err := generateB64EncodedSecrets(hash, passwordHash, salt, iterationCount)
26+
assert.NoError(t, err)
27+
assert.Equal(t, computedStoredKey, storedKey)
28+
assert.Equal(t, computedServerKey, serverKey)
29+
}

0 commit comments

Comments
 (0)