From 3082e5d66a96b4141481bda1f7473ff1606bebfe Mon Sep 17 00:00:00 2001 From: Huy Doan Date: Fri, 29 Jul 2022 04:43:05 +0700 Subject: [PATCH 1/3] working on channel binding --- scram/private/types.nim | 16 +++++++++++++- scram/private/utils.nim | 43 ++++++++++++++++++++++++++++++++++-- scram/server.nim | 48 ++++++++++++++++++++++++++++------------- 3 files changed, 89 insertions(+), 18 deletions(-) diff --git a/scram/private/types.nim b/scram/private/types.nim index e916bc5..974c60b 100644 --- a/scram/private/types.nim +++ b/scram/private/types.nim @@ -1,3 +1,8 @@ +from net import Socket +from asyncnet import AsyncSocket + +export Socket, AsyncSocket + type ScramError* = object of CatchableError @@ -14,8 +19,17 @@ type FIRST_CLIENT_MESSAGE_HANDLED ENDED + AnySocket* = Socket|AsyncSocket + + ChannelType* = enum + TLS_NONE + TLS_SERVER_END_POINT + TLS_UNIQUE + TLS_UNIQUE_FOR_TELNET + TLS_EXPORT + const GS2_HEADER* = "n,," INT_1* = "\x00\x00\x00\x01" CLIENT_KEY* = "Client Key" - SERVER_KEY* = "Server Key" + SERVER_KEY* = "Server Key" \ No newline at end of file diff --git a/scram/private/utils.nim b/scram/private/utils.nim index d7225db..8ea64de 100644 --- a/scram/private/utils.nim +++ b/scram/private/utils.nim @@ -1,8 +1,12 @@ -import random, base64, strutils, types, hmac, bitops +import random, base64, strutils, types, hmac, bitops, openssl from md5 import MD5Digest from sha1 import Sha1Digest from nimSHA2 import Sha224Digest, Sha256Digest, Sha384Digest, Sha512Digest + +proc SSL_get_finished*(ssl: SslCtx, buf: cstring, count: csize_t): csize_t {.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_get_peer_finished*(ssl: SslCtx, buf: cstring, count: csize_t): csize_t {.cdecl, dynlib: DLLSSLName, importc.} + randomize() proc `$%`*[T](input: T): string = @@ -83,4 +87,39 @@ proc hi*[T](password, salt: string, iterations: int): T = result = previous for _ in 1.. TLS_EXPORT: + raise newException(ScramError, "Channel type " & $channel & " is not supported") + + if socket.isNil: + raise newException(ScramError, "Socket is not initialized") + + if not socket.isSsl or socket.sslContext == nil: + raise newException(ScramError, "Socket is not wrapped in an SSL context") + +proc getCBData*(channel: ChannelType, socket: AnySocket, isServer = true): string = + when not defined(ssl): + raise newException(ScramError, "SSL required for channel binding") + else: + result = newString(1024) + if channel == TLS_UNIQUE: + var ret: csize_t + if isServer: + ret = SSL_get_peer_finished(socket.sslContext, result.cstring, 1024) + else: + ret = SSL_get_finished(socket.sslContext, result.cstring, 1024) + + if ret == 0: + raise newException(ScramError, "SSLError: handshake has not reached the finished message") + + result.setLen(ret) diff --git a/scram/server.nim b/scram/server.nim index 68a7db4..a6bc43c 100644 --- a/scram/server.nim +++ b/scram/server.nim @@ -1,14 +1,15 @@ -import strformat, strutils -import base64, pegs, strutils, hmac, nimSHA2, private/[utils,types] +import strformat, strutils, base64, pegs, hmac, nimSHA2, private/[utils,types] type - ScramServer*[T] = ref object of RootObj - serverNonce*: string + ScramServer[T] = ref object of RootObj + serverNonce: string clientFirstMessageBare: string serverFirstMessage: string - state*: ScramState + state: ScramState isSuccessful: bool userData: UserData + cbType: ChannelType + cbData: string UserData* = object salt*: string @@ -45,24 +46,37 @@ proc initUserData*(salt: string, iterations: int, serverKey, storedKey: string): result.serverKey = serverKey result.storedKey = storedKey -proc newScramServer*[T](): ScramServer[T] {.deprecated: "use `new ScramServer[T]` instead".} = - new ScramServer[T] +proc newScramServer*[T](): ScramServer[T] = + result = new ScramServer[T] + result.state = INITIAL + result.isSuccessful = false + result.cbType = TLS_NONE + +proc newScramServer*[T](channel = TLS_UNIQUE, socket: AnySocket): ScramServer[T] = + result = newScramServer[T]() + validateCB(channel, socket) + result.cbType = channel + result.cbData = getCBData(channel, socket) proc handleClientFirstMessage*[T](s: ScramServer[T],clientFirstMessage: string): string = let parts = clientFirstMessage.split(',', 2) - var matches: array[3, string] - if not match(clientFirstMessage, CLIENT_FIRST_MESSAGE, matches) or not parts.len == 3: - s.state = ENDED - return - s.clientFirstMessageBare = parts[2] + #var matches: array[3, string] + #if not match(clientFirstMessage, CLIENT_FIRST_MESSAGE, matches) or not parts.len == 3: + # s.state = ENDED + # return - s.state = FIRST_CLIENT_MESSAGE_HANDLED + #if (gs2Header == "n"): + + + s.clientFirstMessageBare = parts[2] for kv in s.clientFirstMessageBare.split(','): if kv[0..1] == "n=": result = kv[2..^1] elif kv[0..1] == "r=": s.serverNonce = kv[2..^1] & makeNonce() + s.state = FIRST_CLIENT_MESSAGE_HANDLED + proc prepareFirstMessage*(s: ScramServer, userData: UserData): string = s.state = FIRST_PREPARED s.userData = userData @@ -124,13 +138,13 @@ proc getState*(s: ScramServer): ScramState = s.state when isMainModule: - import client as c + import client as c, net var username = "bob" password = "secret" userdata = initUserData(password) - server = new ScramServer[SHA256Digest] + server = newScramServer[SHA256Digest]() client = newScramClient[SHA256Digest]() assert(server.state == INITIAL) @@ -153,3 +167,7 @@ when isMainModule: assert client.verifyServerFinalMessage(serverFinalMessage) == true echo "Client is successful: ", client.isSuccessful() == true + + var + socket = newSocket() + server1 = newScramServer[SHA256Digest](TLS_UNIQUE, socket) From 873cc3a011dba7221af514c33bfd47483f3497e4 Mon Sep 17 00:00:00 2001 From: Huy Doan Date: Tue, 2 Aug 2022 02:13:12 +0700 Subject: [PATCH 2/3] initial channel binding support (WIP) --- scram/client.nim | 78 +++++++++++++++++++-------------------- scram/private/types.nim | 28 ++++++++++---- scram/private/utils.nim | 64 ++++++++++++++++++-------------- scram/server.nim | 82 +++++++++++++++++++++++++++++------------ tests/config.nims | 4 +- tests/test_both.nim | 4 +- tests/test_cb.nim | 35 ++++++++++++++++++ tests/test_server.nim | 10 ++--- 8 files changed, 197 insertions(+), 108 deletions(-) create mode 100644 tests/test_cb.nim diff --git a/scram/client.nim b/scram/client.nim index 127b062..13fcd8b 100644 --- a/scram/client.nim +++ b/scram/client.nim @@ -1,7 +1,7 @@ -import strformat -import base64, pegs, strutils, hmac, sha1, nimSHA2, md5, private/[utils,types] +import base64, strformat, strutils, hmac, sha1, nimSHA2, md5, private/[utils,types] export MD5Digest, SHA1Digest, SHA224Digest, SHA256Digest, SHA384Digest, SHA512Digest, Keccak512Digest +export ChannelType type ScramClient[T] = ref object of RootObj @@ -10,29 +10,25 @@ type state: ScramState isSuccessful: bool serverSignature: T - -when compileOption("threads"): - var - SERVER_FIRST_MESSAGE_VAL: ptr Peg - SERVER_FINAL_MESSAGE_VAL: ptr Peg - template SERVER_FIRST_MESSAGE: Peg = - if SERVER_FIRST_MESSAGE_VAL.isNil: - SERVER_FIRST_MESSAGE_VAL = cast[ptr Peg](allocShared0(sizeof(Peg))) - SERVER_FIRST_MESSAGE_VAL[] = peg"'r='{[^,]*}',s='{[^,]*}',i='{\d+}$" - SERVER_FIRST_MESSAGE_VAL[] - template SERVER_FINAL_MESSAGE: Peg = - if SERVER_FINAL_MESSAGE_VAL.isNil: - SERVER_FINAL_MESSAGE_VAL = cast[ptr Peg](allocShared0(sizeof(Peg))) - SERVER_FINAL_MESSAGE_VAL[] = peg"'v='{[^,]*}$" - SERVER_FINAL_MESSAGE_VAL[] -else: - let - SERVER_FIRST_MESSAGE = peg"'r='{[^,]*}',s='{[^,]*}',i='{\d+}$" - SERVER_FINAL_MESSAGE = peg"'v='{[^,]*}$" + cbType: ChannelType + cbData: string proc newScramClient*[T](): ScramClient[T] = - result = new(ScramClient[T]) + result = new ScramClient[T] result.clientNonce = makeNonce() + result.cbType = TLS_NONE + + +proc newScramClient*[T](socket: AnySocket, channel = TLS_UNIQUE): ScramClient[T] = + result = newScramClient[T]() + if socket != nil: + validateCB(channel, socket) + result.cbType = channel + result.cbData = getCBData(channel, socket) + +proc setCBindType*[T](s: ScramClient[T], channel: ChannelType) = s.cbType = channel + +proc setCBindData*[T](s: ScramClient[T], data: string) = s.cbData = data proc prepareFirstMessage*(s: ScramClient, username: string): string {.raises: [ScramError]} = if username.len == 0: @@ -44,7 +40,7 @@ proc prepareFirstMessage*(s: ScramClient, username: string): string {.raises: [S s.clientFirstMessageBare.add(s.clientNonce) s.state = FIRST_PREPARED - GS2_HEADER & s.clientFirstMessageBare + result = makeGS2Header(s.cbType) & s.clientFirstMessageBare proc prepareFinalMessage*[T](s: ScramClient[T], password, serverFirstMessage: string): string = if s.state != FIRST_PREPARED: @@ -53,17 +49,15 @@ proc prepareFinalMessage*[T](s: ScramClient[T], password, serverFirstMessage: st nonce, salt: string iterations: int var matches: array[3, string] - if match(serverFirstMessage, SERVER_FIRST_MESSAGE, matches): - for kv in serverFirstMessage.split(','): - if kv[0..1] == "i=": - iterations = parseInt(kv[2..^1]) - elif kv[0..1] == "r=": - nonce = kv[2..^1] - elif kv[0..1] == "s=": - salt = base64.decode(kv[2..^1]) - else: - s.state = ENDED - return "" + + for kv in serverFirstMessage.split(','): + if kv[0..1] == "i=": + iterations = parseInt(kv[2..^1]) + elif kv[0..1] == "r=": + nonce = kv[2..^1] + elif kv[0..1] == "s=": + salt = base64.decode(kv[2..^1]) + if not nonce.startsWith(s.clientNonce): raise newException(ScramError, "Security error: invalid nonce received from server. Possible man-in-the-middle attack.") @@ -76,7 +70,7 @@ proc prepareFinalMessage*[T](s: ScramClient[T], password, serverFirstMessage: st clientKey = HMAC[T]($%saltedPassword, CLIENT_KEY) storedKey = HASH[T]($%clientKey) serverKey = HMAC[T]($%saltedPassword, SERVER_KEY) - clientFinalMessageWithoutProof = "c=biws,r=" & nonce + clientFinalMessageWithoutProof = makeCBind(s.cbType, s.cbData) & ",r=" & nonce authMessage =[s.clientFirstMessageBare, serverFirstMessage, clientFinalMessageWithoutProof].join(",") clientSignature = HMAC[T]($%storedKey, authMessage) s.serverSignature = HMAC[T]($%serverKey, authMessage) @@ -93,12 +87,14 @@ proc verifyServerFinalMessage*(s: ScramClient, serverFinalMessage: string): bool raise newException(ScramError, "You can call this method only once after calling prepareFinalMessage()") s.state = ENDED var matches: array[1, string] - if match(serverFinalMessage, SERVER_FINAL_MESSAGE, matches): - var proposedServerSignature: string - for kv in serverFinalMessage.split(','): - if kv[0..1] == "v=": - proposedServerSignature = base64.decode(kv[2..^1]) - s.isSuccessful = proposedServerSignature == $%s.serverSignature + + var proposedServerSignature: string + for kv in serverFinalMessage.split(','): + if kv[0..1] == "e=": + raise newException(ScramError, "ServerError: " & kv[2..^1]) + elif kv[0..1] == "v=": + proposedServerSignature = base64.decode(kv[2..^1]) + s.isSuccessful = proposedServerSignature == $%s.serverSignature s.isSuccessful proc isSuccessful*(s: ScramClient): bool = diff --git a/scram/private/types.nim b/scram/private/types.nim index 974c60b..bd5640c 100644 --- a/scram/private/types.nim +++ b/scram/private/types.nim @@ -22,14 +22,28 @@ type AnySocket* = Socket|AsyncSocket ChannelType* = enum - TLS_NONE - TLS_SERVER_END_POINT - TLS_UNIQUE - TLS_UNIQUE_FOR_TELNET - TLS_EXPORT + TLS_NONE = "" + TLS_SERVER_END_POINT = "tls-server-end-point" + TLS_UNIQUE = "tls-unique" + TLS_UNIQUE_FOR_TELNET = "tls-server-for-telnet" + TLS_EXPORT = "tls-export" + + ServerError* = enum + SERVER_ERROR_NO_ERROR = "" + SERVER_ERROR_INVALID_ENCODING = "invalid-encoding" + SERVER_ERROR_EXTENSIONS_NOT_SUPPORTED = "extensions-not-supported" + SERVER_ERROR_INVALID_PROOF = "invalid-proof" + SERVER_ERROR_CHANNEL_BINDINGS_DONT_MATCH = "channel-bindings-dont-match" + SERVER_ERROR_SERVER_DOES_SUPPORT_CHANNEL_BINDING = "server-does-support-channel-binding" + SERVER_ERROR_CHANNEL_BINDING_NOT_SUPPORTED = "channel-binding-not-supported" + SERVER_ERROR_UNSUPPORTED_CHANNEL_BINDING_TYPE = "unsupported-channel-binding-type" + SERVER_ERROR_UNKNOWN_USER = "unknown-user" + SERVER_ERROR_INVALID_USERNAME_ENCODING = "invalid-username-encoding" + SERVER_ERROR_NO_RESOURCES = "no-resources" + SERVER_ERROR_OTHER_ERROR = "other-error" + const - GS2_HEADER* = "n,," INT_1* = "\x00\x00\x00\x01" CLIENT_KEY* = "Client Key" - SERVER_KEY* = "Server Key" \ No newline at end of file + SERVER_KEY* = "Server Key" diff --git a/scram/private/utils.nim b/scram/private/utils.nim index 8ea64de..7e8f996 100644 --- a/scram/private/utils.nim +++ b/scram/private/utils.nim @@ -1,11 +1,11 @@ -import random, base64, strutils, types, hmac, bitops, openssl +import random, base64, strutils, types, hmac, bitops, openssl, net, asyncnet from md5 import MD5Digest from sha1 import Sha1Digest from nimSHA2 import Sha224Digest, Sha256Digest, Sha384Digest, Sha512Digest -proc SSL_get_finished*(ssl: SslCtx, buf: cstring, count: csize_t): csize_t {.cdecl, dynlib: DLLSSLName, importc.} -proc SSL_get_peer_finished*(ssl: SslCtx, buf: cstring, count: csize_t): csize_t {.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_get_finished*(ssl: SslPtr, buf: cstring, count: csize_t): csize_t {.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_get_peer_finished*(ssl: SslPtr, buf: cstring, count: csize_t): csize_t {.cdecl, dynlib: DLLSSLName, importc.} randomize() @@ -89,37 +89,45 @@ proc hi*[T](password, salt: string, iterations: int): T = previous = HMAC[T](password, $%previous) result ^= previous +proc makeGS2Header*(channel: ChannelType): string = + result = case channel + of TLS_UNIQUE: "p=tls-unique,," + of TLS_SERVER_END_POINT: "p=tls-server-end-point,," + of TLS_UNIQUE_FOR_TELNET: "p=tls-server-for-telnet,," + of TLS_EXPORT: "p=tls-export,," + else: "n,," + +proc makeCBind*(channel: ChannelType, data: string = ""): string = + if channel == TLS_NONE: + result = "c=biws" + else: + result = "c=" & base64.encode(makeGS2Header(channel) & data) + proc validateCB*(channel: ChannelType, socket: AnySocket) = if channel == TLS_NONE: return - when not defined(ssl): - raise newException(ScramError, "SSL required for channel binding") - else: - - if channel > TLS_EXPORT: - raise newException(ScramError, "Channel type " & $channel & " is not supported") + if channel > TLS_EXPORT: + raise newException(ScramChannelBindingError, "Channel type " & $channel & " is not supported") - if socket.isNil: - raise newException(ScramError, "Socket is not initialized") + if socket.isNil: + raise newException(ScramChannelBindingError, "Socket is not initialized") - if not socket.isSsl or socket.sslContext == nil: - raise newException(ScramError, "Socket is not wrapped in an SSL context") + if not socket.isSsl or socket.sslHandle() == nil: + raise newException(ScramChannelBindingError, "Socket is not wrapped in a SSL context") proc getCBData*(channel: ChannelType, socket: AnySocket, isServer = true): string = - when not defined(ssl): - raise newException(ScramError, "SSL required for channel binding") - else: - result = newString(1024) - if channel == TLS_UNIQUE: - var ret: csize_t - if isServer: - ret = SSL_get_peer_finished(socket.sslContext, result.cstring, 1024) - else: - ret = SSL_get_finished(socket.sslContext, result.cstring, 1024) - - if ret == 0: - raise newException(ScramError, "SSLError: handshake has not reached the finished message") - - result.setLen(ret) + + result = newString(1024) + if channel == TLS_UNIQUE: + var ret: csize_t + if isServer: + ret = SSL_get_peer_finished(socket.sslHandle(), result.cstring, 1024) + else: + ret = SSL_get_finished(socket.sslHandle(), result.cstring, 1024) + + if ret == 0: + raise newException(ScramChannelBindingError, "SSLError: handshake has not reached the finished message") + + result.setLen(ret) diff --git a/scram/server.nim b/scram/server.nim index a6bc43c..14a6a86 100644 --- a/scram/server.nim +++ b/scram/server.nim @@ -1,4 +1,6 @@ -import strformat, strutils, base64, pegs, hmac, nimSHA2, private/[utils,types] +import strformat, strutils, base64, hmac, nimSHA2, private/[utils,types] + +export ChannelType type ScramServer[T] = ref object of RootObj @@ -8,6 +10,8 @@ type state: ScramState isSuccessful: bool userData: UserData + serverError: ServerError + serverErrorValueExt: string cbType: ChannelType cbData: string @@ -17,10 +21,6 @@ type serverKey*: string storedKey*: string -let - CLIENT_FIRST_MESSAGE = peg"^([pny]'='?([^,]*)','([^,]*)','){('m='([^,]*)',')?'n='{[^,]*}',r='{[^,]*}(','(.*))*}$" - CLIENT_FINAL_MESSAGE = peg"{'c='{[^,]*}',r='{[^,]*}}',p='{.*}$" - proc initUserData*[T](typ: typedesc[T], password: string, iterations = 4096): UserData = var iterations = iterations if password.len == 0: @@ -52,21 +52,41 @@ proc newScramServer*[T](): ScramServer[T] = result.isSuccessful = false result.cbType = TLS_NONE -proc newScramServer*[T](channel = TLS_UNIQUE, socket: AnySocket): ScramServer[T] = - result = newScramServer[T]() +proc newScramServer*[T](socket: AnySocket, channel = TLS_UNIQUE): ScramServer[T] = validateCB(channel, socket) + + result = newScramServer[T]() result.cbType = channel result.cbData = getCBData(channel, socket) -proc handleClientFirstMessage*[T](s: ScramServer[T],clientFirstMessage: string): string = - let parts = clientFirstMessage.split(',', 2) - #var matches: array[3, string] - #if not match(clientFirstMessage, CLIENT_FIRST_MESSAGE, matches) or not parts.len == 3: - # s.state = ENDED - # return +proc setCBindType*[T](s: ScramServer[T], channel: ChannelType) = s.cbType = channel - #if (gs2Header == "n"): +proc setCBindData*[T](s: ScramServer[T], data: string) = s.cbData = data +proc setServerNonce*[T](s: ScramServer[T], nonce: string) = s.serverNonce = nonce + +proc handleClientFirstMessage*[T](s: ScramServer[T], clientFirstMessage: string): string = + let parts = clientFirstMessage.split(',', 2) + if parts.len != 3: + s.state = ENDED + return + + let gs2CBindFlag = parts[0] + if (gs2CBindFlag[0] == 'n'): + if s.cbType != TLS_NONE: + s.serverError = SERVER_ERROR_SERVER_DOES_SUPPORT_CHANNEL_BINDING + elif (gs2CBindFlag[0] == 'y'): + if s.cbType != TLS_NONE: + s.serverError = SERVER_ERROR_SERVER_DOES_SUPPORT_CHANNEL_BINDING + elif (gs2CBindFlag[0] == 'p'): + if s.cbType == TLS_NONE: + s.serverError = SERVER_ERROR_CHANNEL_BINDING_NOT_SUPPORTED + let cbName = gs2CBindFlag.split("=")[1] + if cbName != $s.cbType: + s.serverError = SERVER_ERROR_UNSUPPORTED_CHANNEL_BINDING_TYPE + else: + s.serverError = SERVER_ERROR_OTHER_ERROR + s.serverErrorValueExt = "Invalid GS2 flag: " & gs2CBindFlag[0] s.clientFirstMessageBare = parts[2] for kv in s.clientFirstMessageBare.split(','): @@ -83,12 +103,22 @@ proc prepareFirstMessage*(s: ScramServer, userData: UserData): string = s.serverFirstMessage = "r=$#,s=$#,i=$#" % [s.serverNonce, userData.salt, $userData.iterations] s.serverFirstMessage + +proc makeError(error: ServerError, ext: string = ""): string = + if error != SERVER_ERROR_NO_ERROR: + result = "e=" & $error + if ext.len != 0: + result &= " " & ext + proc prepareFinalMessage*[T](s: ScramServer[T], clientFinalMessage: string): string = - var matches: array[4, string] - if not match(clientFinalMessage, CLIENT_FINAL_MESSAGE, matches): + + if s.serverError != SERVER_ERROR_NO_ERROR: + result = makeError(s.serverError, s.serverErrorValueExt) s.state = ENDED return - var clientFinalMessageWithoutProof, nonce, proof: string + + var clientFinalMessageWithoutProof, nonce, proof, cbind: string + for kv in clientFinalMessage.split(','): if kv[0..1] == "p=": proof = kv[2..^1] @@ -98,8 +128,16 @@ proc prepareFinalMessage*[T](s: ScramServer[T], clientFinalMessage: string): str clientFinalMessageWithoutProof.add(kv) if kv[0..1] == "r=": nonce = kv[2..^1] + elif kv[0..1] == "c=": + cbind = kv + + if cbind != makeCBind(s.cbType, s.cbData): + result = makeError(SERVER_ERROR_CHANNEL_BINDINGS_DONT_MATCH) + s.state = ENDED + return if nonce != s.serverNonce: + result = makeError(SERVER_ERROR_OTHER_ERROR, "Server nonce does not match") s.state = ENDED return @@ -110,12 +148,12 @@ proc prepareFinalMessage*[T](s: ScramServer[T], clientFinalMessage: string): str serverSignature = HMAC[T](decode(s.userData.serverKey), authMessage) decodedProof = base64.decode(proof) clientKey = custom_xor(clientSignature, decodedProof) - let resultKey = HASH[T](clientKey).raw_str + resultKey = HASH[T](clientKey).raw_str # SECURITY: constant time HMAC check if not constantTimeEqual(resultKey, storedKey): - let k1 = base64.encode(resultKey) - let k2 = base64.encode(storedKey) + result = makeError(SERVER_ERROR_OTHER_ERROR, "constant time hmac check failed") + s.state = ENDED return s.isSuccessful = true @@ -167,7 +205,3 @@ when isMainModule: assert client.verifyServerFinalMessage(serverFinalMessage) == true echo "Client is successful: ", client.isSuccessful() == true - - var - socket = newSocket() - server1 = newScramServer[SHA256Digest](TLS_UNIQUE, socket) diff --git a/tests/config.nims b/tests/config.nims index 3bb69f8..640b184 100644 --- a/tests/config.nims +++ b/tests/config.nims @@ -1 +1,3 @@ -switch("path", "$projectDir/../src") \ No newline at end of file +switch("path", "$projectDir/../src") +switch("d", "ssl") +switch("threads", "on") \ No newline at end of file diff --git a/tests/test_both.nim b/tests/test_both.nim index 1f6ae93..43285be 100644 --- a/tests/test_both.nim +++ b/tests/test_both.nim @@ -3,10 +3,10 @@ import unittest, scram/server, scram/client, sha1, nimSHA2, base64, scram/privat proc test[T](user, password: string) = var client = newScramClient[T]() - var server = new ScramServer[T] + var server = newScramServer[T]() let cfirst = client.prepareFirstMessage(user) assert server.handleClientFirstMessage(cfirst) == user, "incorrect detected username" - assert server.state == FIRST_CLIENT_MESSAGE_HANDLED, "incorrect state" + assert server.getState() == FIRST_CLIENT_MESSAGE_HANDLED, "incorrect state" let sfirst = server.prepareFirstMessage(initUserData(T, password)) let cfinal = client.prepareFinalMessage(password, sfirst) let sfinal = server.prepareFinalMessage(cfinal) diff --git a/tests/test_cb.nim b/tests/test_cb.nim new file mode 100644 index 0000000..70da23b --- /dev/null +++ b/tests/test_cb.nim @@ -0,0 +1,35 @@ +import unittest, scram/[server,client], sha1, nimSHA2 +import scram/private/types + +const FAKE_CBDATA = "xxxxxxxxxxxxxxxx" + +proc test[T](user, password: string) = + var client = newScramClient[T]() + var server = newScramServer[T]() + + client.setCBindType(TLS_UNIQUE) + client.setCBindData(FAKE_CBDATA) + + server.setCBindType(TLS_UNIQUE) + server.setCBindData(FAKE_CBDATA) + + let cfirst = client.prepareFirstMessage(user) + assert server.handleClientFirstMessage(cfirst) == user, "incorrect detected username" + assert server.getState() == FIRST_CLIENT_MESSAGE_HANDLED, "incorrect state" + let sfirst = server.prepareFirstMessage(initUserData(T, password)) + let cfinal = client.prepareFinalMessage(password, sfirst) + let sfinal = server.prepareFinalMessage(cfinal) + assert client.verifyServerFinalMessage(sfinal), "incorrect server final message" + +suite "Scram Client-Server tests": + test "SCRAM-SHA1-PLUS": + test[Sha1Digest]( + "user", + "pencil" + ) + + test "SCRAM-SHA256-PLUS": + test[Sha256Digest]( + "bob", + "secret" + ) \ No newline at end of file diff --git a/tests/test_server.nim b/tests/test_server.nim index 2f82fc2..337605b 100644 --- a/tests/test_server.nim +++ b/tests/test_server.nim @@ -1,11 +1,11 @@ -import unittest, scram/server, sha1, nimSHA2, base64, scram/private/[utils,types] - +import unittest, scram/server, sha1, nimSHA2, base64 +import scram/private/[utils, types] proc test[T](user, password, nonce, salt, cfirst, sfirst, cfinal, sfinal: string) = - var server = new ScramServer[T] + var server = newScramServer[T]() assert server.handleClientFirstMessage(cfirst) == user, "incorrect detected username" - assert server.state == FIRST_CLIENT_MESSAGE_HANDLED, "incorrect state" - server.serverNonce = nonce + assert server.getState() == FIRST_CLIENT_MESSAGE_HANDLED, "incorrect state" + server.setServerNonce(nonce) let iterations = 4096 decodedSalt = base64.decode(salt) From a2decdd127d31ad6cdfca0668943dc219d7820b1 Mon Sep 17 00:00:00 2001 From: Huy Doan Date: Tue, 2 Aug 2022 23:22:52 +0700 Subject: [PATCH 3/3] Added Channel Binding supports #11 --- README.md | 52 ++++++++++++++++++++---- scram.nimble | 2 +- scram/client.nim | 14 ++----- scram/private/types.nim | 7 ---- scram/private/utils.nim | 90 +++++++++++++++++++++++++++++++++++------ scram/server.nim | 16 +++----- tests/test_cb.nim | 71 +++++++++++++++++++++++++++----- 7 files changed, 191 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index 526aae5..fc118de 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,51 @@ [![Build Status](https://travis-ci.org/ba0f3/scram.nim.svg?branch=master)](https://travis-ci.org/ba0f3/scram.nim) -# scram +# scram.nim Salted Challenge Response Authentication Mechanism (SCRAM) -```nim -var s = newScramClient[Sha256Digest]() -s.clientNonce = "VeAOLsQ22fn/tjalHQIz7cQT" +### Supported Mechanisms: +* SCRAM-SHA-1 +* SCRAM-SHA-1-PLUS +* SCRAM-SHA-256 +* SCRAM-SHA-256-PLUS +* SCRAM-SHA-384 +* SCRAM-SHA-384-PLUS +* SCRAM-SHA-512 +* SCRAM-SHA-512-PLUS +* SCRAM-SHA3-512 +* SCRAM-SHA3-512-PLUS + +### Supported Channel Binding Types +* TLS_UNIQUE +* TLS_SERVER_END_POINT + +### Examples -echo s.prepareFirstMessage("bob") -let finalMessage = s.prepareFinalMessage("secret", "r=VeAOLsQ22fn/tjalHQIz7cQTmeE5qJh8qKEe8wALMut1,s=ldZSefTzKxPNJhP73AmW/A==,i=4096") -echo finalMessage -assert(finalMessage == "c=biws,r=VeAOLsQ22fn/tjalHQIz7cQTmeE5qJh8qKEe8wALMut1,p=AtNtxGzsMA8evcWBM0MXFjxN8OcG1KRkLkFyoHlupOU=") +#### Client +```nim +var client = newScramClient[Sha256Digest]() +assert client.prepareFirstMessage(user) == cfirst, "incorrect first message" +let fmsg = client.prepareFinalMessage(password, sfirst) +assert fmsg == cfinal, "incorrect final message" +assert client.verifyServerFinalMessage(sfinal), "incorrect server final message" ``` + +#### Channel Binding + +Helper proc `getChannelBindingData` added to helps you getting channel binding data from existing Socket/AsyncSocket + +```nim +var + ctx = newContext() + socket = newSocket() +ctx.wrapSocket(socket) +socket.connect(...) +# .... +let cbData = getChannelBindingData(TLS_UNIQUE, socket) + +var client = newScramClient[Sha256Digest]() +client.setChannelBindingType(TLS_UNIQUE) +client.setChannelBindingData(cbData) +echo client.prepareFirstMessage(user) +``` \ No newline at end of file diff --git a/scram.nimble b/scram.nimble index 7aee5e9..8d7a9f1 100644 --- a/scram.nimble +++ b/scram.nimble @@ -1,4 +1,4 @@ -version = "0.1.14" +version = "0.2.0" author = "Huy Doan" description = "Salted Challenge Response Authentication Mechanism (SCRAM) " license = "MIT" diff --git a/scram/client.nim b/scram/client.nim index 13fcd8b..1a9380b 100644 --- a/scram/client.nim +++ b/scram/client.nim @@ -1,7 +1,7 @@ import base64, strformat, strutils, hmac, sha1, nimSHA2, md5, private/[utils,types] export MD5Digest, SHA1Digest, SHA224Digest, SHA256Digest, SHA384Digest, SHA512Digest, Keccak512Digest -export ChannelType +export getChannelBindingData type ScramClient[T] = ref object of RootObj @@ -18,17 +18,9 @@ proc newScramClient*[T](): ScramClient[T] = result.clientNonce = makeNonce() result.cbType = TLS_NONE +proc setChannelBindingType*[T](s: ScramClient[T], channel: ChannelType) = s.cbType = channel -proc newScramClient*[T](socket: AnySocket, channel = TLS_UNIQUE): ScramClient[T] = - result = newScramClient[T]() - if socket != nil: - validateCB(channel, socket) - result.cbType = channel - result.cbData = getCBData(channel, socket) - -proc setCBindType*[T](s: ScramClient[T], channel: ChannelType) = s.cbType = channel - -proc setCBindData*[T](s: ScramClient[T], data: string) = s.cbData = data +proc setChannelBindingData*[T](s: ScramClient[T], data: string) = s.cbData = data proc prepareFirstMessage*(s: ScramClient, username: string): string {.raises: [ScramError]} = if username.len == 0: diff --git a/scram/private/types.nim b/scram/private/types.nim index bd5640c..489fca7 100644 --- a/scram/private/types.nim +++ b/scram/private/types.nim @@ -1,8 +1,3 @@ -from net import Socket -from asyncnet import AsyncSocket - -export Socket, AsyncSocket - type ScramError* = object of CatchableError @@ -19,8 +14,6 @@ type FIRST_CLIENT_MESSAGE_HANDLED ENDED - AnySocket* = Socket|AsyncSocket - ChannelType* = enum TLS_NONE = "" TLS_SERVER_END_POINT = "tls-server-end-point" diff --git a/scram/private/utils.nim b/scram/private/utils.nim index 7e8f996..1965fe3 100644 --- a/scram/private/utils.nim +++ b/scram/private/utils.nim @@ -4,8 +4,37 @@ from sha1 import Sha1Digest from nimSHA2 import Sha224Digest, Sha256Digest, Sha384Digest, Sha512Digest -proc SSL_get_finished*(ssl: SslPtr, buf: cstring, count: csize_t): csize_t {.cdecl, dynlib: DLLSSLName, importc.} -proc SSL_get_peer_finished*(ssl: SslPtr, buf: cstring, count: csize_t): csize_t {.cdecl, dynlib: DLLSSLName, importc.} +#from net import Socket +#from asyncnet import AsyncSocket + +#export Socket, AsyncSocket + +type + AnySocket* = Socket|AsyncSocket + +const + NID_md5 = 4 + NID_md5_sha1 = 114 + EVP_MAX_MD_SIZE = 64 + +{.push cdecl, dynlib: DLLSSLName, importc.} + +proc SSL_get_finished(ssl: SslPtr, buf: cstring, count: csize_t): csize_t +proc SSL_get_peer_finished(ssl: SslPtr, buf: cstring, count: csize_t): csize_t + +proc SSL_get_certificate(ssl: SslPtr): PX509 +proc SSL_get_peer_certificate(ssl: SslPtr): PX509 + +proc X509_get_signature_nid(x: PX509): int32 +proc OBJ_find_sigid_algs(signature: int32, pdigest: pointer, pencryption: pointer): int32 +proc OBJ_nid2sn(n: int): cstring + +proc EVP_sha256(): PEVP_MD +proc EVP_get_digestbynid(): PEVP_MD + +proc X509_digest(data: PX509, kind: PEVP_MD, md: ptr char, len: ptr uint32): int32 + +{.pop.} randomize() @@ -104,30 +133,67 @@ proc makeCBind*(channel: ChannelType, data: string = ""): string = result = "c=" & base64.encode(makeGS2Header(channel) & data) -proc validateCB*(channel: ChannelType, socket: AnySocket) = +proc validateChannelBinding*(channel: ChannelType, socket: AnySocket) = if channel == TLS_NONE: return if channel > TLS_EXPORT: - raise newException(ScramChannelBindingError, "Channel type " & $channel & " is not supported") + raise newException(ScramError, "Channel type " & $channel & " is not supported") if socket.isNil: - raise newException(ScramChannelBindingError, "Socket is not initialized") + raise newException(ScramError, "Socket is not initialized") if not socket.isSsl or socket.sslHandle() == nil: - raise newException(ScramChannelBindingError, "Socket is not wrapped in a SSL context") + raise newException(ScramError, "Socket is not wrapped in a SSL context") + +proc getChannelBindingData*(channel: ChannelType, socket: AnySocket, isServer = true): string = + # Ref: https://paquier.xyz/postgresql-2/channel-binding-openssl/ -proc getCBData*(channel: ChannelType, socket: AnySocket, isServer = true): string = + validateChannelBinding(channel, socket) - result = newString(1024) + result = newString(EVP_MAX_MD_SIZE) if channel == TLS_UNIQUE: var ret: csize_t if isServer: - ret = SSL_get_peer_finished(socket.sslHandle(), result.cstring, 1024) + ret = SSL_get_peer_finished(socket.sslHandle(), result.cstring, EVP_MAX_MD_SIZE) else: - ret = SSL_get_finished(socket.sslHandle(), result.cstring, 1024) + ret = SSL_get_finished(socket.sslHandle(), result.cstring, EVP_MAX_MD_SIZE) if ret == 0: - raise newException(ScramChannelBindingError, "SSLError: handshake has not reached the finished message") - + raise newException(ScramError, "SSLError: handshake has not reached the finished message") result.setLen(ret) + + elif channel == TLS_SERVER_END_POINT: + var + serverCert: PX509 + algoNid: int32 + algoType: PEVP_MD + hash: array[EVP_MAX_MD_SIZE, char] + hashSize: int32 + + if isServer: + serverCert = cast[PX509](SSL_get_certificate(socket.sslHandle())) + else: + serverCert = cast[PX509](SSL_get_peer_certificate(socket.sslHandle())) + + if serverCert == nil: + raise newException(ScramError, "SSLError: could not load server certtificate") + + if OBJ_find_sigid_algs(X509_get_signature_nid(serverCert), addr algoNid, nil) == 0: + raise newException(ScramError, "SSLError: could not determine server certificate signature algorithm") + + if algoNid == NID_md5 or algoNid == NID_md5_sha1: + algoType = EVP_sha256() + else: + algoType = EVP_get_digestbynid(algoNid) + if algoType == nil: + raise newException(ScramError, "SSLError: could not find digest for NID " & OBJ_nid2sn(algoNid)) + + if X509_digest(serverCert, algoType, hash, addr hashSize) == 0: + raise newException(ScramError, "SSLError: could not generate server certificate hash") + + copyMem(addr result[0], hash, hashSize) + result.setLen(hashSize) + + else: + raise newException(ScramError, "Channel " & $channel & " is not supported yet") \ No newline at end of file diff --git a/scram/server.nim b/scram/server.nim index 14a6a86..e2268de 100644 --- a/scram/server.nim +++ b/scram/server.nim @@ -1,6 +1,7 @@ -import strformat, strutils, base64, hmac, nimSHA2, private/[utils,types] +import base64, strformat, strutils, hmac, sha1, nimSHA2, md5, private/[utils,types] -export ChannelType +export MD5Digest, SHA1Digest, SHA224Digest, SHA256Digest, SHA384Digest, SHA512Digest, Keccak512Digest +export getChannelBindingData type ScramServer[T] = ref object of RootObj @@ -52,16 +53,9 @@ proc newScramServer*[T](): ScramServer[T] = result.isSuccessful = false result.cbType = TLS_NONE -proc newScramServer*[T](socket: AnySocket, channel = TLS_UNIQUE): ScramServer[T] = - validateCB(channel, socket) +proc setChannelBindingType*[T](s: ScramServer[T], channel: ChannelType) = s.cbType = channel - result = newScramServer[T]() - result.cbType = channel - result.cbData = getCBData(channel, socket) - -proc setCBindType*[T](s: ScramServer[T], channel: ChannelType) = s.cbType = channel - -proc setCBindData*[T](s: ScramServer[T], data: string) = s.cbData = data +proc setChannelBindingData*[T](s: ScramServer[T], data: string) = s.cbData = data proc setServerNonce*[T](s: ScramServer[T], nonce: string) = s.serverNonce = nonce diff --git a/tests/test_cb.nim b/tests/test_cb.nim index 70da23b..99ad9ef 100644 --- a/tests/test_cb.nim +++ b/tests/test_cb.nim @@ -3,15 +3,17 @@ import scram/private/types const FAKE_CBDATA = "xxxxxxxxxxxxxxxx" -proc test[T](user, password: string) = +proc test[T](user, password: string, clientChannel = TLS_NONE, serverChannel = TLS_NONE, clientCbData = FAKE_CBDATA, serverCbData = FAKE_CBDATA) = var client = newScramClient[T]() var server = newScramServer[T]() - client.setCBindType(TLS_UNIQUE) - client.setCBindData(FAKE_CBDATA) + if clientChannel != TLS_NONE: + client.setChannelBindingType(clientChannel) + client.setChannelBindingData(clientCbData) - server.setCBindType(TLS_UNIQUE) - server.setCBindData(FAKE_CBDATA) + if serverChannel != TLS_NONE: + server.setChannelBindingType(serverChannel) + server.setChannelBindingData(serverCbData) let cfirst = client.prepareFirstMessage(user) assert server.handleClientFirstMessage(cfirst) == user, "incorrect detected username" @@ -21,15 +23,62 @@ proc test[T](user, password: string) = let sfinal = server.prepareFinalMessage(cfinal) assert client.verifyServerFinalMessage(sfinal), "incorrect server final message" -suite "Scram Client-Server tests": - test "SCRAM-SHA1-PLUS": +suite "Scram Channel Binding tests": + test "SCRAM-SHA1-PLUS tls-unique": test[Sha1Digest]( "user", - "pencil" + "pencil", + TLS_UNIQUE, + TLS_UNIQUE ) - test "SCRAM-SHA256-PLUS": + test "SCRAM-SHA256-PLUS: tls-unique": test[Sha256Digest]( "bob", - "secret" - ) \ No newline at end of file + "secret", + TLS_UNIQUE, + TLS_UNIQUE + ) + + test "SCRAM-SHA1-PLUS tls-server-end-point": + test[Sha1Digest]( + "user", + "pencil", + TLS_SERVER_END_POINT, + TLS_SERVER_END_POINT + ) + + test "SCRAM-SHA256-PLUS: tls-server-end-point": + test[Sha256Digest]( + "bob", + "secret", + TLS_SERVER_END_POINT, + TLS_SERVER_END_POINT + ) + + test "client-support-server-do-not": + expect ScramError: + test[Sha256Digest]( + "bob", + "secret", + TLS_UNIQUE + ) + + test "server-do-not-suport-client-channel-binding-type": + expect ScramError: + test[Sha256Digest]( + "bob", + "secret", + TLS_UNIQUE, + TLS_SERVER_END_POINT + ) + test "channel-bindings-dont-match": + expect ScramError: + test[Sha256Digest]( + "bob", + "secret", + TLS_UNIQUE, + TLS_UNIQUE, + "xxxx", + "zzzz" + ) \ No newline at end of file