From 25617310bff196d549c9072e7aa10d5b9ea5e455 Mon Sep 17 00:00:00 2001 From: Kumbirai Tanekha Date: Wed, 27 Sep 2023 12:22:05 +0100 Subject: [PATCH 1/8] Update postgres test README --- test/connector/tcp/pg/README.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/connector/tcp/pg/README.md b/test/connector/tcp/pg/README.md index 4c137b152..f525871fd 100644 --- a/test/connector/tcp/pg/README.md +++ b/test/connector/tcp/pg/README.md @@ -1,19 +1,24 @@ -# Postgresql Handler Development +# Postgresql Connector Development + +## TLDR -##TLDR The following two steps are all you need to know: + 1. A single command starts the dev workflow: ```sh-session $ ./dev ``` + 2. Another one runs the tests: + ```sh-session $ ./test ``` + So while developing, you'll do a single `./dev` and then many `./test` runs as you iteratively change the code or add new tests. -##Additional Details +## Additional Details `./dev` uses `docker-compose` to start both `pg` containers, the `secretless` container, and the `test` container (where tests are run from). From 835429845dbd9992138b63b58e6301c25b039854 Mon Sep 17 00:00:00 2001 From: Kumbirai Tanekha Date: Wed, 27 Sep 2023 15:38:04 +0100 Subject: [PATCH 2/8] Add caching_sha256_password to mysql connector --- CHANGELOG.md | 6 +- .../tcp/mysql/authentication_handshake.go | 179 ++++++++- .../plugin/connectors/tcp/mysql/connector.go | 6 +- .../connectors/tcp/mysql/protocol/const.go | 16 +- .../connectors/tcp/mysql/protocol/packet.go | 4 +- .../connectors/tcp/mysql/protocol/protocol.go | 379 ++++++++++++++++-- test/connector/tcp/mysql/Dockerfile.dev | 3 +- test/connector/tcp/mysql/Dockerfile.mysql | 8 - test/connector/tcp/mysql/README.md | 59 ++- test/connector/tcp/mysql/docker-compose.yml | 34 +- test/connector/tcp/mysql/etc/no-ssl.cnf | 2 + test/connector/tcp/mysql/etc/ssl.cnf | 4 + test/connector/tcp/mysql/etc/test.sql | 16 +- test/connector/tcp/mysql/etc/toggle_ssl.sh | 21 - test/connector/tcp/mysql/pkg/run_test_case.go | 4 +- .../tcp/mysql/tests/essentials_test.go | 26 +- 16 files changed, 614 insertions(+), 153 deletions(-) delete mode 100644 test/connector/tcp/mysql/Dockerfile.mysql create mode 100644 test/connector/tcp/mysql/etc/no-ssl.cnf create mode 100644 test/connector/tcp/mysql/etc/ssl.cnf delete mode 100755 test/connector/tcp/mysql/etc/toggle_ssl.sh diff --git a/CHANGELOG.md b/CHANGELOG.md index 6783e6fb3..5292ff83d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Nothing should go in this section, please add to the latest unreleased version (and update the corresponding date), or add a new version. +## [1.7.19] - 2023-10-01 + +### Added +- Add support for caching_sha256_password to mysql connector + ## [1.7.18] - 2023-08-22 ### Changed @@ -16,7 +21,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. [cyberark/secretless-broker#1499](https://github.com/cyberark/secretless-broker/pull/1499) ### Security -- Updated jquery to v3.7.1 - Updated github.com/docker/docker to v24.0.5 (CONJSE-1798) ### Added diff --git a/internal/plugin/connectors/tcp/mysql/authentication_handshake.go b/internal/plugin/connectors/tcp/mysql/authentication_handshake.go index 9a6d1fd22..f851b83df 100644 --- a/internal/plugin/connectors/tcp/mysql/authentication_handshake.go +++ b/internal/plugin/connectors/tcp/mysql/authentication_handshake.go @@ -1,6 +1,11 @@ package mysql import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" "net" "github.com/cyberark/secretless-broker/internal/plugin/connectors/tcp/mysql/protocol" @@ -8,7 +13,6 @@ import ( ) /* - AuthenticationHandshake represents the entire back and forth process between a MySQL client and server during which authentication occurs. Note this is distinct from the various specific handshake packets that are @@ -58,7 +62,6 @@ following source: Secretless->Backend: HandshakeResponse Backend->Secretless: OkPacket Secretless->Client: OkPacket - */ type AuthenticationHandshake struct { connectionDetails *ConnectionDetails @@ -94,17 +97,60 @@ func NewAuthenticationHandshake( // AuthenticatedBackendConn will return the raw, authenticated network conn. func (h *AuthenticationHandshake) Run() error { h.readServerHandshake() + // Pass along the handshake but make sure client doesn't use TLS to connect to Secretless h.writeHandshakeToClient() + + // Does the server support TLS when needed? h.validateServerSSL() + // Get the client handshake response. I thought it would be good to only make this + // use the simpliest auth mechanism, but no need since the server won't do the entire dance. + // It will only ever return success or error, not auth switch or anything else h.readClientHandshakeResponse() - h.overrideClientCapabilities() - h.injectCredentials() - h.handleClientSSLRequest() - h.writeClientHandshakeResponseToBackend() - h.verifyAndProxyOkResponse() + + // Everything in here is about responding to the server authentication challenge + // No need to talk to the client an + if true { + h.overrideClientCapabilities() + h.injectCredentials() + // Deal with this later + h.handleClientSSLRequest() + h.writeClientHandshakeResponseToBackend() + h.verifyAndProxyOkResponse() + } + return h.err } +func (h *AuthenticationHandshake) NewRun() error { + backendPacket := h.readBackendPacket() + if h.err != nil { + return h.err + } + + serverHandshake, err := protocol.UnpackHandshakeV10(backendPacket) + if err != nil { + return err + } + fmt.Println("serverHandshake:", serverHandshake) + + backendPacket1, err := protocol.PackHandshakeV10(serverHandshake) + if err != nil { + return err + } + fmt.Println("backendPacket:", backendPacket) + fmt.Println("backendPacket:", backendPacket1) + + serverHandshake, err = protocol.UnpackHandshakeV10(backendPacket) + if err != nil { + return err + } + fmt.Println("serverHandshake:", serverHandshake) + + return h.clientConn.write(backendPacket1) + + // return nil +} + // AuthenticatedBackendConn returns an already authenticated connection // to the MySQL server. Intended to be called after Run() has completed. func (h *AuthenticationHandshake) AuthenticatedBackendConn() net.Conn { @@ -128,10 +174,17 @@ func (h *AuthenticationHandshake) writeHandshakeToClient() { return } + serverHandshake := *h.serverHandshake // Remove Client SSL Capability from Server Handshake Packet // to force client to connect to Secretless without SSL // TODO: update this after kumbi's work - packetWithNoSSL, err := protocol.RemoveSSLFromHandshakeV10(h.rawServerHandshake) + serverHandshake.ServerCapabilities &^= protocol.ClientSSL + + // Give client the simplest auth plugin request + // This might work for now, but we'll likely need to add support for other auth plugins + serverHandshake.AuthPlugin = "mysql_native_password" + + packetWithNoSSL, err := protocol.PackHandshakeV10(&serverHandshake) if err != nil { h.err = err return @@ -161,7 +214,8 @@ func (h *AuthenticationHandshake) readClientHandshakeResponse() { return } - // TODO: client requesting SSL results ERROR 2026 (HY000): SSL connection error: protocol version mismatch + // TODO: client requesting SSL results in ERROR 2026 (HY000): SSL connection error: protocol version mismatch + // TODO: Find out if this is a client request SSL after we advertise not supporting SSL ? h.clientHandshakeResponse, h.err = protocol.UnpackHandshakeResponse41(rawResponse) } func (h *AuthenticationHandshake) overrideClientCapabilities() { @@ -173,7 +227,7 @@ func (h *AuthenticationHandshake) overrideClientCapabilities() { // TODO: add tests cases for authentication plugins support // Disable CapabilityFlag for authentication plugins support - h.clientHandshakeResponse.CapabilityFlags &^= protocol.ClientPluginAuth + h.clientHandshakeResponse.CapabilityFlags |= protocol.ClientPluginAuth // TODO: add tests cases for client secure connection // Enable CapabilityFlag for client secure connection @@ -193,6 +247,7 @@ func (h *AuthenticationHandshake) injectCredentials() { // TODO: change this to method call on clientHandshakeResponse when Kumbi's work done h.err = protocol.InjectCredentials( + h.serverHandshake.AuthPlugin, h.clientHandshakeResponse, h.serverHandshake.Salt, h.connectionDetails.Username, @@ -270,6 +325,8 @@ func (h *AuthenticationHandshake) writeClientHandshakeResponseToBackend() { return } + // TODO: We should probably be carrying out a comprehensive unpacking, so that + // 1. we can be selective about the contents of the response packedHandshakeRespPacket, err := protocol.PackHandshakeResponse41(h.clientHandshakeResponse) if err != nil { h.err = err @@ -289,7 +346,9 @@ func (h *AuthenticationHandshake) verifyAndProxyOkResponse() { return } - switch protocol.GetPacketType(rawPkt) { + packetType := protocol.GetPacketType(rawPkt) + fmt.Println("packetType", packetType) + switch packetType { case protocol.ResponseErr: // Return after adding the error response to AuthenticationHandshake // as a protocol.Error type @@ -300,6 +359,104 @@ func (h *AuthenticationHandshake) verifyAndProxyOkResponse() { err := protocol.UnpackErrResponse(rawPkt) h.err = err return + case protocol.ResponseAuthMoreData: + moreDataResp, err := protocol.UnpackAuthMoreDataResponse(rawPkt) + if err != nil { + h.err = err + return + } + + switch moreDataResp.StatusTag { + case protocol.CachingSha2PasswordFastAuthSuccess: + // The user was cached and a fast login was performed successfully. + // Do nothing. An OK packet will be sent by the server immediately + // following this packet. + return + case protocol.CachingSha2PasswordPerformFullAuthentication: + // The server is requesting a full authentication handshake. + //https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html + //https://github.com/go-sql-driver/mysql/blob/master/auth.go#L353 + + // request public key from server + data := protocol.PackAuthRequestPubKeyResponse() + + h.writeBackendPacket(data) + if h.err != nil { + return + } + + // read public key from server + pubKeyPkt := h.readBackendPacket() + if h.err != nil { + return + } + + // parse public key + if pubKeyPkt[4] != protocol.ResponseAuthMoreData { + h.err = errors.New("expected ResponseAuthMoreData packet") + //TODO: For some reason we're getting "28000Access denied for user..." packets here in some cases + return + } + + block, rest := pem.Decode(pubKeyPkt[5:]) + if block == nil { + h.err = fmt.Errorf("no pem data found, data: %s", rest) + return + } + pkix, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + h.err = err + return + } + pubKey := pkix.(*rsa.PublicKey) + + // encrypt password with public key + enc, err := protocol.EncryptPassword(h.connectionDetails.Password, h.serverHandshake.Salt, pubKey) + if err != nil { + h.err = err + return + } + + encPkt := protocol.PackAuthEncryptedPasswordResponse(enc) + + h.writeBackendPacket(encPkt) + + return + } + return + case 0xfe: + authSwitchRequest, err := protocol.UnpackAuthSwitchRequest(rawPkt) + if err != nil { + h.err = err + return + } + + salt := authSwitchRequest.PluginData + // This is because the salt seems to actually be 21 bytes, ending in a null byte. + // However the documentation suggests auth switch requests should be an EOF string + // See https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html + if authSwitchRequest.PluginName == "mysql_native_password" { + salt = salt[:20] + } + authResponse, err := protocol.CreateAuthResponse(authSwitchRequest.PluginName, []byte(h.connectionDetails.Password), salt) + if err != nil { + return + } + + authSwitchResponseData, err := protocol.PackAuthSwitchResponse( + authSwitchRequest.SequenceNumber, + authResponse, + ) + if err != nil { + h.err = err + return + } + h.writeBackendPacket(authSwitchResponseData) + + // TODO: Avoid infinite recursion? + h.verifyAndProxyOkResponse() + return + default: // Verify packet is valid; don't do anything with unpacked if _, err := protocol.UnpackOkResponse(rawPkt); err != nil { diff --git a/internal/plugin/connectors/tcp/mysql/connector.go b/internal/plugin/connectors/tcp/mysql/connector.go index c1a1e8bf8..f64375891 100644 --- a/internal/plugin/connectors/tcp/mysql/connector.go +++ b/internal/plugin/connectors/tcp/mysql/connector.go @@ -39,9 +39,9 @@ func (connector *SingleUseConnector) sendErrorToClient(err error) { // Connect implements the tcp.Connector func signature // // It is the main method of the SingleUseConnector. It: -// 1. Constructs connection details from the provided credentials map. -// 2. Dials the backend using credentials. -// 3. Runs through the connection phase steps to authenticate. +// 1. Constructs connection details from the provided credentials map. +// 2. Dials the backend using credentials. +// 3. Runs through the connection phase steps to authenticate. // // Connect requires "host", "port", "username" and "password" credentials. func (connector *SingleUseConnector) Connect( diff --git a/internal/plugin/connectors/tcp/mysql/protocol/const.go b/internal/plugin/connectors/tcp/mysql/protocol/const.go index dbc5a4067..3801b6984 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/const.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/const.go @@ -27,11 +27,12 @@ package protocol // Random constants const ( // MySQL response types - responseEOF = 0xfe - responseOk = 0x00 - responsePrepareOk = 0x00 - ResponseErr = 0xff - responseLocalinfile = 0xfb + ResponseAuthMoreData = 0x01 + ResponseEOF = 0xfe + responseOk = 0x00 + responsePrepareOk = 0x00 + ResponseErr = 0xff + responseLocalinfile = 0xfb // MySQL field types constants fieldTypeString = 0xfd @@ -49,7 +50,10 @@ const ( // Digits after comma doubleDecodePrecision = 6 - defaultAuthPluginName = "mysql_native_password" + defaultAuthPluginName = "mysql_native_password" + CachingSha2PasswordRequestPublicKey = 2 + CachingSha2PasswordFastAuthSuccess = 3 + CachingSha2PasswordPerformFullAuthentication = 4 ) // Protocol commands diff --git a/internal/plugin/connectors/tcp/mysql/protocol/packet.go b/internal/plugin/connectors/tcp/mysql/protocol/packet.go index cf2f28ece..7d717a895 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/packet.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/packet.go @@ -185,7 +185,7 @@ func ReadResponse(conn net.Conn, deprecateEOF bool) ([]byte, byte, error) { data = append(data, pkt...) - if pkt[4] == responseEOF { + if pkt[4] == ResponseEOF { break } } @@ -197,7 +197,7 @@ func ReadResponse(conn net.Conn, deprecateEOF bool) ([]byte, byte, error) { func ReadPacket(conn net.Conn) ([]byte, error) { // Read packet header - header := []byte{0, 0, 0, 0} + header := make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { return nil, err } diff --git a/internal/plugin/connectors/tcp/mysql/protocol/protocol.go b/internal/plugin/connectors/tcp/mysql/protocol/protocol.go index 7fa8032e5..b1e728d34 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/protocol.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/protocol.go @@ -26,9 +26,13 @@ package protocol import ( "bytes" + "crypto/rand" + "crypto/rsa" "crypto/sha1" + "crypto/sha256" "encoding/binary" "errors" + "fmt" "io" ) @@ -49,10 +53,12 @@ var ErrFieldTypeNotImplementedYet = errors.New("Protocol: Required field type no // int<1> PacketType (0xFF) // int<2> ErrorCode // if clientCapabilities & clientProtocol41 -// { -// string<1> SqlStateMarker (#) -// string<5> SqlState -// } +// +// { +// string<1> SqlStateMarker (#) +// string<5> SqlState +// } +// // string Error func UnpackErrResponse(data []byte) error { // Min packet length = @@ -96,10 +102,10 @@ func UnpackErrResponse(data []byte) error { // GetPacketType extracts the PacketType byte // Part of basic packet structure shown below. // -// int<3> PacketLength -// int<1> PacketNumber -// int<1> PacketType (0xFF) -// ... more ... +// int<3> PacketLength +// int<1> PacketNumber +// int<1> PacketType (0xFF) +// ... more ... func GetPacketType(packet []byte) byte { return packet[4] } @@ -186,6 +192,8 @@ type HandshakeV10 struct { ProtocolVersion byte ServerVersion string ConnectionID uint32 + StatusFlags uint16 + CharacterSet uint8 ServerCapabilities uint32 AuthPlugin string Salt []byte @@ -207,22 +215,29 @@ type HandshakeV10 struct { // int<2> StatusFlags // int<2> ServerCapabilities (2nd part) // if capabilities & clientPluginAuth -// { -// int<1> AuthPluginDataLength -// } +// +// { +// int<1> AuthPluginDataLength +// } +// // else -// { -// int<1> 0x00 -// } +// +// { +// int<1> 0x00 +// } +// // string<10> Reserved (all 0x00) // if capabilities & clientSecureConnection -// { -// string[$len] AuthPluginDataPart2 ($len=MAX(13, AuthPluginDataLength - 8)) -// } +// +// { +// string[$len] AuthPluginDataPart2 ($len=MAX(13, AuthPluginDataLength - 8)) +// } +// // if capabilities & clientPluginAuth -// { -// string[NUL] AuthPluginName -// } +// +// { +// string[NUL] AuthPluginName +// } func UnpackHandshakeV10(packet []byte) (*HandshakeV10, error) { r := bytes.NewReader(packet) @@ -263,10 +278,16 @@ func UnpackHandshakeV10(packet []byte) (*HandshakeV10, error) { return nil, err } - // Skip ServerDefaultCollation and StatusFlags - if _, err := r.Seek(3, io.SeekCurrent); err != nil { + // Read ServerCharacterSet and StatusFlags + serverCharacterSet, err := r.ReadByte() + if err != nil { return nil, err } + serverStatusFlagsBuf := make([]byte, 2) + if _, err := r.Read(serverStatusFlagsBuf); err != nil { + return nil, err + } + serverStatusFlags := binary.LittleEndian.Uint16(serverStatusFlagsBuf) // Read ExServerCapabilities serverCapabilitiesHigherBuf := make([]byte, 2) @@ -332,9 +353,91 @@ func UnpackHandshakeV10(packet []byte) (*HandshakeV10, error) { ServerCapabilities: serverCapabilities, AuthPlugin: authPlugin, Salt: salt, + StatusFlags: serverStatusFlags, + CharacterSet: serverCharacterSet, }, nil } +// PackHandshakeV10 takes in a HandshakeResponse41 object and +// returns a handshake response packet +func PackHandshakeV10(serverHandshake *HandshakeV10) ([]byte, error) { + // Create a buffer to write the packet data + buffer := new(bytes.Buffer) + + // Write ProtocolVersion (int<1>) + binary.Write(buffer, binary.LittleEndian, serverHandshake.ProtocolVersion) + + // Write ServerVersion (string) + buffer.WriteString(serverHandshake.ServerVersion) + buffer.WriteByte(0) + + // Write ConnectionID (int<4>) + binary.Write(buffer, binary.LittleEndian, serverHandshake.ConnectionID) + + // Write AuthPluginDataPart1 (string<8>) + buffer.Write(serverHandshake.Salt[:8]) // Write the first 8 bytes of the salt + + // Write Reserved (int<1>) + buffer.WriteByte(0) + + // Write ServerCapabilities (int<2>) + binary.Write(buffer, binary.LittleEndian, uint16(serverHandshake.ServerCapabilities&0xFFFF)) + + // Write ServerCharacterSet (int<1>) + buffer.WriteByte(serverHandshake.CharacterSet) + + // Write StatusFlags (int<2>) + binary.Write(buffer, binary.LittleEndian, serverHandshake.StatusFlags) + + // Write ServerCapabilities (int<2>), the higher part + binary.Write(buffer, binary.LittleEndian, uint16(serverHandshake.ServerCapabilities>>16)) + + // Write AuthPluginDataLength (int<1>) if required + if serverHandshake.ServerCapabilities&ClientPluginAuth > 0 { + buffer.WriteByte(byte(len(serverHandshake.Salt) + 1)) + } + + // Write Reserved (string<10>) + buffer.Write(make([]byte, 10)) + + // Calculate the length of AuthPluginDataPart2 + var authPluginDataLength byte + if serverHandshake.ServerCapabilities&ClientSecureConnection != 0 { + numBytes := len(serverHandshake.Salt) - 8 + if numBytes > 13 { + numBytes = 13 + } + authPluginDataLength = byte(numBytes) + } + + // Write AuthPluginDataPart2 (string[$len]) if required + if serverHandshake.ServerCapabilities&ClientSecureConnection != 0 { + buffer.Write(serverHandshake.Salt[8 : 8+int(authPluginDataLength)]) + buffer.WriteByte(0) + } + + // Write AuthPluginName (string) if required + if serverHandshake.ServerCapabilities&ClientPluginAuth > 0 { + buffer.WriteString(serverHandshake.AuthPlugin) + buffer.WriteByte(0) + } + + // Calculate the packet length (excluding the length field itself) + packetLength := buffer.Len() + + // Create a header buffer and write the packet length (int<3>) + headerBuffer := make([]byte, 4) + headerBuffer[0] = byte(packetLength & 0xFF) + headerBuffer[1] = byte((packetLength >> 8) & 0xFF) + headerBuffer[2] = byte((packetLength >> 16) & 0xFF) + headerBuffer[3] = 0 + + // Combine the header and packet data to create the final packet + finalPacket := append(headerBuffer, buffer.Bytes()...) + + return finalPacket, nil +} + // RemoveSSLFromHandshakeV10 removes Client SSL Capability from Server // Handshake Packet. Secretless needs to do this to force the client to // communicate with Secretless without using SSL. That half of the connection @@ -426,17 +529,17 @@ func writeUint16(data []byte, pos int, value uint16) { // // The format of the header is also described here: // -// https://dev.mysql.com/doc/internals/en/mysql-packet.html +// https://dev.mysql.com/doc/internals/en/mysql-packet.html // -// +-------------+----------------+---------------------------------------------+ -// | Type | Name | Description | -// +-------------+----------------+---------------------------------------------+ -// | int<3> | payload_length | Length of the payload. The number of bytes | -// | | | in the packet beyond the initial 4 bytes | -// | | | that make up the packet header. | -// | int<1> | sequence_id | Sequence ID | -// | string | payload | [len=payload_length] payload of the packet | -// +-------------+----------------+---------------------------------------------+ +// +-------------+----------------+---------------------------------------------+ +// | Type | Name | Description | +// +-------------+----------------+---------------------------------------------+ +// | int<3> | payload_length | Length of the payload. The number of bytes | +// | | | in the packet beyond the initial 4 bytes | +// | | | that make up the packet header. | +// | int<1> | sequence_id | Sequence ID | +// | string | payload | [len=payload_length] payload of the packet | +// +-------------+----------------+---------------------------------------------+ type HandshakeResponse41 struct { Header []byte CapabilityFlags uint32 @@ -554,13 +657,34 @@ func UnpackHandshakeResponse41(packet []byte) (*HandshakeResponse41, error) { PacketTail: packetTail}, nil } +func CreateAuthResponse(authPlugin string, password []byte, salt []byte) ([]byte, error) { + var authResponse []byte + var err error + + switch authPlugin { + case "mysql_native_password": + authResponse, err = NativePassword([]byte(password), salt) + case "caching_sha2_password": + authResponse = scrambleSHA256Password([]byte(password), salt) + default: + err = fmt.Errorf("Unknown auth plugin: %s", authPlugin) + } + + if err != nil { + return nil, err + } + return authResponse, nil +} + // InjectCredentials takes in a HandshakeResponse41 from the client, the // salt from the server, and a username / password, and uses the salt // from the server handshake to inject the username / password credentials into // the client handshake response -func InjectCredentials(clientHandshake *HandshakeResponse41, salt []byte, username string, password string) (err error) { - - authResponse, err := NativePassword([]byte(password), salt) +// TODO: we can do a much better job than this. We just need to calculate the paylod and not do all this resetting BS! +// TODO: the plugin name should come from the server, not the client! +func InjectCredentials(authPlugin string, clientHandshake *HandshakeResponse41, salt []byte, username string, password string) (err error) { + fmt.Println("Injecting credentials to", authPlugin) + authResponse, err := CreateAuthResponse(authPlugin, []byte(password), salt) if err != nil { return } @@ -568,13 +692,14 @@ func InjectCredentials(clientHandshake *HandshakeResponse41, salt []byte, userna // Reset the payload length for the packet payloadLengthDiff := int32(len(username) - len(clientHandshake.Username)) payloadLengthDiff += int32(len(authResponse) - int(clientHandshake.AuthLength)) - payloadLengthDiff += int32(len(defaultAuthPluginName) - len(clientHandshake.AuthPluginName)) + payloadLengthDiff += int32(len(authPlugin) - len(clientHandshake.AuthPluginName)) clientHandshake.Header, err = UpdateHeaderPayloadLength(clientHandshake.Header, payloadLengthDiff) if err != nil { return } + clientHandshake.AuthPluginName = authPlugin clientHandshake.Username = username clientHandshake.AuthLength = int64(len(authResponse)) clientHandshake.AuthResponse = authResponse @@ -582,14 +707,42 @@ func InjectCredentials(clientHandshake *HandshakeResponse41, salt []byte, userna return } +// Hash password using MySQL 8+ method (SHA256) +func scrambleSHA256Password(password []byte, scramble []byte) []byte { + if len(password) == 0 { + return nil + } + + // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) + + crypt := sha256.New() + crypt.Write(password) + message1 := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1) + message1Hash := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1Hash) + crypt.Write(scramble) + message2 := crypt.Sum(nil) + + for i := range message1 { + message1[i] ^= message2[i] + } + + return message1 +} + // PackHandshakeResponse41 takes in a HandshakeResponse41 object and // returns a handshake response packet -func PackHandshakeResponse41(clientHandshake *HandshakeResponse41) (packet []byte, err error) { +func PackHandshakeResponse41(clientHandshake *HandshakeResponse41) ([]byte, error) { var buf bytes.Buffer // write the header (same as the original) - buf.Write(clientHandshake.Header) + // buf.Write(clientHandshake.Header) // write the capability flags capabilityFlagsBuf := make([]byte, 4) @@ -633,7 +786,7 @@ func PackHandshakeResponse41(clientHandshake *HandshakeResponse41) (packet []byt } // write auth plugin name - buf.WriteString(defaultAuthPluginName) + buf.WriteString(clientHandshake.AuthPluginName) buf.WriteByte(0) // write tail of packet (if set) @@ -641,9 +794,20 @@ func PackHandshakeResponse41(clientHandshake *HandshakeResponse41) (packet []byt buf.Write(clientHandshake.PacketTail) } - packet = buf.Bytes() + // Calculate the packet length (excluding the length field itself) + packetLength := buf.Len() - return + // Create a header buffer and write the packet length (int<3>) + headerBuffer := make([]byte, 4) + headerBuffer[0] = byte(packetLength & 0xFF) + headerBuffer[1] = byte((packetLength >> 8) & 0xFF) + headerBuffer[2] = byte((packetLength >> 16) & 0xFF) + headerBuffer[3] = clientHandshake.Header[3] + + // Combine the header and packet data to create the final packet + finalPacket := append(headerBuffer, buf.Bytes()...) + + return finalPacket, nil } // GetLenEncodedIntegerSize returns bytes count for length encoded integer @@ -744,6 +908,45 @@ func ReadNullTerminatedString(r *bytes.Reader) string { } } +type AuthSwitchRequest struct { + SequenceNumber uint8 + PluginName string + PluginData []byte +} + +// UnpackAuthSwitchRequest decodes an AuthSwitchRequest packet from the provided data. +func UnpackAuthSwitchRequest(data []byte) (*AuthSwitchRequest, error) { + sequenceNumber := data[3] + data = data[4:] + + // Find the position of the null-terminated plugin name + nullTerminatorIndex := 1 + for i := 1; i < len(data); i++ { + if data[i] == 0x00 { + nullTerminatorIndex = i + break + } + } + + // Extract the plugin name + if nullTerminatorIndex == 1 { + return nil, fmt.Errorf("Invalid AuthSwitchRequest packet: Missing plugin name") + } + pluginName := string(data[1:nullTerminatorIndex]) + + // Extract the plugin provided data + var pluginData []byte + if nullTerminatorIndex+1 < len(data) { + pluginData = data[nullTerminatorIndex+1:] + } + + return &AuthSwitchRequest{ + SequenceNumber: sequenceNumber, + PluginName: pluginName, + PluginData: pluginData, + }, nil +} + // ReadNullTerminatedBytes reads bytes from reader until 0x00 byte func ReadNullTerminatedBytes(r *bytes.Reader) (str []byte) { for { @@ -839,3 +1042,97 @@ func WriteUint24(u uint32) (b []byte) { return } + +// PackAuthSwitchResponse creates an AuthSwitchResponse packet with the provided response data. +func PackAuthSwitchResponse(authSwitchRequestSequenceId uint8, data []byte) ([]byte, error) { + // Create a buffer to write the packet data + buffer := new(bytes.Buffer) + + // Write the response data to the buffer + buffer.Write(data) + + // Calculate the packet length (excluding the length field itself) + packetLength := buffer.Len() + + // Create a header buffer and write the packet length (int<3>) + headerBuffer := make([]byte, 4) + headerBuffer[0] = byte(packetLength & 0xFF) + headerBuffer[1] = byte((packetLength >> 8) & 0xFF) + headerBuffer[2] = byte((packetLength >> 16) & 0xFF) + headerBuffer[3] = 0 + + // Combine the header and packet data to create the final packet + finalPacket := append(headerBuffer, buffer.Bytes()...) + + return finalPacket, nil +} + +type AuthMoreDataResponse struct { + PacketType byte + StatusTag byte +} + +// UnpackAuthMoreDataResponse decodes AuthMoreData from server. +// Basic packet structure shown below. +// +// int<3> PacketLength +// int<1> PacketNumber +// int<1> PacketType (0x01) +// int<1> StatusTag (0x03 or 0x04) +// string AuthenticationMethodData (unused by secretless) +func UnpackAuthMoreDataResponse(packet []byte) (*AuthMoreDataResponse, error) { + + // Min packet length = header(4 bytes) + PacketType(1 byte) + if err := CheckPacketLength(5, packet); err != nil { + return nil, err + } + + r := bytes.NewReader(packet) + + // Skip packet header + if _, err := GetPacketHeader(r); err != nil { + return nil, err + } + + // Read header, validate OK + packetType, err := r.ReadByte() + if err != nil { + return nil, err + } + if packetType != ResponseAuthMoreData { + return nil, errors.New("Malformed packet") + } + + // Read status tag + statusTag, err := r.ReadByte() + if err != nil { + return nil, err + } + + return &AuthMoreDataResponse{ + PacketType: packetType, + StatusTag: statusTag, + }, nil +} + +func PackAuthRequestPubKeyResponse() (packet []byte) { + packet = append(WriteUint24(uint32(1)), 3, CachingSha2PasswordRequestPublicKey) + return +} + +func PackAuthEncryptedPasswordResponse(encPwd []byte) (packet []byte) { + packet = append(WriteUint24(uint32(len(encPwd))), 5) + packet = append(packet, encPwd...) + return +} + +func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(seed) + plain[i] ^= seed[j] + } + sha1 := sha1.New() + return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) +} diff --git a/test/connector/tcp/mysql/Dockerfile.dev b/test/connector/tcp/mysql/Dockerfile.dev index 99cec2ca6..cb2f88bec 100644 --- a/test/connector/tcp/mysql/Dockerfile.dev +++ b/test/connector/tcp/mysql/Dockerfile.dev @@ -1,4 +1,3 @@ FROM secretless-dev -RUN apt-get update && \ - apt-get install -y default-mysql-client +COPY --from=mysql:8.1 /usr/bin/mysql /usr/bin/mysql diff --git a/test/connector/tcp/mysql/Dockerfile.mysql b/test/connector/tcp/mysql/Dockerfile.mysql deleted file mode 100644 index ee12922ee..000000000 --- a/test/connector/tcp/mysql/Dockerfile.mysql +++ /dev/null @@ -1,8 +0,0 @@ -FROM mysql/mysql-server:5.7 -RUN yum install openssl -y -RUN mkdir -p /etc/mysql/mysql.conf.d/ - -COPY etc/toggle_ssl.sh /docker-entrypoint-initdb.d/ -COPY etc/test.sql /docker-entrypoint-initdb.d/ - -COPY ./ssl /ssl diff --git a/test/connector/tcp/mysql/README.md b/test/connector/tcp/mysql/README.md index 3e71cd912..f64d0127d 100644 --- a/test/connector/tcp/mysql/README.md +++ b/test/connector/tcp/mysql/README.md @@ -1,46 +1,58 @@ -# MySQL Handler Development +# MySQL Connector Development ## Usage / known limitations -- The MySQL handler is currently limited to connections via Unix domain socket +- The MySQL connector is currently limited to connections via Unix domain socket + +### To use the MySQL connector -### To use the MySQL handler: #### Start your MySQL server + From this directory, call -``` + +```sh docker-compose up -d mysql ``` + This will automatically start a MySQL server in a Docker container at `localhost:$(docker-compose port mysql 3306)`. It will also configure the MySQL server as follows: + - Create a `testuser` user (with password `testpass`) - Authorize the `testuser` user to connect to the database server from any IP and access any schema - Create a table `test` in the `testdb` schema and add two rows #### Start and configure secretless-broker + From the root project directory, build the Secretless Broker binaries for your platform: -``` + +```sh platform=$(go run test/print_platform.go) ./bin/build $platform amd64 ``` From this directory, start Secretless Broker: -``` -./run_dev + +```sh +./dev ``` -#### Log in to the MySQL server via the MySQL handler +#### Log in to the MySQL server via the MySQL connector + In another terminal, navigate to the `test/mysql_handler` directory and send a MySQL request via Unix socket: _Note: Since the Secretless Broker container runs the daemon as a limited user, sockets should be mounted to `/sock` directory._ -``` +```sh mysql --socket=sock/mysql.sock ``` + or via TCP: + +```sh +mysql -h 0.0.0.0 -P 13306 -u testuser --ssl-mode=DISABLED ``` -mysql -h 0.0.0.0 -P 13306 --ssl-mode=DISABLED -``` + You may be prompted for a password, but you don't need to enter one; just hit return to continue. Once logged in, you should be able to `SELECT * FROM testdb.test` and see the rows that were added to the sample table. @@ -48,11 +60,13 @@ Once logged in, you should be able to `SELECT * FROM testdb.test` and see the ro Note: this assumes you have a MySQL client installed locally on your machine. In the examples above and when you run the test suite locally, it is assumed you use one like [mysqlsh](https://dev.mysql.com/doc/refman/5.7/en/mysqlsh.html), which assumes SSL connections when possible by default (and has an `--ssl-mode` flag you can use to disable SSL). If you use `mysqlsh`, you will need to create an executable `mysql` file in your `PATH` that contains the following in order to be able to run `run_dev_test` locally: -``` + +```sh #!/bin/bash -ex mysqlsh --sql "$@" ``` + This will run the MySQL shell as a client in SQL mode. ## MySQL Handler Development @@ -62,7 +76,8 @@ This will run the MySQL shell as a client in SQL mode. The easiest way to do Secretless Broker development is to use the VS Code debugger. As above, you will want to start up your MySQL server container before beginning development. To configure the Secretless Broker, you can provide VS Code with a `launch.json` file for debugging by copying the sample file below to `.vscode/launch.json`, replacing `[YOUR MYSQL PORT]` with the actual exposed port of your MySQL Docker container. Sample `launch.json`: -``` + +```json { // Use IntelliSense to learn about possible attributes. // Hover to view descriptions of existing attributes. @@ -70,29 +85,30 @@ Sample `launch.json`: "version": "0.2.0", "configurations": [ { - "name": "MySQL Handler", + "name": "MySQL Connector", "type": "go", "request": "launch", "mode": "debug", "remotePath": "", "port": 2345, "host": "127.0.0.1", - "program": "${workspaceFolder}/cmd/secretless/", + "program": "${workspaceFolder}/cmd/secretless-broker/", "env": { "MYSQL_HOST": "localhost", "MYSQL_PORT": "[YOUR MYSQL PORT]", "MYSQL_PASSWORD": "testpass" }, - "args": [ "-f", "/Users/gjennings/go/src/github.com/cyberark/secretless-broker/test/mysql_handler/secretless.dev.yml"], + "args": [ "-f", "${workspaceFolder}/test/connector/tcp/mysql/fixtures/secretless.dev.yml"], "showLog": true } ] } ``` -Once you start the debugger (which will automatically start the Secretless Broker with the dev MySQL Handler configuration), you can send requests to the MySQL server via a client as described above. +Once you start the debugger (which will automatically start the Secretless Broker with the dev MySQL Connector configuration), you can send requests to the MySQL server via a client as described above. ### Using Docker You can also run: -``` -cd test/mysql_handler/ + +```sh +cd test/connector/tcp/mysql/ ./start docker-compose run --rm secretless-dev ``` @@ -106,14 +122,17 @@ to connect via Unix socket. ## Running the test suite #### Run the tests in Docker + Make sure you have built updated Secretless Broker binaries for Linux and updated Docker images before running the test suite. To run the test suite in Docker, run: -``` + +```sh ./stop # Remove all existing project containers ./start # Stand up MySQL and Secretless Broker servers ./test # Run tests in a test container ``` + Make sure you build the project by running `./bin/build` in the project root before running the tests so that the test container will be using updated code. If you want to run using your local changes, you can run `./test -l` diff --git a/test/connector/tcp/mysql/docker-compose.yml b/test/connector/tcp/mysql/docker-compose.yml index 550b8a388..70dbc3fe7 100644 --- a/test/connector/tcp/mysql/docker-compose.yml +++ b/test/connector/tcp/mysql/docker-compose.yml @@ -1,10 +1,8 @@ version: '3.0' services: - mysql: - build: - context: . - dockerfile: Dockerfile.mysql + mysql_no_tls: &mysql_no_tls + image: mysql:8.1 ports: - 3306 healthcheck: @@ -12,23 +10,17 @@ services: interval: 1s timeout: 30s environment: - NO_SSL: "false" - MYSQL_ROOT_PASSWORD: securerootpass - - mysql_no_tls: - build: - context: . - dockerfile: Dockerfile.mysql - ports: - - 3306 - healthcheck: - test: ["CMD-SHELL", "mysqladmin -psecurerootpass status"] - interval: 1s - timeout: 30s - environment: - NO_SSL: "true" MYSQL_ROOT_PASSWORD: securerootpass + volumes: + - ./etc/test.sql:/docker-entrypoint-initdb.d/test.sql + - ./etc/no-ssl.cnf:/etc/mysql/conf.d/no-ssl.cnf + mysql: + <<: *mysql_no_tls + volumes: + - ./etc/test.sql:/docker-entrypoint-initdb.d/test.sql + - ./etc/ssl.cnf:/etc/mysql/conf.d/ssl.cnf + - ./ssl:/etc/mysql-ssl secretless-dev: image: secretless-dev command: ./bin/reflex @@ -36,7 +28,7 @@ services: DB_HOST_NO_TLS: mysql_no_tls DB_HOST_TLS: mysql DB_PORT: 3306 - DB_USER: testuser + DB_USER: testuser_native_password DB_PASSWORD: testpass volumes: - ../../../../:/secretless @@ -65,7 +57,7 @@ services: DB_HOST_TLS: mysql DB_HOST_NO_TLS: mysql_no_tls DB_PORT: 3306 - DB_USER: testuser + DB_USER: testuser_native_password DB_PASSWORD: testpass SECRETLESS_HOST: VERBOSE: diff --git a/test/connector/tcp/mysql/etc/no-ssl.cnf b/test/connector/tcp/mysql/etc/no-ssl.cnf new file mode 100644 index 000000000..d46c18c66 --- /dev/null +++ b/test/connector/tcp/mysql/etc/no-ssl.cnf @@ -0,0 +1,2 @@ +[mysqld] +tls_version='' diff --git a/test/connector/tcp/mysql/etc/ssl.cnf b/test/connector/tcp/mysql/etc/ssl.cnf new file mode 100644 index 000000000..67733f279 --- /dev/null +++ b/test/connector/tcp/mysql/etc/ssl.cnf @@ -0,0 +1,4 @@ +[mysqld] +ssl-ca=/etc/mysql-ssl/ca.pem +ssl-cert=/etc/mysql-ssl/server.pem +ssl-key=/etc/mysql-ssl/server-key.pem diff --git a/test/connector/tcp/mysql/etc/test.sql b/test/connector/tcp/mysql/etc/test.sql index b43b4725f..907c6533c 100644 --- a/test/connector/tcp/mysql/etc/test.sql +++ b/test/connector/tcp/mysql/etc/test.sql @@ -1,5 +1,17 @@ -GRANT ALL PRIVILEGES ON *.* TO 'testuser'@'localhost' IDENTIFIED BY 'testpass'; -GRANT ALL PRIVILEGES ON *.* TO 'testuser'@'%' IDENTIFIED BY 'testpass'; +-- Create a user with the old mysql_native_password plugin which was the default in MySQL 5.7 +-- This allows us to log in using legacy MySQL clients +CREATE USER 'testuser_native_password'@'localhost' IDENTIFIED WITH mysql_native_password BY 'testpass'; +GRANT ALL PRIVILEGES ON *.* TO 'testuser_native_password'@'localhost'; + +CREATE USER 'testuser_native_password'@'%' IDENTIFIED WITH mysql_native_password BY 'testpass'; +GRANT ALL PRIVILEGES ON *.* TO 'testuser_native_password'@'%'; + +-- Create a user with the new caching_sha2_password plugin which is the default in MySQL 8.0 +CREATE USER 'testuser'@'localhost' IDENTIFIED BY 'testpass'; +GRANT ALL PRIVILEGES ON *.* TO 'testuser'@'localhost'; + +CREATE USER 'testuser'@'%' IDENTIFIED BY 'testpass'; +GRANT ALL PRIVILEGES ON *.* TO 'testuser'@'%'; CREATE DATABASE testdb; diff --git a/test/connector/tcp/mysql/etc/toggle_ssl.sh b/test/connector/tcp/mysql/etc/toggle_ssl.sh deleted file mode 100755 index ef9abc3bf..000000000 --- a/test/connector/tcp/mysql/etc/toggle_ssl.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash -e -# This script is used to build the ROOT/test/mysql_handler mysql container image -# The script expects ROOT/test/util/ssl to contain the pre-generated -# shared SSL fixtures used during testing -# -# This script is housed in /docker-entrypoint-initdb.d/ inside the container image -# The envvar NO_SSL is used to toggle SSL for the mysql container image at startup - -if [[ "$NO_SSL" = "true" ]] -then - echo "removing SSL support" - rm /var/lib/mysql/*.pem - echo "ssl=0" >> /etc/my.cnf -else - cp /ssl/ca.pem /var/lib/mysql/ca.pem - cp /ssl/ca-key.pem /var/lib/mysql/ca-key.pem - cp /ssl/client.pem /var/lib/mysql/client-cert.pem - cp /ssl/client-key.pem /var/lib/mysql/client-key.pem - cp /ssl/server.pem /var/lib/mysql/server-cert.pem - cp /ssl/server-key.pem /var/lib/mysql/server-key.pem -fi diff --git a/test/connector/tcp/mysql/pkg/run_test_case.go b/test/connector/tcp/mysql/pkg/run_test_case.go index a622a7657..479bae1ec 100644 --- a/test/connector/tcp/mysql/pkg/run_test_case.go +++ b/test/connector/tcp/mysql/pkg/run_test_case.go @@ -18,7 +18,7 @@ func RunQuery( args := []string{"-e", "select count(*) from testdb.test"} if clientConfig.SSL { - args = append(args, "--ssl", "--ssl-verify-server-cert=TRUE") + args = append(args, "--ssl-mode", "VERIFY_CA") } if clientConfig.Username != "" { args = append(args, fmt.Sprintf("--user=%s", clientConfig.Username)) @@ -37,7 +37,7 @@ func RunQuery( } // ensures mysql can handle non-native auth - args = append(args, "--default-auth=mysql_clear_password") + // args = append(args, "--default-auth=mysql_clear_password") // Pre command logs println("") diff --git a/test/connector/tcp/mysql/tests/essentials_test.go b/test/connector/tcp/mysql/tests/essentials_test.go index 26772567b..32e724501 100644 --- a/test/connector/tcp/mysql/tests/essentials_test.go +++ b/test/connector/tcp/mysql/tests/essentials_test.go @@ -13,16 +13,16 @@ func TestEssentials(t *testing.T) { Description: "with username, wrong password", ShouldPass: true, ClientConfiguration: ClientConfiguration{ - Username: "testuser", - Password: "wrongpassword", + Username: "placeholder", + Password: "placeholder", }, }, { Description: "with wrong username, wrong password", ShouldPass: true, ClientConfiguration: ClientConfiguration{ - Username: "wrongusername", - Password: "wrongpassword", + Username: "placeholder", + Password: "placeholder", }, }, { @@ -69,11 +69,11 @@ func TestEssentials(t *testing.T) { Description: "Socket, client -> TLS -> secretless", ShouldPass: false, ClientConfiguration: ClientConfiguration{ - Username: "wrongusername", - Password: "wrongpassword", + Username: "placeholder", + Password: "placeholder", SSL: true, }, - CmdOutput: StringPointer("ERROR 2026 (HY000): TLS/SSL error: SSL is required, but the server does not support it"), + CmdOutput: StringPointer("SSL is required but the server doesn't support it"), }, }, t) @@ -88,11 +88,11 @@ func TestEssentials(t *testing.T) { Description: "TCP, client -> TLS -> secretless", ShouldPass: false, ClientConfiguration: ClientConfiguration{ - Username: "wrongusername", - Password: "wrongpassword", + Username: "placeholder", + Password: "placeholder", SSL: true, }, - CmdOutput: StringPointer("ERROR 2026 (HY000): TLS/SSL error: SSL is required, but the server does not support it"), + CmdOutput: StringPointer("SSL is required but the server doesn't support it"), }, }, t) @@ -108,10 +108,10 @@ func TestEssentials(t *testing.T) { Description: "secretless using invalid credentials", ShouldPass: false, ClientConfiguration: ClientConfiguration{ - Username: "testuser", - Password: "wrongpassword", + Username: "placeholder", + Password: "placeholder", }, - CmdOutput: StringPointer("ERROR 1045 (28000): Access denied for user 'testuser'@"), + CmdOutput: StringPointer("ERROR 1045 (28000): Access denied for user 'testuser_native_password'@"), }, }, t) }) From cb46995b2fb2a281e47471cb5fd46d3a5c0d5521 Mon Sep 17 00:00:00 2001 From: Kumbirai Tanekha Date: Thu, 28 Sep 2023 08:57:52 +0100 Subject: [PATCH 3/8] Update mysql unit tests --- .../tcp/mysql/protocol/protocol_test.go | 35 ++++++------------- test/connector/tcp/mysql/pkg/run_test_case.go | 3 -- 2 files changed, 10 insertions(+), 28 deletions(-) diff --git a/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go b/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go index 33318f31a..2a0089af0 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go @@ -247,20 +247,20 @@ func TestInjectCredentials(t *testing.T) { expectedAuth := []byte{0xf, 0xf8, 0xe1, 0xa3, 0xe7, 0xe3, 0x5f, 0xd2, 0xb1, 0x69, 0x8c, 0x39, 0x5b, 0xfa, 0x99, 0x4f, 0x53, 0xdd, 0xe5, 0x35} // 20 - expectedHeader := []byte{0x99, 0x0, 0x0, 0x1} - // expectedHeader[0] = 0xaa + (8 - 14) + (20 - 20) + (21 - 32) + expectedHeader := []byte{0xa2, 0x0, 0x0, 0x1} + // expectedHeader[0] = 0xaa + (8 - 14) + (20 - 20) + (21 - 23) // test with handshake response that already has auth set to another value handshake := HandshakeResponse41{ AuthLength: int64(20), - AuthPluginName: "non_native_mysql_native_password", // 32 + AuthPluginName: "caching_sha256_password", // 23 AuthResponse: []byte{0xc0, 0xb, 0xbc, 0xb6, 0x6, 0xf5, 0x4f, 0x4e, 0xf4, 0x1b, 0x87, 0xc0, 0xb8, 0x89, 0xae, 0xc4, 0x49, 0x7c, 0x46, 0xf3}, // 20 Username: "madeupusername", // 14 Header: []byte{0xaa, 0x0, 0x0, 0x1}, } - err := InjectCredentials(&handshake, salt, username, password) + err := InjectCredentials("mysql_native_password", &handshake, salt, username, password) assert.Equal(t, username, handshake.Username) assert.Equal(t, int64(20), handshake.AuthLength) @@ -279,7 +279,7 @@ func TestInjectCredentials(t *testing.T) { Header: []byte{0xaa, 0x0, 0x0, 0x1}, } - err = InjectCredentials(&handshake, salt, username, password) + err = InjectCredentials("mysql_native_password", &handshake, salt, username, password) assert.Equal(t, username, handshake.Username) assert.Equal(t, int64(20), handshake.AuthLength) @@ -289,8 +289,8 @@ func TestInjectCredentials(t *testing.T) { } func TestPackHandshakeResponse41(t *testing.T) { - input := HandshakeResponse41{ - Header: []byte{0xaa, 0x0, 0x0, 0x1}, + input := &HandshakeResponse41{ + Header: []byte{0x95, 0x0, 0x0, 0x1}, CapabilityFlags: uint32(33464965), MaxPacketSize: uint32(1073741824), ClientCharset: uint8(8), @@ -310,26 +310,11 @@ func TestPackHandshakeResponse41(t *testing.T) { 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x6, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34}, } - expected := []byte{0xaa, 0x0, 0x0, 0x1, 0x85, 0xa2, 0xfe, 0x1, 0x0, - 0x0, 0x0, 0x40, 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x72, 0x6f, 0x67, 0x65, 0x72, 0x0, 0x14, 0xc0, - 0xb, 0xbc, 0xb6, 0x6, 0xf5, 0x4f, 0x4e, 0xf4, 0x1b, 0x87, 0xc0, - 0xb8, 0x89, 0xae, 0xc4, 0x49, 0x7c, 0x46, 0xf3, 0x6d, 0x79, 0x73, - 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, - 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x0, 0x58, 0x3, 0x5f, - 0x6f, 0x73, 0xa, 0x6d, 0x61, 0x63, 0x6f, 0x73, 0x31, 0x30, 0x2e, - 0x31, 0x32, 0xc, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, - 0x6e, 0x61, 0x6d, 0x65, 0x8, 0x6c, 0x69, 0x62, 0x6d, 0x79, 0x73, - 0x71, 0x6c, 0x4, 0x5f, 0x70, 0x69, 0x64, 0x5, 0x36, 0x36, 0x34, - 0x37, 0x39, 0xf, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, - 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x6, 0x35, 0x2e, 0x37, - 0x2e, 0x32, 0x30, 0x9, 0x5f, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, - 0x72, 0x6d, 0x6, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34} - output, err := PackHandshakeResponse41(&input) + output, err := PackHandshakeResponse41(input) + newInput, err := UnpackHandshakeResponse41(output) - assert.Equal(t, expected, output) + assert.Equal(t, input, newInput) assert.Equal(t, nil, err) } diff --git a/test/connector/tcp/mysql/pkg/run_test_case.go b/test/connector/tcp/mysql/pkg/run_test_case.go index 479bae1ec..44d9af4b6 100644 --- a/test/connector/tcp/mysql/pkg/run_test_case.go +++ b/test/connector/tcp/mysql/pkg/run_test_case.go @@ -36,9 +36,6 @@ func RunQuery( panic("Listener Type can only be TCP or Socket") } - // ensures mysql can handle non-native auth - // args = append(args, "--default-auth=mysql_clear_password") - // Pre command logs println("") println("---<< EXECUTED") From a89f57f446ee2eef05ac0b2a747cc52093c66995 Mon Sep 17 00:00:00 2001 From: Kumbirai Tanekha Date: Thu, 28 Sep 2023 08:58:26 +0100 Subject: [PATCH 4/8] Bugfix: mysql: Read client response before validating SSL --- .../tcp/mysql/authentication_handshake.go | 37 ++----------------- 1 file changed, 4 insertions(+), 33 deletions(-) diff --git a/internal/plugin/connectors/tcp/mysql/authentication_handshake.go b/internal/plugin/connectors/tcp/mysql/authentication_handshake.go index f851b83df..f9d24ccf4 100644 --- a/internal/plugin/connectors/tcp/mysql/authentication_handshake.go +++ b/internal/plugin/connectors/tcp/mysql/authentication_handshake.go @@ -100,13 +100,14 @@ func (h *AuthenticationHandshake) Run() error { // Pass along the handshake but make sure client doesn't use TLS to connect to Secretless h.writeHandshakeToClient() - // Does the server support TLS when needed? - h.validateServerSSL() // Get the client handshake response. I thought it would be good to only make this // use the simpliest auth mechanism, but no need since the server won't do the entire dance. // It will only ever return success or error, not auth switch or anything else h.readClientHandshakeResponse() + // Does the server support TLS when needed? + h.validateServerSSL() + // Everything in here is about responding to the server authentication challenge // No need to talk to the client an if true { @@ -121,36 +122,6 @@ func (h *AuthenticationHandshake) Run() error { return h.err } -func (h *AuthenticationHandshake) NewRun() error { - backendPacket := h.readBackendPacket() - if h.err != nil { - return h.err - } - - serverHandshake, err := protocol.UnpackHandshakeV10(backendPacket) - if err != nil { - return err - } - fmt.Println("serverHandshake:", serverHandshake) - - backendPacket1, err := protocol.PackHandshakeV10(serverHandshake) - if err != nil { - return err - } - fmt.Println("backendPacket:", backendPacket) - fmt.Println("backendPacket:", backendPacket1) - - serverHandshake, err = protocol.UnpackHandshakeV10(backendPacket) - if err != nil { - return err - } - fmt.Println("serverHandshake:", serverHandshake) - - return h.clientConn.write(backendPacket1) - - // return nil -} - // AuthenticatedBackendConn returns an already authenticated connection // to the MySQL server. Intended to be called after Run() has completed. func (h *AuthenticationHandshake) AuthenticatedBackendConn() net.Conn { @@ -347,7 +318,7 @@ func (h *AuthenticationHandshake) verifyAndProxyOkResponse() { } packetType := protocol.GetPacketType(rawPkt) - fmt.Println("packetType", packetType) + switch packetType { case protocol.ResponseErr: // Return after adding the error response to AuthenticationHandshake From fc5c6abeeb40a3a71519994a5360ebfc1ca500a2 Mon Sep 17 00:00:00 2001 From: Kumbirai Tanekha Date: Thu, 28 Sep 2023 09:23:35 +0100 Subject: [PATCH 5/8] Clean up and refactor mysql protocol --- .../connectors/tcp/mysql/protocol/protocol.go | 115 ++++-------------- .../tcp/mysql/protocol/protocol_test.go | 63 ++-------- 2 files changed, 28 insertions(+), 150 deletions(-) diff --git a/internal/plugin/connectors/tcp/mysql/protocol/protocol.go b/internal/plugin/connectors/tcp/mysql/protocol/protocol.go index b1e728d34..8f72a5d1f 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/protocol.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/protocol.go @@ -422,20 +422,7 @@ func PackHandshakeV10(serverHandshake *HandshakeV10) ([]byte, error) { buffer.WriteByte(0) } - // Calculate the packet length (excluding the length field itself) - packetLength := buffer.Len() - - // Create a header buffer and write the packet length (int<3>) - headerBuffer := make([]byte, 4) - headerBuffer[0] = byte(packetLength & 0xFF) - headerBuffer[1] = byte((packetLength >> 8) & 0xFF) - headerBuffer[2] = byte((packetLength >> 16) & 0xFF) - headerBuffer[3] = 0 - - // Combine the header and packet data to create the final packet - finalPacket := append(headerBuffer, buffer.Bytes()...) - - return finalPacket, nil + return AddHeaderToPacket(0, buffer.Bytes()), nil } // RemoveSSLFromHandshakeV10 removes Client SSL Capability from Server @@ -541,7 +528,7 @@ func writeUint16(data []byte, pos int, value uint16) { // | string | payload | [len=payload_length] payload of the packet | // +-------------+----------------+---------------------------------------------+ type HandshakeResponse41 struct { - Header []byte + SequenceID uint8 CapabilityFlags uint32 MaxPacketSize uint32 ClientCharset uint8 @@ -645,7 +632,7 @@ func UnpackHandshakeResponse41(packet []byte) (*HandshakeResponse41, error) { } return &HandshakeResponse41{ - Header: header, + SequenceID: header[3], CapabilityFlags: capabilityFlags, MaxPacketSize: maxPacketSize, ClientCharset: charset, @@ -683,22 +670,11 @@ func CreateAuthResponse(authPlugin string, password []byte, salt []byte) ([]byte // TODO: we can do a much better job than this. We just need to calculate the paylod and not do all this resetting BS! // TODO: the plugin name should come from the server, not the client! func InjectCredentials(authPlugin string, clientHandshake *HandshakeResponse41, salt []byte, username string, password string) (err error) { - fmt.Println("Injecting credentials to", authPlugin) authResponse, err := CreateAuthResponse(authPlugin, []byte(password), salt) if err != nil { return } - // Reset the payload length for the packet - payloadLengthDiff := int32(len(username) - len(clientHandshake.Username)) - payloadLengthDiff += int32(len(authResponse) - int(clientHandshake.AuthLength)) - payloadLengthDiff += int32(len(authPlugin) - len(clientHandshake.AuthPluginName)) - - clientHandshake.Header, err = UpdateHeaderPayloadLength(clientHandshake.Header, payloadLengthDiff) - if err != nil { - return - } - clientHandshake.AuthPluginName = authPlugin clientHandshake.Username = username clientHandshake.AuthLength = int64(len(authResponse)) @@ -794,20 +770,7 @@ func PackHandshakeResponse41(clientHandshake *HandshakeResponse41) ([]byte, erro buf.Write(clientHandshake.PacketTail) } - // Calculate the packet length (excluding the length field itself) - packetLength := buf.Len() - - // Create a header buffer and write the packet length (int<3>) - headerBuffer := make([]byte, 4) - headerBuffer[0] = byte(packetLength & 0xFF) - headerBuffer[1] = byte((packetLength >> 8) & 0xFF) - headerBuffer[2] = byte((packetLength >> 16) & 0xFF) - headerBuffer[3] = clientHandshake.Header[3] - - // Combine the header and packet data to create the final packet - finalPacket := append(headerBuffer, buf.Bytes()...) - - return finalPacket, nil + return AddHeaderToPacket(clientHandshake.SequenceID, buf.Bytes()), nil } // GetLenEncodedIntegerSize returns bytes count for length encoded integer @@ -1007,40 +970,20 @@ func NativePassword(password []byte, salt []byte) (nativePassword []byte, err er return } -// UpdateHeaderPayloadLength takes in a 4 byte header and a difference -// in length, and returns a new header -func UpdateHeaderPayloadLength(origHeader []byte, diff int32) (header []byte, err error) { - - initialPayloadLength, err := ReadUint24(origHeader[0:3]) - if err != nil { - return nil, err - } - updatedPayloadLength := int32(initialPayloadLength) + diff - if updatedPayloadLength < 0 { - return nil, errors.New("Malformed packet") - } - header = append(WriteUint24(uint32(updatedPayloadLength)), origHeader[3]) - - return -} - -// ReadUint24 takes in a byte slice and returns a uint32 -func ReadUint24(b []byte) (uint32, error) { - if len(b) < 3 { - return 0, errors.New("Invalid packet") - } - - return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16, nil -} +// AddHeaderToPacket adds a header to a packet +func AddHeaderToPacket(sequenceId uint8, restOfPacket []byte) []byte { + // Calculate the packet length (excluding the length field itself) + packetLength := len(restOfPacket) -// WriteUint24 takes in a uint32 and returns a byte slice -func WriteUint24(u uint32) (b []byte) { - b = make([]byte, 3) - b[0] = byte(u) - b[1] = byte(u >> 8) - b[2] = byte(u >> 16) + // Create a header buffer and write the packet length (int<3>) + headerBuffer := make([]byte, 4) + headerBuffer[0] = byte(packetLength & 0xFF) + headerBuffer[1] = byte((packetLength >> 8) & 0xFF) + headerBuffer[2] = byte((packetLength >> 16) & 0xFF) + headerBuffer[3] = sequenceId - return + // Combine the header and packet data to create the final packet + return append(headerBuffer, restOfPacket...) } // PackAuthSwitchResponse creates an AuthSwitchResponse packet with the provided response data. @@ -1051,20 +994,7 @@ func PackAuthSwitchResponse(authSwitchRequestSequenceId uint8, data []byte) ([]b // Write the response data to the buffer buffer.Write(data) - // Calculate the packet length (excluding the length field itself) - packetLength := buffer.Len() - - // Create a header buffer and write the packet length (int<3>) - headerBuffer := make([]byte, 4) - headerBuffer[0] = byte(packetLength & 0xFF) - headerBuffer[1] = byte((packetLength >> 8) & 0xFF) - headerBuffer[2] = byte((packetLength >> 16) & 0xFF) - headerBuffer[3] = 0 - - // Combine the header and packet data to create the final packet - finalPacket := append(headerBuffer, buffer.Bytes()...) - - return finalPacket, nil + return AddHeaderToPacket(0, buffer.Bytes()), nil } type AuthMoreDataResponse struct { @@ -1115,15 +1045,12 @@ func UnpackAuthMoreDataResponse(packet []byte) (*AuthMoreDataResponse, error) { }, nil } -func PackAuthRequestPubKeyResponse() (packet []byte) { - packet = append(WriteUint24(uint32(1)), 3, CachingSha2PasswordRequestPublicKey) - return +func PackAuthRequestPubKeyResponse() []byte { + return AddHeaderToPacket(3, []byte{CachingSha2PasswordRequestPublicKey}) } -func PackAuthEncryptedPasswordResponse(encPwd []byte) (packet []byte) { - packet = append(WriteUint24(uint32(len(encPwd))), 5) - packet = append(packet, encPwd...) - return +func PackAuthEncryptedPasswordResponse(encPwd []byte) []byte { + return AddHeaderToPacket(5, encPwd) } func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { diff --git a/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go b/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go index 2a0089af0..90bdbb4e5 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go @@ -195,7 +195,7 @@ func TestUnpackHandshakeV10(t *testing.T) { func TestUnpackHandshakeResponse41(t *testing.T) { expected := HandshakeResponse41{ - Header: []byte{0xaa, 0x0, 0x0, 0x1}, + SequenceID: 1, CapabilityFlags: uint32(33464965), MaxPacketSize: uint32(1073741824), ClientCharset: uint8(8), @@ -256,8 +256,8 @@ func TestInjectCredentials(t *testing.T) { AuthPluginName: "caching_sha256_password", // 23 AuthResponse: []byte{0xc0, 0xb, 0xbc, 0xb6, 0x6, 0xf5, 0x4f, 0x4e, 0xf4, 0x1b, 0x87, 0xc0, 0xb8, 0x89, 0xae, 0xc4, 0x49, 0x7c, 0x46, 0xf3}, // 20 - Username: "madeupusername", // 14 - Header: []byte{0xaa, 0x0, 0x0, 0x1}, + Username: "madeupusername", // 14 + SequenceID: 1, } err := InjectCredentials("mysql_native_password", &handshake, salt, username, password) @@ -265,7 +265,7 @@ func TestInjectCredentials(t *testing.T) { assert.Equal(t, username, handshake.Username) assert.Equal(t, int64(20), handshake.AuthLength) assert.Equal(t, expectedAuth, handshake.AuthResponse) - assert.Equal(t, expectedHeader, handshake.Header) + assert.Equal(t, expectedHeader[3], handshake.SequenceID) assert.Equal(t, nil, err) // test with handshake response with empty auth and mysql_native_password @@ -276,7 +276,7 @@ func TestInjectCredentials(t *testing.T) { AuthPluginName: "mysql_native_password", // 21 AuthResponse: []byte{}, // 0 Username: "madeupusername", // 14 - Header: []byte{0xaa, 0x0, 0x0, 0x1}, + SequenceID: 1, } err = InjectCredentials("mysql_native_password", &handshake, salt, username, password) @@ -284,13 +284,13 @@ func TestInjectCredentials(t *testing.T) { assert.Equal(t, username, handshake.Username) assert.Equal(t, int64(20), handshake.AuthLength) assert.Equal(t, expectedAuth, handshake.AuthResponse) - assert.Equal(t, expectedHeader, handshake.Header) + assert.Equal(t, expectedHeader[3], handshake.SequenceID) assert.Equal(t, nil, err) } func TestPackHandshakeResponse41(t *testing.T) { input := &HandshakeResponse41{ - Header: []byte{0x95, 0x0, 0x0, 0x1}, + SequenceID: 1, CapabilityFlags: uint32(33464965), MaxPacketSize: uint32(1073741824), ClientCharset: uint8(8), @@ -420,52 +420,3 @@ func TestNativePassword(t *testing.T) { assert.Equal(t, expected, output) assert.Equal(t, nil, err) } - -func TestUpdateHeaderPayloadLength(t *testing.T) { - // Test with a valid negative value - expectedHeader := []byte{170, 0, 0, 0} - inputHeader := []byte{173, 0, 0, 0} - inputLength := int32(-3) - - output, err := UpdateHeaderPayloadLength(inputHeader, inputLength) - - assert.Equal(t, expectedHeader, output) - assert.Equal(t, nil, err) - - // Test with a valid positive value - expectedHeader = []byte{176, 0, 0, 0} - inputHeader = []byte{173, 0, 0, 0} - inputLength = int32(3) - - output, err = UpdateHeaderPayloadLength(inputHeader, inputLength) - - assert.Equal(t, expectedHeader, output) - assert.Equal(t, nil, err) - - // Test with an invalid value for the length difference - inputHeader = []byte{173, 0, 0, 0} - inputLength = int32(-180) - - output, err = UpdateHeaderPayloadLength(inputHeader, inputLength) - - assert.EqualError(t, err, "Malformed packet") -} - -func TestReadUint24(t *testing.T) { - expected := uint32(173) - input := []byte{173, 0, 0} - - output, err := ReadUint24(input) - - assert.Equal(t, expected, output) - assert.Equal(t, nil, err) -} - -func TestWriteUint24(t *testing.T) { - expected := []byte{173, 0, 0} - input := uint32(173) - - output := WriteUint24(input) - - assert.Equal(t, expected, output) -} From b3c5c4064dba2141dc246d922cebffb9d7a6983c Mon Sep 17 00:00:00 2001 From: Kumbirai Tanekha Date: Thu, 28 Sep 2023 14:57:07 +0100 Subject: [PATCH 6/8] Handle caching_sha256_password with TLS --- .../tcp/mysql/authentication_handshake.go | 124 +++++++++++------- .../connectors/tcp/mysql/protocol/const.go | 21 +-- .../connectors/tcp/mysql/protocol/packet.go | 8 +- .../connectors/tcp/mysql/protocol/protocol.go | 27 ++-- 4 files changed, 109 insertions(+), 71 deletions(-) diff --git a/internal/plugin/connectors/tcp/mysql/authentication_handshake.go b/internal/plugin/connectors/tcp/mysql/authentication_handshake.go index f9d24ccf4..57f0758af 100644 --- a/internal/plugin/connectors/tcp/mysql/authentication_handshake.go +++ b/internal/plugin/connectors/tcp/mysql/authentication_handshake.go @@ -96,28 +96,41 @@ func NewAuthenticationHandshake( // MySQL server and client. When it completes successfully, // AuthenticatedBackendConn will return the raw, authenticated network conn. func (h *AuthenticationHandshake) Run() error { + // The server is the first to communicate. Read the server handshake h.readServerHandshake() - // Pass along the handshake but make sure client doesn't use TLS to connect to Secretless + // Pass along the server handshake to the client, with some minor modifications + // + // 1. Remove TLS capability to avoid TLS connections to Secretless. + // 2. Use `mysql_native_password` as the auth plugin between the client and Secretless, to make + // life easier. We actually don't care about the credentials from the client, we just need the rest + // of the packet. h.writeHandshakeToClient() - // Get the client handshake response. I thought it would be good to only make this - // use the simpliest auth mechanism, but no need since the server won't do the entire dance. - // It will only ever return success or error, not auth switch or anything else + // Read the client handshake response. + // + // We are done listening to the client! h.readClientHandshakeResponse() - // Does the server support TLS when needed? + // Make sure if the connector (not the client) is configured to use TLS + // then the server supports TLS. This must be done after reading the client response otherwise if validation + // fails then the client connection hangs h.validateServerSSL() - // Everything in here is about responding to the server authentication challenge - // No need to talk to the client an - if true { - h.overrideClientCapabilities() - h.injectCredentials() - // Deal with this later - h.handleClientSSLRequest() - h.writeClientHandshakeResponseToBackend() - h.verifyAndProxyOkResponse() - } + // Everything beyond this point is in service of responding to the server authentication challenge. + + // Override client capabilities. For example, the connector has secure connection capabilities and supports + // authentication plugins. + h.overrideClientCapabilities() + // Inject credentials into the client handshake response + h.injectCredentials() + + h.handleClientSSLRequest() + + // Write modified client handshake to server, and + // carry out the rest of the authentication dance between Secretless and the server. + // When we're done we just let the client know of the outcome. + h.writeClientHandshakeResponseToBackend() + h.handleBackendAuthResponse() return h.err } @@ -186,9 +199,9 @@ func (h *AuthenticationHandshake) readClientHandshakeResponse() { } // TODO: client requesting SSL results in ERROR 2026 (HY000): SSL connection error: protocol version mismatch - // TODO: Find out if this is a client request SSL after we advertise not supporting SSL ? h.clientHandshakeResponse, h.err = protocol.UnpackHandshakeResponse41(rawResponse) } + func (h *AuthenticationHandshake) overrideClientCapabilities() { if h.err != nil { return @@ -202,6 +215,7 @@ func (h *AuthenticationHandshake) overrideClientCapabilities() { // TODO: add tests cases for client secure connection // Enable CapabilityFlag for client secure connection + // TODO: explore weird heisenbug when this is toggled off: ERROR: 1043 (08S01): Bad handshake h.clientHandshakeResponse.CapabilityFlags |= protocol.ClientSecureConnection // Ensure CapabilityFlag is set when using TLS @@ -312,25 +326,27 @@ func (h *AuthenticationHandshake) verifyAndProxyOkResponse() { return } + // This proxying needs to take place to ensure the client gets the OK packet with + // the correct sequence id, the connection keeps track of this information whereas + // Secretless duplex streaming does not. rawPkt := h.readBackendPacket() + h.writeClientPacket(rawPkt) +} + +func (h *AuthenticationHandshake) handleBackendAuthResponse() { if h.err != nil { return } - packetType := protocol.GetPacketType(rawPkt) - - switch packetType { - case protocol.ResponseErr: - // Return after adding the error response to AuthenticationHandshake - // as a protocol.Error type - // - // The protocol.Error type makes it possible - // to have Go errors that can contain rich protocol specific information - // and have the smarts to encode themselves into a MYSQL error packet - err := protocol.UnpackErrResponse(rawPkt) - h.err = err + rawPkt := h.readBackendPacket() + if h.err != nil { return + } + + switch protocol.GetPacketType(rawPkt) { case protocol.ResponseAuthMoreData: + defer h.verifyAndProxyOkResponse() + moreDataResp, err := protocol.UnpackAuthMoreDataResponse(rawPkt) if err != nil { h.err = err @@ -345,24 +361,41 @@ func (h *AuthenticationHandshake) verifyAndProxyOkResponse() { return case protocol.CachingSha2PasswordPerformFullAuthentication: // The server is requesting a full authentication handshake. - //https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html - //https://github.com/go-sql-driver/mysql/blob/master/auth.go#L353 + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html + // https://github.com/go-sql-driver/mysql/blob/master/auth.go#L353 + + if h.clientRequestedSSL() { + data, err := protocol.PackAuthSwitchResponse( + h.backendConn.sequenceID, + append([]byte(h.connectionDetails.Password), 0), + ) + if err != nil { + h.err = err + return + } + + h.writeBackendPacket(data) + if h.err != nil { + return + } + return + } - // request public key from server - data := protocol.PackAuthRequestPubKeyResponse() + // Request public key from server + data := protocol.PackAuthRequestPubKeyResponse(h.backendConn.sequenceID) h.writeBackendPacket(data) if h.err != nil { return } - // read public key from server + // Read public key from server pubKeyPkt := h.readBackendPacket() if h.err != nil { return } - // parse public key + // Parse public key if pubKeyPkt[4] != protocol.ResponseAuthMoreData { h.err = errors.New("expected ResponseAuthMoreData packet") //TODO: For some reason we're getting "28000Access denied for user..." packets here in some cases @@ -381,21 +414,24 @@ func (h *AuthenticationHandshake) verifyAndProxyOkResponse() { } pubKey := pkix.(*rsa.PublicKey) - // encrypt password with public key + // Encrypt password with public key enc, err := protocol.EncryptPassword(h.connectionDetails.Password, h.serverHandshake.Salt, pubKey) if err != nil { h.err = err return } - encPkt := protocol.PackAuthEncryptedPasswordResponse(enc) + encPkt := protocol.PackAuthEncryptedPasswordResponse(h.backendConn.sequenceID, enc) h.writeBackendPacket(encPkt) - return } + return - case 0xfe: + + case protocol.ResponseAuthSwitchRequest: + defer h.verifyAndProxyOkResponse() + authSwitchRequest, err := protocol.UnpackAuthSwitchRequest(rawPkt) if err != nil { h.err = err @@ -424,19 +460,15 @@ func (h *AuthenticationHandshake) verifyAndProxyOkResponse() { } h.writeBackendPacket(authSwitchResponseData) - // TODO: Avoid infinite recursion? - h.verifyAndProxyOkResponse() return default: - // Verify packet is valid; don't do anything with unpacked - if _, err := protocol.UnpackOkResponse(rawPkt); err != nil { - h.err = err - return - } + // Let the client deal with it + h.writeClientPacket(rawPkt) + + return } - h.writeClientPacket(rawPkt) } func (h *AuthenticationHandshake) dbSSLMode() *ssl.DbSSLMode { diff --git a/internal/plugin/connectors/tcp/mysql/protocol/const.go b/internal/plugin/connectors/tcp/mysql/protocol/const.go index 3801b6984..f60c155ce 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/const.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/const.go @@ -27,12 +27,13 @@ package protocol // Random constants const ( // MySQL response types - ResponseAuthMoreData = 0x01 - ResponseEOF = 0xfe - responseOk = 0x00 - responsePrepareOk = 0x00 - ResponseErr = 0xff - responseLocalinfile = 0xfb + ResponseAuthMoreData = 0x01 + ResponseEOF = 0xfe + ResponseOk = 0x00 + ResponsePrepareOk = 0x00 + ResponseAuthSwitchRequest = 0xfe + ResponseErr = 0xff + responseLocalinfile = 0xfb // MySQL field types constants fieldTypeString = 0xfd @@ -50,10 +51,10 @@ const ( // Digits after comma doubleDecodePrecision = 6 - defaultAuthPluginName = "mysql_native_password" - CachingSha2PasswordRequestPublicKey = 2 - CachingSha2PasswordFastAuthSuccess = 3 - CachingSha2PasswordPerformFullAuthentication = 4 + // caching_sha256_password authentication plugin constants + CachingSha2PasswordRequestPublicKey = 0x02 + CachingSha2PasswordFastAuthSuccess = 0x03 + CachingSha2PasswordPerformFullAuthentication = 0x04 ) // Protocol commands diff --git a/internal/plugin/connectors/tcp/mysql/protocol/packet.go b/internal/plugin/connectors/tcp/mysql/protocol/packet.go index 7d717a895..bde538005 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/packet.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/packet.go @@ -93,7 +93,7 @@ func ReadPrepareResponse(conn net.Conn) ([]byte, byte, error) { } switch pkt[4] { - case responsePrepareOk: + case ResponsePrepareOk: numParams := binary.LittleEndian.Uint16(pkt[9:11]) numColumns := binary.LittleEndian.Uint16(pkt[11:13]) packetsExpected := 0 @@ -121,7 +121,7 @@ func ReadPrepareResponse(conn net.Conn) ([]byte, byte, error) { data = append(data, pkt...) } - return data, responseOk, nil + return data, ResponseOk, nil case ResponseErr: return pkt, ResponseErr, nil @@ -148,8 +148,8 @@ func ReadResponse(conn net.Conn, deprecateEOF bool) ([]byte, byte, error) { } switch pkt[4] { - case responseOk: - return pkt, responseOk, nil + case ResponseOk: + return pkt, ResponseOk, nil case ResponseErr: return pkt, ResponseErr, nil diff --git a/internal/plugin/connectors/tcp/mysql/protocol/protocol.go b/internal/plugin/connectors/tcp/mysql/protocol/protocol.go index 8f72a5d1f..c979c41c5 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/protocol.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/protocol.go @@ -148,7 +148,7 @@ func UnpackOkResponse(packet []byte) (*OkResponse, error) { if err != nil { return nil, err } - if packetType != responseOk { + if packetType != ResponseOk { return nil, errors.New("Malformed packet") } @@ -190,6 +190,7 @@ func UnpackOkResponse(packet []byte) (*OkResponse, error) { // See https://mariadb.com/kb/en/mariadb/1-connecting-connecting/#initial-handshake-packet type HandshakeV10 struct { ProtocolVersion byte + SequenceId uint8 ServerVersion string ConnectionID uint32 StatusFlags uint16 @@ -241,8 +242,9 @@ type HandshakeV10 struct { func UnpackHandshakeV10(packet []byte) (*HandshakeV10, error) { r := bytes.NewReader(packet) - // Skip packet header - if _, err := GetPacketHeader(r); err != nil { + // Header + header, err := GetPacketHeader(r) + if err != nil { return nil, err } @@ -347,6 +349,7 @@ func UnpackHandshakeV10(packet []byte) (*HandshakeV10, error) { } return &HandshakeV10{ + SequenceId: header[3], ProtocolVersion: protoVersion, ServerVersion: serverVersion, ConnectionID: connectionID, @@ -422,7 +425,7 @@ func PackHandshakeV10(serverHandshake *HandshakeV10) ([]byte, error) { buffer.WriteByte(0) } - return AddHeaderToPacket(0, buffer.Bytes()), nil + return AddHeaderToPacket(serverHandshake.SequenceId, buffer.Bytes()), nil } // RemoveSSLFromHandshakeV10 removes Client SSL Capability from Server @@ -994,10 +997,11 @@ func PackAuthSwitchResponse(authSwitchRequestSequenceId uint8, data []byte) ([]b // Write the response data to the buffer buffer.Write(data) - return AddHeaderToPacket(0, buffer.Bytes()), nil + return AddHeaderToPacket(authSwitchRequestSequenceId, buffer.Bytes()), nil } type AuthMoreDataResponse struct { + SequenceId uint8 PacketType byte StatusTag byte } @@ -1019,8 +1023,8 @@ func UnpackAuthMoreDataResponse(packet []byte) (*AuthMoreDataResponse, error) { r := bytes.NewReader(packet) - // Skip packet header - if _, err := GetPacketHeader(r); err != nil { + header, err := GetPacketHeader(r) + if err != nil { return nil, err } @@ -1040,17 +1044,18 @@ func UnpackAuthMoreDataResponse(packet []byte) (*AuthMoreDataResponse, error) { } return &AuthMoreDataResponse{ + SequenceId: header[3], PacketType: packetType, StatusTag: statusTag, }, nil } -func PackAuthRequestPubKeyResponse() []byte { - return AddHeaderToPacket(3, []byte{CachingSha2PasswordRequestPublicKey}) +func PackAuthRequestPubKeyResponse(sequenceId uint8) []byte { + return AddHeaderToPacket(sequenceId, []byte{CachingSha2PasswordRequestPublicKey}) } -func PackAuthEncryptedPasswordResponse(encPwd []byte) []byte { - return AddHeaderToPacket(5, encPwd) +func PackAuthEncryptedPasswordResponse(sequenceId uint8, encPwd []byte) []byte { + return AddHeaderToPacket(sequenceId, encPwd) } func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { From c757e0d2d949e65efaf7da7b8e2dc411970030c4 Mon Sep 17 00:00:00 2001 From: Kumbirai Tanekha Date: Thu, 28 Sep 2023 17:00:41 +0100 Subject: [PATCH 7/8] Add unit test for mysql#PackHandshakeV10 --- .../tcp/mysql/protocol/protocol_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go b/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go index 90bdbb4e5..9bdfd878a 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go @@ -193,6 +193,24 @@ func TestUnpackHandshakeV10(t *testing.T) { } } +func TestPackHandshakeV10(t *testing.T) { + input := &HandshakeV10{ + ProtocolVersion: byte(10), + ServerVersion: "5.5.56", + ConnectionID: uint32(1630), + AuthPlugin: "mysql_native_password", + ServerCapabilities: binary.LittleEndian.Uint32([]byte{255, 247, 15, 128}), + Salt: []byte{0x48, 0x6a, 0x5b, 0x6a, 0x24, 0x71, 0x30, 0x3a, 0x6f, 0x43, 0x40, 0x56, 0x6e, 0x4b, + 0x68, 0x4a, 0x79, 0x46, 0x30, 0x5a}, + } + + output, err := PackHandshakeV10(input) + newInput, err := UnpackHandshakeV10(output) + + assert.Equal(t, input, newInput) + assert.Equal(t, nil, err) +} + func TestUnpackHandshakeResponse41(t *testing.T) { expected := HandshakeResponse41{ SequenceID: 1, From cc5e82f5897b65a010128eff459da08ab118b6e8 Mon Sep 17 00:00:00 2001 From: Shlomo Heigh Date: Mon, 2 Oct 2023 12:03:58 -0400 Subject: [PATCH 8/8] Clean up and add unit tests --- CHANGELOG.md | 4 +- .../tcp/mysql/authentication_handshake.go | 29 +-- .../connectors/tcp/mysql/protocol/error.go | 2 +- .../connectors/tcp/mysql/protocol/protocol.go | 66 ++++-- .../tcp/mysql/protocol/protocol_test.go | 190 ++++++++++++++++++ test/connector/tcp/mysql/docker-compose.yml | 4 +- .../tcp/mysql/tests/essentials_test.go | 22 +- 7 files changed, 259 insertions(+), 58 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5292ff83d..1b4fdeba5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Nothing should go in this section, please add to the latest unreleased version (and update the corresponding date), or add a new version. -## [1.7.19] - 2023-10-01 +## [1.7.19] - 2023-11-02 ### Added -- Add support for caching_sha256_password to mysql connector +- Add support for caching_sha256_password to mysql connector (CONJSE-1801) ## [1.7.18] - 2023-08-22 diff --git a/internal/plugin/connectors/tcp/mysql/authentication_handshake.go b/internal/plugin/connectors/tcp/mysql/authentication_handshake.go index 57f0758af..281248079 100644 --- a/internal/plugin/connectors/tcp/mysql/authentication_handshake.go +++ b/internal/plugin/connectors/tcp/mysql/authentication_handshake.go @@ -1,11 +1,6 @@ package mysql import ( - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "fmt" "net" "github.com/cyberark/secretless-broker/internal/plugin/connectors/tcp/mysql/protocol" @@ -253,7 +248,7 @@ func (h *AuthenticationHandshake) handleClientSSLRequest() { // but truncating the username and everything after the username in // the payload, as described here: // - // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_ssl_request.html // // The payload itself breaks down as follows: // @@ -311,7 +306,7 @@ func (h *AuthenticationHandshake) writeClientHandshakeResponseToBackend() { } // TODO: We should probably be carrying out a comprehensive unpacking, so that - // 1. we can be selective about the contents of the response + // we can be selective about the contents of the response packedHandshakeRespPacket, err := protocol.PackHandshakeResponse41(h.clientHandshakeResponse) if err != nil { h.err = err @@ -364,6 +359,9 @@ func (h *AuthenticationHandshake) handleBackendAuthResponse() { // https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html // https://github.com/go-sql-driver/mysql/blob/master/auth.go#L353 + // When using caching_sha2_password and TLS is enabled, no need + // to fetch public key and sign password with it, since + // the password is already encrypted in the TLS session. if h.clientRequestedSSL() { data, err := protocol.PackAuthSwitchResponse( h.backendConn.sequenceID, @@ -395,24 +393,12 @@ func (h *AuthenticationHandshake) handleBackendAuthResponse() { return } - // Parse public key - if pubKeyPkt[4] != protocol.ResponseAuthMoreData { - h.err = errors.New("expected ResponseAuthMoreData packet") - //TODO: For some reason we're getting "28000Access denied for user..." packets here in some cases - return - } - - block, rest := pem.Decode(pubKeyPkt[5:]) - if block == nil { - h.err = fmt.Errorf("no pem data found, data: %s", rest) - return - } - pkix, err := x509.ParsePKIXPublicKey(block.Bytes) + // Unpack public key from packet + pubKey, err := protocol.UnpackAuthRequestPubKeyResponse(pubKeyPkt) if err != nil { h.err = err return } - pubKey := pkix.(*rsa.PublicKey) // Encrypt password with public key enc, err := protocol.EncryptPassword(h.connectionDetails.Password, h.serverHandshake.Salt, pubKey) @@ -421,6 +407,7 @@ func (h *AuthenticationHandshake) handleBackendAuthResponse() { return } + // Send encrypted password to server encPkt := protocol.PackAuthEncryptedPasswordResponse(h.backendConn.sequenceID, enc) h.writeBackendPacket(encPkt) diff --git a/internal/plugin/connectors/tcp/mysql/protocol/error.go b/internal/plugin/connectors/tcp/mysql/protocol/error.go index 86669c1c8..b35d52192 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/error.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/error.go @@ -61,7 +61,7 @@ func (e Error) Error() string { } // GetPacket formats an Error into a protocol message. -// https://dev.mysql.com/doc/internals/en/packet-ERR_Packet.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_err_packet.html func (e Error) GetPacket() []byte { data := make([]byte, 4, 4+1+2+1+5+len(e.Message)) data = append(data, 0xff) diff --git a/internal/plugin/connectors/tcp/mysql/protocol/protocol.go b/internal/plugin/connectors/tcp/mysql/protocol/protocol.go index c979c41c5..fbe24b6f9 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/protocol.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/protocol.go @@ -30,7 +30,9 @@ import ( "crypto/rsa" "crypto/sha1" "crypto/sha256" + "crypto/x509" "encoding/binary" + "encoding/pem" "errors" "fmt" "io" @@ -190,7 +192,7 @@ func UnpackOkResponse(packet []byte) (*OkResponse, error) { // See https://mariadb.com/kb/en/mariadb/1-connecting-connecting/#initial-handshake-packet type HandshakeV10 struct { ProtocolVersion byte - SequenceId uint8 + SequenceID uint8 ServerVersion string ConnectionID uint32 StatusFlags uint16 @@ -349,7 +351,7 @@ func UnpackHandshakeV10(packet []byte) (*HandshakeV10, error) { } return &HandshakeV10{ - SequenceId: header[3], + SequenceID: header[3], ProtocolVersion: protoVersion, ServerVersion: serverVersion, ConnectionID: connectionID, @@ -425,7 +427,7 @@ func PackHandshakeV10(serverHandshake *HandshakeV10) ([]byte, error) { buffer.WriteByte(0) } - return AddHeaderToPacket(serverHandshake.SequenceId, buffer.Bytes()), nil + return AddHeaderToPacket(serverHandshake.SequenceID, buffer.Bytes()), nil } // RemoveSSLFromHandshakeV10 removes Client SSL Capability from Server @@ -515,11 +517,11 @@ func writeUint16(data []byte, pos int, value uint16) { // HandshakeResponse41 represents handshake response packet sent by 4.1+ clients supporting clientProtocol41 capability, // if the server announced it in its initial handshake packet. -// See http://imysql.com/mysql-internal-manual/connection-phase-packets.html#packet-Protocol::HandshakeResponse41 +// See https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html#sect_protocol_connection_phase_packets_protocol_handshake_response41 // // The format of the header is also described here: // -// https://dev.mysql.com/doc/internals/en/mysql-packet.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_packets.html // // +-------------+----------------+---------------------------------------------+ // | Type | Name | Description | @@ -647,6 +649,7 @@ func UnpackHandshakeResponse41(packet []byte) (*HandshakeResponse41, error) { PacketTail: packetTail}, nil } +// CreateAuthResponse creates an auth response for the given auth plugin func CreateAuthResponse(authPlugin string, password []byte, salt []byte) ([]byte, error) { var authResponse []byte var err error @@ -670,8 +673,6 @@ func CreateAuthResponse(authPlugin string, password []byte, salt []byte) ([]byte // salt from the server, and a username / password, and uses the salt // from the server handshake to inject the username / password credentials into // the client handshake response -// TODO: we can do a much better job than this. We just need to calculate the paylod and not do all this resetting BS! -// TODO: the plugin name should come from the server, not the client! func InjectCredentials(authPlugin string, clientHandshake *HandshakeResponse41, salt []byte, username string, password string) (err error) { authResponse, err := CreateAuthResponse(authPlugin, []byte(password), salt) if err != nil { @@ -720,9 +721,6 @@ func PackHandshakeResponse41(clientHandshake *HandshakeResponse41) ([]byte, erro var buf bytes.Buffer - // write the header (same as the original) - // buf.Write(clientHandshake.Header) - // write the capability flags capabilityFlagsBuf := make([]byte, 4) binary.LittleEndian.PutUint32(capabilityFlagsBuf, clientHandshake.CapabilityFlags) @@ -874,6 +872,8 @@ func ReadNullTerminatedString(r *bytes.Reader) string { } } +// AuthSwitchRequest represents a request from the server to switch to a different authentication method. +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html type AuthSwitchRequest struct { SequenceNumber uint8 PluginName string @@ -913,6 +913,24 @@ func UnpackAuthSwitchRequest(data []byte) (*AuthSwitchRequest, error) { }, nil } +// UnpackAuthRequestPubKeyResponse decodes a response from the server to a request for its public key. +func UnpackAuthRequestPubKeyResponse(data []byte) (*rsa.PublicKey, error) { + // Parse public key + if data[4] != ResponseAuthMoreData { + return nil, fmt.Errorf("expected ResponseAuthMoreData packet") + } + + block, rest := pem.Decode(data[5:]) + if block == nil { + return nil, fmt.Errorf("no pem data found, data: %s", rest) + } + pkix, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %s", err) + } + return pkix.(*rsa.PublicKey), nil +} + // ReadNullTerminatedBytes reads bytes from reader until 0x00 byte func ReadNullTerminatedBytes(r *bytes.Reader) (str []byte) { for { @@ -948,7 +966,7 @@ func CheckPacketLength(expected int, packet []byte) error { } // NativePassword calculates native password expected by server in HandshakeResponse41 -// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41 +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html#sect_protocol_connection_phase_packets_protocol_handshake_response41 // SHA1( password ) XOR SHA1( "20-bytes random data from server" SHA1( SHA1( password ) ) ) func NativePassword(password []byte, salt []byte) (nativePassword []byte, err error) { sha1 := sha1.New() @@ -974,7 +992,7 @@ func NativePassword(password []byte, salt []byte) (nativePassword []byte, err er } // AddHeaderToPacket adds a header to a packet -func AddHeaderToPacket(sequenceId uint8, restOfPacket []byte) []byte { +func AddHeaderToPacket(sequenceID uint8, restOfPacket []byte) []byte { // Calculate the packet length (excluding the length field itself) packetLength := len(restOfPacket) @@ -983,25 +1001,26 @@ func AddHeaderToPacket(sequenceId uint8, restOfPacket []byte) []byte { headerBuffer[0] = byte(packetLength & 0xFF) headerBuffer[1] = byte((packetLength >> 8) & 0xFF) headerBuffer[2] = byte((packetLength >> 16) & 0xFF) - headerBuffer[3] = sequenceId + headerBuffer[3] = sequenceID // Combine the header and packet data to create the final packet return append(headerBuffer, restOfPacket...) } // PackAuthSwitchResponse creates an AuthSwitchResponse packet with the provided response data. -func PackAuthSwitchResponse(authSwitchRequestSequenceId uint8, data []byte) ([]byte, error) { +func PackAuthSwitchResponse(authSwitchRequestSequenceID uint8, data []byte) ([]byte, error) { // Create a buffer to write the packet data buffer := new(bytes.Buffer) // Write the response data to the buffer buffer.Write(data) - return AddHeaderToPacket(authSwitchRequestSequenceId, buffer.Bytes()), nil + return AddHeaderToPacket(authSwitchRequestSequenceID, buffer.Bytes()), nil } +// AuthMoreDataResponse represents a packet sent from the server to request more auth data from the client. type AuthMoreDataResponse struct { - SequenceId uint8 + SequenceID uint8 PacketType byte StatusTag byte } @@ -1044,21 +1063,26 @@ func UnpackAuthMoreDataResponse(packet []byte) (*AuthMoreDataResponse, error) { } return &AuthMoreDataResponse{ - SequenceId: header[3], + SequenceID: header[3], PacketType: packetType, StatusTag: statusTag, }, nil } -func PackAuthRequestPubKeyResponse(sequenceId uint8) []byte { - return AddHeaderToPacket(sequenceId, []byte{CachingSha2PasswordRequestPublicKey}) +// PackAuthRequestPubKeyResponse encodes the request for the server's public key +func PackAuthRequestPubKeyResponse(sequenceID uint8) []byte { + return AddHeaderToPacket(sequenceID, []byte{CachingSha2PasswordRequestPublicKey}) } -func PackAuthEncryptedPasswordResponse(sequenceId uint8, encPwd []byte) []byte { - return AddHeaderToPacket(sequenceId, encPwd) +// PackAuthEncryptedPasswordResponse encodes the encrypted password response packet +func PackAuthEncryptedPasswordResponse(sequenceID uint8, encPwd []byte) []byte { + return AddHeaderToPacket(sequenceID, encPwd) } +// EncryptPassword encrypts a password using the provided seed and public key. func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { + // For this stage of the authentication, we must use sha1. See + // https://github.com/go-sql-driver/mysql/blob/19171b59bf90e6bf7a5bdf979e5e24a84b328b8a/auth.go#L217-L226 plain := make([]byte, len(password)+1) copy(plain, password) for i := range plain { diff --git a/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go b/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go index 9bdfd878a..eb57d1ce4 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go @@ -26,7 +26,11 @@ package protocol import ( "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" "encoding/binary" + "encoding/pem" "testing" "github.com/stretchr/testify/assert" @@ -438,3 +442,189 @@ func TestNativePassword(t *testing.T) { assert.Equal(t, expected, output) assert.Equal(t, nil, err) } + +func TestCreateAuthResponse(t *testing.T) { + testCases := []struct { + authPlugin string + password []byte + salt []byte + expectedLen int + expectErr bool + }{ + { + authPlugin: "mysql_native_password", + password: []byte("password"), + salt: []byte("salt"), + expectedLen: 20, + }, + { + authPlugin: "caching_sha2_password", + password: []byte("password"), + salt: []byte("salt"), + expectedLen: 32, + }, + { + authPlugin: "unknown_auth_plugin", + password: []byte("password"), + salt: []byte("salt"), + expectErr: true, + }, + } + + for _, tc := range testCases { + actual, err := CreateAuthResponse(tc.authPlugin, tc.password, tc.salt) + if tc.expectErr { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, tc.expectedLen, len(actual)) + } + } +} + +func TestUnpackAuthSwitchRequest(t *testing.T) { + testCases := []struct { + name string + input []byte + expectedError string + expectedReq *AuthSwitchRequest + }{ + { + name: "valid AuthSwitchRequest packet", + input: []byte{ + 0x02, 0x00, 0x00, 0x01, // Header + 0x01, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x00, // Plugin name ("plugin") + 0x01, 0x02, 0x03, // Plugin data + }, + expectedReq: &AuthSwitchRequest{ + SequenceNumber: 1, + PluginName: "plugin", + PluginData: []byte{0x01, 0x02, 0x03}, + }, + }, + { + name: "missing plugin name", + input: []byte{0x00, 0x00, 0x00, 0x01}, + expectedError: "Invalid AuthSwitchRequest packet: Missing plugin name", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, err := UnpackAuthSwitchRequest(tc.input) + + if tc.expectedError != "" { + assert.EqualError(t, err, tc.expectedError) + assert.Nil(t, req) + } else { + assert.NoError(t, err) + assert.NotNil(t, req) + assert.Equal(t, tc.expectedReq.SequenceNumber, req.SequenceNumber) + assert.Equal(t, tc.expectedReq.PluginName, req.PluginName) + assert.Equal(t, tc.expectedReq.PluginData, req.PluginData) + } + }) + } +} + +func TestUnpackAuthRequestPubKeyResponse(t *testing.T) { + // Generate a test RSA public key + testPubKey, _ := rsa.GenerateKey(rand.Reader, 256) + testPubKeyBytes, _ := x509.MarshalPKIXPublicKey(&testPubKey.PublicKey) + testPubKeyPem := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: testPubKeyBytes, + }) + testPubKeyBytes = append([]byte{0x02, 0x00, 0x00, 0x04, 0x01}, testPubKeyPem...) + + testCases := []struct { + name string + input []byte + expectedError string + expectedResp *rsa.PublicKey + }{ + { + name: "valid AuthRequestPubKeyResponse packet", + input: testPubKeyBytes, + expectedResp: &testPubKey.PublicKey, + }, + { + name: "missing RSA public key", + input: []byte{0x02, 0x00, 0x00, 0x04, 0x01}, + expectedError: "no pem data found, data: ", + }, + { + name: "invalid RSA public key", + input: []byte{0x02, 0x00, 0x00, 0x04, 0x01, 0x01, 0x02}, + expectedError: "no pem data found, data: \x01\x02", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resp, err := UnpackAuthRequestPubKeyResponse(tc.input) + + if tc.expectedError != "" { + assert.EqualError(t, err, tc.expectedError) + assert.Nil(t, resp) + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.EqualValues(t, tc.expectedResp, resp) + } + }) + } +} + +func TestPackAuthSwitchResponse(t *testing.T) { + data := []byte{0x01, 0x02, 0x03} + seqID := uint8(9) + + expected := []byte{ + 0x03, 0x00, 0x00, 0x09, // Header (including sequence number) + 0x01, 0x02, 0x03, // Data + } + + output, err := PackAuthSwitchResponse(seqID, data) + assert.NoError(t, err) + assert.Equal(t, expected, output) +} + +func TestUnpackAuthMoreDataResponse(t *testing.T) { + testCases := []struct { + name string + input []byte + expectedError string + expectedResp *AuthMoreDataResponse + }{ + { + name: "valid AuthMoreDataResponse packet", + input: []byte{0x01, 0x00, 0x00, 0x09, 0x01, 0x04}, + expectedResp: &AuthMoreDataResponse{ + SequenceID: 9, + PacketType: 1, + StatusTag: 4, + }, + }, + { + name: "missing data", + input: []byte{0x01, 0x00, 0x00, 0x09}, + expectedError: ErrInvalidPacketLength.Error(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resp, err := UnpackAuthMoreDataResponse(tc.input) + + if tc.expectedError != "" { + assert.EqualError(t, err, tc.expectedError) + assert.Nil(t, resp) + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, tc.expectedResp, resp) + } + }) + } +} diff --git a/test/connector/tcp/mysql/docker-compose.yml b/test/connector/tcp/mysql/docker-compose.yml index 70dbc3fe7..c2beab9e6 100644 --- a/test/connector/tcp/mysql/docker-compose.yml +++ b/test/connector/tcp/mysql/docker-compose.yml @@ -28,7 +28,7 @@ services: DB_HOST_NO_TLS: mysql_no_tls DB_HOST_TLS: mysql DB_PORT: 3306 - DB_USER: testuser_native_password + DB_USER: testuser DB_PASSWORD: testpass volumes: - ../../../../:/secretless @@ -57,7 +57,7 @@ services: DB_HOST_TLS: mysql DB_HOST_NO_TLS: mysql_no_tls DB_PORT: 3306 - DB_USER: testuser_native_password + DB_USER: testuser DB_PASSWORD: testpass SECRETLESS_HOST: VERBOSE: diff --git a/test/connector/tcp/mysql/tests/essentials_test.go b/test/connector/tcp/mysql/tests/essentials_test.go index 32e724501..478775008 100644 --- a/test/connector/tcp/mysql/tests/essentials_test.go +++ b/test/connector/tcp/mysql/tests/essentials_test.go @@ -13,16 +13,16 @@ func TestEssentials(t *testing.T) { Description: "with username, wrong password", ShouldPass: true, ClientConfiguration: ClientConfiguration{ - Username: "placeholder", - Password: "placeholder", + Username: "testuser", + Password: "wrongpassword", }, }, { Description: "with wrong username, wrong password", ShouldPass: true, ClientConfiguration: ClientConfiguration{ - Username: "placeholder", - Password: "placeholder", + Username: "wrongusername", + Password: "wrongpassword", }, }, { @@ -69,8 +69,8 @@ func TestEssentials(t *testing.T) { Description: "Socket, client -> TLS -> secretless", ShouldPass: false, ClientConfiguration: ClientConfiguration{ - Username: "placeholder", - Password: "placeholder", + Username: "wrongusername", + Password: "wrongpassword", SSL: true, }, CmdOutput: StringPointer("SSL is required but the server doesn't support it"), @@ -88,8 +88,8 @@ func TestEssentials(t *testing.T) { Description: "TCP, client -> TLS -> secretless", ShouldPass: false, ClientConfiguration: ClientConfiguration{ - Username: "placeholder", - Password: "placeholder", + Username: "wrongusername", + Password: "wrongpassword", SSL: true, }, CmdOutput: StringPointer("SSL is required but the server doesn't support it"), @@ -108,10 +108,10 @@ func TestEssentials(t *testing.T) { Description: "secretless using invalid credentials", ShouldPass: false, ClientConfiguration: ClientConfiguration{ - Username: "placeholder", - Password: "placeholder", + Username: "testuser", + Password: "wrongpassword", }, - CmdOutput: StringPointer("ERROR 1045 (28000): Access denied for user 'testuser_native_password'@"), + CmdOutput: StringPointer("ERROR 1045 (28000): Access denied for user 'testuser'@"), }, }, t) })