Skip to content
This repository was archived by the owner on Dec 12, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions pkg/authentication/scramcredentials/scram_credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package scramcredentials

import (
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"hash"

"github.com/xdg/stringprep"
)

const (
RFC5802MandatedSaltSize = 4

clientKeyInput = "Client Key" // specified in RFC 5802
serverKeyInput = "Server Key" // specified in RFC 5802

// using the default MongoDB values for the number of iterations depending on mechanism
scramSha1Iterations = 10000
scramSha256Iterations = 15000
)

type ScramCreds struct {
IterationCount int `json:"iterationCount"`
Salt string `json:"salt"`
ServerKey string `json:"serverKey"`
StoredKey string `json:"storedKey"`
}

func ComputeScramSha256Creds(password string, salt []byte) (ScramCreds, error) {
base64EncodedSalt := base64.StdEncoding.EncodeToString(salt)
return computeScramCredentials(sha256.New, scramSha256Iterations, base64EncodedSalt, password)
}

func ComputeScramSha1Creds(username, password string, salt []byte) (ScramCreds, error) {
base64EncodedSalt := base64.StdEncoding.EncodeToString(salt)
password = md5Hex(username + ":mongo:" + password)
return computeScramCredentials(sha1.New, scramSha1Iterations, base64EncodedSalt, password)
}

func md5Hex(s string) string {
h := md5.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}

func generateSaltedPassword(hashConstructor func() hash.Hash, password string, salt []byte, iterationCount int) ([]byte, error) {
preparedPassword, err := stringprep.SASLprep.Prepare(password)
if err != nil {
return nil, fmt.Errorf("error SASLprep'ing password: %s", err)
}

result, err := hmacIteration(hashConstructor, []byte(preparedPassword), salt, iterationCount)
if err != nil {
return nil, fmt.Errorf("error running hmacIteration: %s", err)
}
return result, nil
}

func hmacIteration(hashConstructor func() hash.Hash, input, salt []byte, iterationCount int) ([]byte, error) {
hashSize := hashConstructor().Size()

// incorrect salt size will pass validation, but the credentials will be invalid. i.e. it will not
// be possible to auth with the password provided to create the credentials.
if len(salt) != hashSize-RFC5802MandatedSaltSize {
return nil, fmt.Errorf("salt should have a size of %v bytes, but instead has a size of %v bytes", hashSize-RFC5802MandatedSaltSize, len(salt))
}

startKey := append(salt, 0, 0, 0, 1)
result := make([]byte, hashSize)

hmacHash := hmac.New(hashConstructor, input)
if _, err := hmacHash.Write(startKey); err != nil {
return nil, fmt.Errorf("error running hmacHash: %s", err)
}

intermediateDigest := hmacHash.Sum(nil)

for i := 0; i < len(intermediateDigest); i++ {
result[i] = intermediateDigest[i]
}

for i := 1; i < iterationCount; i++ {
hmacHash.Reset()
if _, err := hmacHash.Write(intermediateDigest); err != nil {
return nil, fmt.Errorf("error running hmacHash: %s", err)
}

intermediateDigest = hmacHash.Sum(nil)

for i := 0; i < len(intermediateDigest); i++ {
result[i] ^= intermediateDigest[i]
}
}

return result, nil
}

func generateClientOrServerKey(hashConstructor func() hash.Hash, saltedPassword []byte, input string) ([]byte, error) {
hmacHash := hmac.New(hashConstructor, saltedPassword)
if _, err := hmacHash.Write([]byte(input)); err != nil {
return nil, fmt.Errorf("error running hmacHash: %s", err)
}

return hmacHash.Sum(nil), nil
}

func generateStoredKey(hashConstructor func() hash.Hash, clientKey []byte) ([]byte, error) {
h := hashConstructor()
if _, err := h.Write(clientKey); err != nil {
return nil, fmt.Errorf("error hashing: %s", err)
}
return h.Sum(nil), nil
}

func generateSecrets(hashConstructor func() hash.Hash, password string, salt []byte, iterationCount int) (storedKey, serverKey []byte, err error) {
saltedPassword, err := generateSaltedPassword(hashConstructor, password, salt, iterationCount)
if err != nil {
return nil, nil, fmt.Errorf("error generating salted password: %s", err)
}

clientKey, err := generateClientOrServerKey(hashConstructor, saltedPassword, clientKeyInput)
if err != nil {
return nil, nil, fmt.Errorf("error generating client key: %s", err)
}

storedKey, err = generateStoredKey(hashConstructor, clientKey)
if err != nil {
return nil, nil, fmt.Errorf("error generating stored key: %s", err)
}

serverKey, err = generateClientOrServerKey(hashConstructor, saltedPassword, serverKeyInput)
if err != nil {
return nil, nil, fmt.Errorf("error generating server key: %s", err)
}

return storedKey, serverKey, err
}

func generateB64EncodedSecrets(hashConstructor func() hash.Hash, password, b64EncodedSalt string, iterationCount int) (storedKey, serverKey string, err error) {
salt, err := base64.StdEncoding.DecodeString(b64EncodedSalt)
if err != nil {
return "", "", fmt.Errorf("error decoding salt: %s", err)
}

unencodedStoredKey, unencodedServerKey, err := generateSecrets(hashConstructor, password, salt, iterationCount)
if err != nil {
return "", "", fmt.Errorf("error generating secrets: %s", err)
}

storedKey = base64.StdEncoding.EncodeToString(unencodedStoredKey)
serverKey = base64.StdEncoding.EncodeToString(unencodedServerKey)
return storedKey, serverKey, nil
}

// password should be encrypted in the case of SCRAM-SHA-1 and unencrypted in the case of SCRAM-SHA-256
func computeScramCredentials(hashConstructor func() hash.Hash, iterationCount int, base64EncodedSalt string, password string) (ScramCreds, error) {
storedKey, serverKey, err := generateB64EncodedSecrets(hashConstructor, password, base64EncodedSalt, iterationCount)
if err != nil {
return ScramCreds{}, fmt.Errorf("error generating SCRAM-SHA keys: %s", err)
}

return ScramCreds{IterationCount: iterationCount, Salt: base64EncodedSalt, StoredKey: storedKey, ServerKey: serverKey}, nil
}
29 changes: 29 additions & 0 deletions pkg/authentication/scramcredentials/scram_credentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package scramcredentials

import (
"crypto/sha1"
"crypto/sha256"
"hash"
"testing"

"github.com/stretchr/testify/assert"
)

func TestScramSha1SecretsMatch(t *testing.T) {
assertSecretsMatch(t, sha1.New, "caeec61ba3b15b15b188d29e876514e8", 10, "S3cuk2Rnu/MlbewzxrmmVA==", "sYBa3XlSPKNrgjzhOuEuRlJY4dQ=", "zuAxRSQb3gZkbaB1IGlusK4jy1M=")
assertSecretsMatch(t, sha1.New, "4d9625b297999b3ca786d4a9622d04f1", 10, "kW9KbCQiCOll5Ljd44cjkQ==", "VJ8fFVHkPltibvT//mG/OWw44Hc=", "ceDRsgj9HezpZ4/vkZX8GZNNN50=")
assertSecretsMatch(t, sha1.New, "fd0a78e418dcef39f8c768222810b894", 10, "hhX6xsoID6FeWjXncuNgAg==", "TxgaZJ4cIn+S9EfTcc9IOEG7RGc=", "d6/qjwBs0qkPKfUAjSh5eemsySE=")
}
func TestScramSha256SecretsMatch(t *testing.T) {
assertSecretsMatch(t, sha256.New, "Gy4ZNMr-SYEsEpAEZv", 15000, "ajdf1E1QTsNAQdBEodB4vzQOFuvcw9K6PmouVg==", "/pBk9XBwSm9UyeQmyJ3LfogfHu9Z/XTjGmRhQDHx/4I=", "Avm8mjtMyg659LAyeD4VmuzQb5lxL5iy3dCuzfscfMc=")
assertSecretsMatch(t, sha256.New, "Y9SPYSJYUJB_", 15000, "Oplsu3uju+lYyX4apKb0K6xfHpmFtH99Oyk4Ow==", "oTJhml8KKZUSt9k4tg+tS6D/ygR+a2Xfo8JKjTpQoAI=", "SUfA2+SKL35u665WY5NnJJmA9L5dHu/TnWXX/0nm42Y=")
assertSecretsMatch(t, sha256.New, "157VDZr0h-Pz-wj72", 15000, "P/4xs3anygxu3/l2p35CSBe4Z47IV/FtE/e44A==", "jOb27nFF72SQoY7WUqKXOTR4e8jETXxMS67SONrcbjA=", "3FnslkgUweautAfPRCOEjhS+YbUYUNmdDQUGxB+oaFE=")
assertSecretsMatch(t, sha256.New, "P8z1sDfELCePTNbVqX", 15000, "RPNhenwTHlqW5OE597XpuwvPLaiecPpYFa58Pg==", "sJ8UhQRszLNo15cOe62+HLjt2NxmSkJGjdJpclTIMBs=", "CSg02ODAvh9+swUHoimXcDsT9lLp/A5IhQXavXl7+qA=")
}
Comment on lines +17 to +22
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test cases are using the salt, stored key and sever keys of users created through the mongo shell. E.g.

 db.createUser({user: "user4", 
pwd: "P8z1sDfELCePTNbVqX", 
roles: [ { role: "readWrite", db: "admin" } ]})

db.system.users.findOne({user: "user4"})

outputs

example-mongodb:PRIMARY> db.system.users.findOne({user: "user4"})
{
	"_id" : "admin.user4",
	"userId" : UUID("6b5aec1f-312b-491a-aeaa-b00cc374cf52"),
	"user" : "user4",
	"db" : "admin",
	"credentials" : {
		"SCRAM-SHA-1" : {
			"iterationCount" : 10000,
			"salt" : "Cd9RJHbdRt5b+kPf8YR0WQ==",
			"storedKey" : "v3BB0CLs5AwLiAgIfE/yqDujoI4=",
			"serverKey" : "yLqlYgg5GW04nAeBmFpgLU/vtto="
		},
		"SCRAM-SHA-256" : {
			"iterationCount" : 15000,
			"salt" : "RPNhenwTHlqW5OE597XpuwvPLaiecPpYFa58Pg==",
			"storedKey" : "sJ8UhQRszLNo15cOe62+HLjt2NxmSkJGjdJpclTIMBs=",
			"serverKey" : "CSg02ODAvh9+swUHoimXcDsT9lLp/A5IhQXavXl7+qA="
		}
	},
	"roles" : [
		{
			"role" : "readWrite",
			"db" : "admin"
		}
	]
}


func assertSecretsMatch(t *testing.T, hash func() hash.Hash, passwordHash string, iterationCount int, salt, storedKey, serverKey string) {
computedStoredKey, computedServerKey, err := generateB64EncodedSecrets(hash, passwordHash, salt, iterationCount)
assert.NoError(t, err)
assert.Equal(t, computedStoredKey, storedKey)
assert.Equal(t, computedServerKey, serverKey)
}