diff --git a/CHANGELOG.md b/CHANGELOG.md index 6783e6fb3..1b4fdeba5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Nothing should go in this section, please add to the latest unreleased version (and update the corresponding date), or add a new version. +## [1.7.19] - 2023-11-02 + +### Added +- Add support for caching_sha256_password to mysql connector (CONJSE-1801) + ## [1.7.18] - 2023-08-22 ### Changed @@ -16,7 +21,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. [cyberark/secretless-broker#1499](https://github.com/cyberark/secretless-broker/pull/1499) ### Security -- Updated jquery to v3.7.1 - Updated github.com/docker/docker to v24.0.5 (CONJSE-1798) ### Added diff --git a/internal/plugin/connectors/tcp/mysql/authentication_handshake.go b/internal/plugin/connectors/tcp/mysql/authentication_handshake.go index 9a6d1fd22..281248079 100644 --- a/internal/plugin/connectors/tcp/mysql/authentication_handshake.go +++ b/internal/plugin/connectors/tcp/mysql/authentication_handshake.go @@ -8,7 +8,6 @@ import ( ) /* - AuthenticationHandshake represents the entire back and forth process between a MySQL client and server during which authentication occurs. Note this is distinct from the various specific handshake packets that are @@ -58,7 +57,6 @@ following source: Secretless->Backend: HandshakeResponse Backend->Secretless: OkPacket Secretless->Client: OkPacket - */ type AuthenticationHandshake struct { connectionDetails *ConnectionDetails @@ -93,15 +91,42 @@ func NewAuthenticationHandshake( // MySQL server and client. When it completes successfully, // AuthenticatedBackendConn will return the raw, authenticated network conn. func (h *AuthenticationHandshake) Run() error { + // The server is the first to communicate. Read the server handshake h.readServerHandshake() + // Pass along the server handshake to the client, with some minor modifications + // + // 1. Remove TLS capability to avoid TLS connections to Secretless. + // 2. Use `mysql_native_password` as the auth plugin between the client and Secretless, to make + // life easier. We actually don't care about the credentials from the client, we just need the rest + // of the packet. h.writeHandshakeToClient() - h.validateServerSSL() + + // Read the client handshake response. + // + // We are done listening to the client! h.readClientHandshakeResponse() + + // Make sure if the connector (not the client) is configured to use TLS + // then the server supports TLS. This must be done after reading the client response otherwise if validation + // fails then the client connection hangs + h.validateServerSSL() + + // Everything beyond this point is in service of responding to the server authentication challenge. + + // Override client capabilities. For example, the connector has secure connection capabilities and supports + // authentication plugins. h.overrideClientCapabilities() + // Inject credentials into the client handshake response h.injectCredentials() + h.handleClientSSLRequest() + + // Write modified client handshake to server, and + // carry out the rest of the authentication dance between Secretless and the server. + // When we're done we just let the client know of the outcome. h.writeClientHandshakeResponseToBackend() - h.verifyAndProxyOkResponse() + h.handleBackendAuthResponse() + return h.err } @@ -128,10 +153,17 @@ func (h *AuthenticationHandshake) writeHandshakeToClient() { return } + serverHandshake := *h.serverHandshake // Remove Client SSL Capability from Server Handshake Packet // to force client to connect to Secretless without SSL // TODO: update this after kumbi's work - packetWithNoSSL, err := protocol.RemoveSSLFromHandshakeV10(h.rawServerHandshake) + serverHandshake.ServerCapabilities &^= protocol.ClientSSL + + // Give client the simplest auth plugin request + // This might work for now, but we'll likely need to add support for other auth plugins + serverHandshake.AuthPlugin = "mysql_native_password" + + packetWithNoSSL, err := protocol.PackHandshakeV10(&serverHandshake) if err != nil { h.err = err return @@ -161,9 +193,10 @@ func (h *AuthenticationHandshake) readClientHandshakeResponse() { return } - // TODO: client requesting SSL results ERROR 2026 (HY000): SSL connection error: protocol version mismatch + // TODO: client requesting SSL results in ERROR 2026 (HY000): SSL connection error: protocol version mismatch h.clientHandshakeResponse, h.err = protocol.UnpackHandshakeResponse41(rawResponse) } + func (h *AuthenticationHandshake) overrideClientCapabilities() { if h.err != nil { return @@ -173,10 +206,11 @@ func (h *AuthenticationHandshake) overrideClientCapabilities() { // TODO: add tests cases for authentication plugins support // Disable CapabilityFlag for authentication plugins support - h.clientHandshakeResponse.CapabilityFlags &^= protocol.ClientPluginAuth + h.clientHandshakeResponse.CapabilityFlags |= protocol.ClientPluginAuth // TODO: add tests cases for client secure connection // Enable CapabilityFlag for client secure connection + // TODO: explore weird heisenbug when this is toggled off: ERROR: 1043 (08S01): Bad handshake h.clientHandshakeResponse.CapabilityFlags |= protocol.ClientSecureConnection // Ensure CapabilityFlag is set when using TLS @@ -193,6 +227,7 @@ func (h *AuthenticationHandshake) injectCredentials() { // TODO: change this to method call on clientHandshakeResponse when Kumbi's work done h.err = protocol.InjectCredentials( + h.serverHandshake.AuthPlugin, h.clientHandshakeResponse, h.serverHandshake.Salt, h.connectionDetails.Username, @@ -213,7 +248,7 @@ func (h *AuthenticationHandshake) handleClientSSLRequest() { // but truncating the username and everything after the username in // the payload, as described here: // - // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_ssl_request.html // // The payload itself breaks down as follows: // @@ -270,6 +305,8 @@ func (h *AuthenticationHandshake) writeClientHandshakeResponseToBackend() { return } + // TODO: We should probably be carrying out a comprehensive unpacking, so that + // we can be selective about the contents of the response packedHandshakeRespPacket, err := protocol.PackHandshakeResponse41(h.clientHandshakeResponse) if err != nil { h.err = err @@ -284,31 +321,141 @@ func (h *AuthenticationHandshake) verifyAndProxyOkResponse() { return } + // This proxying needs to take place to ensure the client gets the OK packet with + // the correct sequence id, the connection keeps track of this information whereas + // Secretless duplex streaming does not. + rawPkt := h.readBackendPacket() + h.writeClientPacket(rawPkt) +} + +func (h *AuthenticationHandshake) handleBackendAuthResponse() { + if h.err != nil { + return + } + rawPkt := h.readBackendPacket() if h.err != nil { return } switch protocol.GetPacketType(rawPkt) { - case protocol.ResponseErr: - // Return after adding the error response to AuthenticationHandshake - // as a protocol.Error type - // - // The protocol.Error type makes it possible - // to have Go errors that can contain rich protocol specific information - // and have the smarts to encode themselves into a MYSQL error packet - err := protocol.UnpackErrResponse(rawPkt) - h.err = err + case protocol.ResponseAuthMoreData: + defer h.verifyAndProxyOkResponse() + + moreDataResp, err := protocol.UnpackAuthMoreDataResponse(rawPkt) + if err != nil { + h.err = err + return + } + + switch moreDataResp.StatusTag { + case protocol.CachingSha2PasswordFastAuthSuccess: + // The user was cached and a fast login was performed successfully. + // Do nothing. An OK packet will be sent by the server immediately + // following this packet. + return + case protocol.CachingSha2PasswordPerformFullAuthentication: + // The server is requesting a full authentication handshake. + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html + // https://github.com/go-sql-driver/mysql/blob/master/auth.go#L353 + + // When using caching_sha2_password and TLS is enabled, no need + // to fetch public key and sign password with it, since + // the password is already encrypted in the TLS session. + if h.clientRequestedSSL() { + data, err := protocol.PackAuthSwitchResponse( + h.backendConn.sequenceID, + append([]byte(h.connectionDetails.Password), 0), + ) + if err != nil { + h.err = err + return + } + + h.writeBackendPacket(data) + if h.err != nil { + return + } + return + } + + // Request public key from server + data := protocol.PackAuthRequestPubKeyResponse(h.backendConn.sequenceID) + + h.writeBackendPacket(data) + if h.err != nil { + return + } + + // Read public key from server + pubKeyPkt := h.readBackendPacket() + if h.err != nil { + return + } + + // Unpack public key from packet + pubKey, err := protocol.UnpackAuthRequestPubKeyResponse(pubKeyPkt) + if err != nil { + h.err = err + return + } + + // Encrypt password with public key + enc, err := protocol.EncryptPassword(h.connectionDetails.Password, h.serverHandshake.Salt, pubKey) + if err != nil { + h.err = err + return + } + + // Send encrypted password to server + encPkt := protocol.PackAuthEncryptedPasswordResponse(h.backendConn.sequenceID, enc) + + h.writeBackendPacket(encPkt) + return + } + return - default: - // Verify packet is valid; don't do anything with unpacked - if _, err := protocol.UnpackOkResponse(rawPkt); err != nil { + + case protocol.ResponseAuthSwitchRequest: + defer h.verifyAndProxyOkResponse() + + authSwitchRequest, err := protocol.UnpackAuthSwitchRequest(rawPkt) + if err != nil { + h.err = err + return + } + + salt := authSwitchRequest.PluginData + // This is because the salt seems to actually be 21 bytes, ending in a null byte. + // However the documentation suggests auth switch requests should be an EOF string + // See https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html + if authSwitchRequest.PluginName == "mysql_native_password" { + salt = salt[:20] + } + authResponse, err := protocol.CreateAuthResponse(authSwitchRequest.PluginName, []byte(h.connectionDetails.Password), salt) + if err != nil { + return + } + + authSwitchResponseData, err := protocol.PackAuthSwitchResponse( + authSwitchRequest.SequenceNumber, + authResponse, + ) + if err != nil { h.err = err return } + h.writeBackendPacket(authSwitchResponseData) + + return + + default: + // Let the client deal with it + h.writeClientPacket(rawPkt) + + return } - h.writeClientPacket(rawPkt) } func (h *AuthenticationHandshake) dbSSLMode() *ssl.DbSSLMode { diff --git a/internal/plugin/connectors/tcp/mysql/connector.go b/internal/plugin/connectors/tcp/mysql/connector.go index c1a1e8bf8..f64375891 100644 --- a/internal/plugin/connectors/tcp/mysql/connector.go +++ b/internal/plugin/connectors/tcp/mysql/connector.go @@ -39,9 +39,9 @@ func (connector *SingleUseConnector) sendErrorToClient(err error) { // Connect implements the tcp.Connector func signature // // It is the main method of the SingleUseConnector. It: -// 1. Constructs connection details from the provided credentials map. -// 2. Dials the backend using credentials. -// 3. Runs through the connection phase steps to authenticate. +// 1. Constructs connection details from the provided credentials map. +// 2. Dials the backend using credentials. +// 3. Runs through the connection phase steps to authenticate. // // Connect requires "host", "port", "username" and "password" credentials. func (connector *SingleUseConnector) Connect( diff --git a/internal/plugin/connectors/tcp/mysql/protocol/const.go b/internal/plugin/connectors/tcp/mysql/protocol/const.go index dbc5a4067..f60c155ce 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/const.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/const.go @@ -27,11 +27,13 @@ package protocol // Random constants const ( // MySQL response types - responseEOF = 0xfe - responseOk = 0x00 - responsePrepareOk = 0x00 - ResponseErr = 0xff - responseLocalinfile = 0xfb + ResponseAuthMoreData = 0x01 + ResponseEOF = 0xfe + ResponseOk = 0x00 + ResponsePrepareOk = 0x00 + ResponseAuthSwitchRequest = 0xfe + ResponseErr = 0xff + responseLocalinfile = 0xfb // MySQL field types constants fieldTypeString = 0xfd @@ -49,7 +51,10 @@ const ( // Digits after comma doubleDecodePrecision = 6 - defaultAuthPluginName = "mysql_native_password" + // caching_sha256_password authentication plugin constants + CachingSha2PasswordRequestPublicKey = 0x02 + CachingSha2PasswordFastAuthSuccess = 0x03 + CachingSha2PasswordPerformFullAuthentication = 0x04 ) // Protocol commands diff --git a/internal/plugin/connectors/tcp/mysql/protocol/error.go b/internal/plugin/connectors/tcp/mysql/protocol/error.go index 86669c1c8..b35d52192 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/error.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/error.go @@ -61,7 +61,7 @@ func (e Error) Error() string { } // GetPacket formats an Error into a protocol message. -// https://dev.mysql.com/doc/internals/en/packet-ERR_Packet.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_err_packet.html func (e Error) GetPacket() []byte { data := make([]byte, 4, 4+1+2+1+5+len(e.Message)) data = append(data, 0xff) diff --git a/internal/plugin/connectors/tcp/mysql/protocol/packet.go b/internal/plugin/connectors/tcp/mysql/protocol/packet.go index cf2f28ece..bde538005 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/packet.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/packet.go @@ -93,7 +93,7 @@ func ReadPrepareResponse(conn net.Conn) ([]byte, byte, error) { } switch pkt[4] { - case responsePrepareOk: + case ResponsePrepareOk: numParams := binary.LittleEndian.Uint16(pkt[9:11]) numColumns := binary.LittleEndian.Uint16(pkt[11:13]) packetsExpected := 0 @@ -121,7 +121,7 @@ func ReadPrepareResponse(conn net.Conn) ([]byte, byte, error) { data = append(data, pkt...) } - return data, responseOk, nil + return data, ResponseOk, nil case ResponseErr: return pkt, ResponseErr, nil @@ -148,8 +148,8 @@ func ReadResponse(conn net.Conn, deprecateEOF bool) ([]byte, byte, error) { } switch pkt[4] { - case responseOk: - return pkt, responseOk, nil + case ResponseOk: + return pkt, ResponseOk, nil case ResponseErr: return pkt, ResponseErr, nil @@ -185,7 +185,7 @@ func ReadResponse(conn net.Conn, deprecateEOF bool) ([]byte, byte, error) { data = append(data, pkt...) - if pkt[4] == responseEOF { + if pkt[4] == ResponseEOF { break } } @@ -197,7 +197,7 @@ func ReadResponse(conn net.Conn, deprecateEOF bool) ([]byte, byte, error) { func ReadPacket(conn net.Conn) ([]byte, error) { // Read packet header - header := []byte{0, 0, 0, 0} + header := make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { return nil, err } diff --git a/internal/plugin/connectors/tcp/mysql/protocol/protocol.go b/internal/plugin/connectors/tcp/mysql/protocol/protocol.go index 7fa8032e5..fbe24b6f9 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/protocol.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/protocol.go @@ -26,9 +26,15 @@ package protocol import ( "bytes" + "crypto/rand" + "crypto/rsa" "crypto/sha1" + "crypto/sha256" + "crypto/x509" "encoding/binary" + "encoding/pem" "errors" + "fmt" "io" ) @@ -49,10 +55,12 @@ var ErrFieldTypeNotImplementedYet = errors.New("Protocol: Required field type no // int<1> PacketType (0xFF) // int<2> ErrorCode // if clientCapabilities & clientProtocol41 -// { -// string<1> SqlStateMarker (#) -// string<5> SqlState -// } +// +// { +// string<1> SqlStateMarker (#) +// string<5> SqlState +// } +// // string Error func UnpackErrResponse(data []byte) error { // Min packet length = @@ -96,10 +104,10 @@ func UnpackErrResponse(data []byte) error { // GetPacketType extracts the PacketType byte // Part of basic packet structure shown below. // -// int<3> PacketLength -// int<1> PacketNumber -// int<1> PacketType (0xFF) -// ... more ... +// int<3> PacketLength +// int<1> PacketNumber +// int<1> PacketType (0xFF) +// ... more ... func GetPacketType(packet []byte) byte { return packet[4] } @@ -142,7 +150,7 @@ func UnpackOkResponse(packet []byte) (*OkResponse, error) { if err != nil { return nil, err } - if packetType != responseOk { + if packetType != ResponseOk { return nil, errors.New("Malformed packet") } @@ -184,8 +192,11 @@ func UnpackOkResponse(packet []byte) (*OkResponse, error) { // See https://mariadb.com/kb/en/mariadb/1-connecting-connecting/#initial-handshake-packet type HandshakeV10 struct { ProtocolVersion byte + SequenceID uint8 ServerVersion string ConnectionID uint32 + StatusFlags uint16 + CharacterSet uint8 ServerCapabilities uint32 AuthPlugin string Salt []byte @@ -207,27 +218,35 @@ type HandshakeV10 struct { // int<2> StatusFlags // int<2> ServerCapabilities (2nd part) // if capabilities & clientPluginAuth -// { -// int<1> AuthPluginDataLength -// } +// +// { +// int<1> AuthPluginDataLength +// } +// // else -// { -// int<1> 0x00 -// } +// +// { +// int<1> 0x00 +// } +// // string<10> Reserved (all 0x00) // if capabilities & clientSecureConnection -// { -// string[$len] AuthPluginDataPart2 ($len=MAX(13, AuthPluginDataLength - 8)) -// } +// +// { +// string[$len] AuthPluginDataPart2 ($len=MAX(13, AuthPluginDataLength - 8)) +// } +// // if capabilities & clientPluginAuth -// { -// string[NUL] AuthPluginName -// } +// +// { +// string[NUL] AuthPluginName +// } func UnpackHandshakeV10(packet []byte) (*HandshakeV10, error) { r := bytes.NewReader(packet) - // Skip packet header - if _, err := GetPacketHeader(r); err != nil { + // Header + header, err := GetPacketHeader(r) + if err != nil { return nil, err } @@ -263,10 +282,16 @@ func UnpackHandshakeV10(packet []byte) (*HandshakeV10, error) { return nil, err } - // Skip ServerDefaultCollation and StatusFlags - if _, err := r.Seek(3, io.SeekCurrent); err != nil { + // Read ServerCharacterSet and StatusFlags + serverCharacterSet, err := r.ReadByte() + if err != nil { return nil, err } + serverStatusFlagsBuf := make([]byte, 2) + if _, err := r.Read(serverStatusFlagsBuf); err != nil { + return nil, err + } + serverStatusFlags := binary.LittleEndian.Uint16(serverStatusFlagsBuf) // Read ExServerCapabilities serverCapabilitiesHigherBuf := make([]byte, 2) @@ -326,15 +351,85 @@ func UnpackHandshakeV10(packet []byte) (*HandshakeV10, error) { } return &HandshakeV10{ + SequenceID: header[3], ProtocolVersion: protoVersion, ServerVersion: serverVersion, ConnectionID: connectionID, ServerCapabilities: serverCapabilities, AuthPlugin: authPlugin, Salt: salt, + StatusFlags: serverStatusFlags, + CharacterSet: serverCharacterSet, }, nil } +// PackHandshakeV10 takes in a HandshakeResponse41 object and +// returns a handshake response packet +func PackHandshakeV10(serverHandshake *HandshakeV10) ([]byte, error) { + // Create a buffer to write the packet data + buffer := new(bytes.Buffer) + + // Write ProtocolVersion (int<1>) + binary.Write(buffer, binary.LittleEndian, serverHandshake.ProtocolVersion) + + // Write ServerVersion (string) + buffer.WriteString(serverHandshake.ServerVersion) + buffer.WriteByte(0) + + // Write ConnectionID (int<4>) + binary.Write(buffer, binary.LittleEndian, serverHandshake.ConnectionID) + + // Write AuthPluginDataPart1 (string<8>) + buffer.Write(serverHandshake.Salt[:8]) // Write the first 8 bytes of the salt + + // Write Reserved (int<1>) + buffer.WriteByte(0) + + // Write ServerCapabilities (int<2>) + binary.Write(buffer, binary.LittleEndian, uint16(serverHandshake.ServerCapabilities&0xFFFF)) + + // Write ServerCharacterSet (int<1>) + buffer.WriteByte(serverHandshake.CharacterSet) + + // Write StatusFlags (int<2>) + binary.Write(buffer, binary.LittleEndian, serverHandshake.StatusFlags) + + // Write ServerCapabilities (int<2>), the higher part + binary.Write(buffer, binary.LittleEndian, uint16(serverHandshake.ServerCapabilities>>16)) + + // Write AuthPluginDataLength (int<1>) if required + if serverHandshake.ServerCapabilities&ClientPluginAuth > 0 { + buffer.WriteByte(byte(len(serverHandshake.Salt) + 1)) + } + + // Write Reserved (string<10>) + buffer.Write(make([]byte, 10)) + + // Calculate the length of AuthPluginDataPart2 + var authPluginDataLength byte + if serverHandshake.ServerCapabilities&ClientSecureConnection != 0 { + numBytes := len(serverHandshake.Salt) - 8 + if numBytes > 13 { + numBytes = 13 + } + authPluginDataLength = byte(numBytes) + } + + // Write AuthPluginDataPart2 (string[$len]) if required + if serverHandshake.ServerCapabilities&ClientSecureConnection != 0 { + buffer.Write(serverHandshake.Salt[8 : 8+int(authPluginDataLength)]) + buffer.WriteByte(0) + } + + // Write AuthPluginName (string) if required + if serverHandshake.ServerCapabilities&ClientPluginAuth > 0 { + buffer.WriteString(serverHandshake.AuthPlugin) + buffer.WriteByte(0) + } + + return AddHeaderToPacket(serverHandshake.SequenceID, buffer.Bytes()), nil +} + // RemoveSSLFromHandshakeV10 removes Client SSL Capability from Server // Handshake Packet. Secretless needs to do this to force the client to // communicate with Secretless without using SSL. That half of the connection @@ -422,23 +517,23 @@ func writeUint16(data []byte, pos int, value uint16) { // HandshakeResponse41 represents handshake response packet sent by 4.1+ clients supporting clientProtocol41 capability, // if the server announced it in its initial handshake packet. -// See http://imysql.com/mysql-internal-manual/connection-phase-packets.html#packet-Protocol::HandshakeResponse41 +// See https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html#sect_protocol_connection_phase_packets_protocol_handshake_response41 // // The format of the header is also described here: // -// https://dev.mysql.com/doc/internals/en/mysql-packet.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_packets.html // -// +-------------+----------------+---------------------------------------------+ -// | Type | Name | Description | -// +-------------+----------------+---------------------------------------------+ -// | int<3> | payload_length | Length of the payload. The number of bytes | -// | | | in the packet beyond the initial 4 bytes | -// | | | that make up the packet header. | -// | int<1> | sequence_id | Sequence ID | -// | string | payload | [len=payload_length] payload of the packet | -// +-------------+----------------+---------------------------------------------+ +// +-------------+----------------+---------------------------------------------+ +// | Type | Name | Description | +// +-------------+----------------+---------------------------------------------+ +// | int<3> | payload_length | Length of the payload. The number of bytes | +// | | | in the packet beyond the initial 4 bytes | +// | | | that make up the packet header. | +// | int<1> | sequence_id | Sequence ID | +// | string | payload | [len=payload_length] payload of the packet | +// +-------------+----------------+---------------------------------------------+ type HandshakeResponse41 struct { - Header []byte + SequenceID uint8 CapabilityFlags uint32 MaxPacketSize uint32 ClientCharset uint8 @@ -542,7 +637,7 @@ func UnpackHandshakeResponse41(packet []byte) (*HandshakeResponse41, error) { } return &HandshakeResponse41{ - Header: header, + SequenceID: header[3], CapabilityFlags: capabilityFlags, MaxPacketSize: maxPacketSize, ClientCharset: charset, @@ -554,27 +649,37 @@ func UnpackHandshakeResponse41(packet []byte) (*HandshakeResponse41, error) { PacketTail: packetTail}, nil } -// InjectCredentials takes in a HandshakeResponse41 from the client, the -// salt from the server, and a username / password, and uses the salt -// from the server handshake to inject the username / password credentials into -// the client handshake response -func InjectCredentials(clientHandshake *HandshakeResponse41, salt []byte, username string, password string) (err error) { +// CreateAuthResponse creates an auth response for the given auth plugin +func CreateAuthResponse(authPlugin string, password []byte, salt []byte) ([]byte, error) { + var authResponse []byte + var err error - authResponse, err := NativePassword([]byte(password), salt) - if err != nil { - return + switch authPlugin { + case "mysql_native_password": + authResponse, err = NativePassword([]byte(password), salt) + case "caching_sha2_password": + authResponse = scrambleSHA256Password([]byte(password), salt) + default: + err = fmt.Errorf("Unknown auth plugin: %s", authPlugin) } - // Reset the payload length for the packet - payloadLengthDiff := int32(len(username) - len(clientHandshake.Username)) - payloadLengthDiff += int32(len(authResponse) - int(clientHandshake.AuthLength)) - payloadLengthDiff += int32(len(defaultAuthPluginName) - len(clientHandshake.AuthPluginName)) + if err != nil { + return nil, err + } + return authResponse, nil +} - clientHandshake.Header, err = UpdateHeaderPayloadLength(clientHandshake.Header, payloadLengthDiff) +// InjectCredentials takes in a HandshakeResponse41 from the client, the +// salt from the server, and a username / password, and uses the salt +// from the server handshake to inject the username / password credentials into +// the client handshake response +func InjectCredentials(authPlugin string, clientHandshake *HandshakeResponse41, salt []byte, username string, password string) (err error) { + authResponse, err := CreateAuthResponse(authPlugin, []byte(password), salt) if err != nil { return } + clientHandshake.AuthPluginName = authPlugin clientHandshake.Username = username clientHandshake.AuthLength = int64(len(authResponse)) clientHandshake.AuthResponse = authResponse @@ -582,15 +687,40 @@ func InjectCredentials(clientHandshake *HandshakeResponse41, salt []byte, userna return } +// Hash password using MySQL 8+ method (SHA256) +func scrambleSHA256Password(password []byte, scramble []byte) []byte { + if len(password) == 0 { + return nil + } + + // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) + + crypt := sha256.New() + crypt.Write(password) + message1 := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1) + message1Hash := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1Hash) + crypt.Write(scramble) + message2 := crypt.Sum(nil) + + for i := range message1 { + message1[i] ^= message2[i] + } + + return message1 +} + // PackHandshakeResponse41 takes in a HandshakeResponse41 object and // returns a handshake response packet -func PackHandshakeResponse41(clientHandshake *HandshakeResponse41) (packet []byte, err error) { +func PackHandshakeResponse41(clientHandshake *HandshakeResponse41) ([]byte, error) { var buf bytes.Buffer - // write the header (same as the original) - buf.Write(clientHandshake.Header) - // write the capability flags capabilityFlagsBuf := make([]byte, 4) binary.LittleEndian.PutUint32(capabilityFlagsBuf, clientHandshake.CapabilityFlags) @@ -633,7 +763,7 @@ func PackHandshakeResponse41(clientHandshake *HandshakeResponse41) (packet []byt } // write auth plugin name - buf.WriteString(defaultAuthPluginName) + buf.WriteString(clientHandshake.AuthPluginName) buf.WriteByte(0) // write tail of packet (if set) @@ -641,9 +771,7 @@ func PackHandshakeResponse41(clientHandshake *HandshakeResponse41) (packet []byt buf.Write(clientHandshake.PacketTail) } - packet = buf.Bytes() - - return + return AddHeaderToPacket(clientHandshake.SequenceID, buf.Bytes()), nil } // GetLenEncodedIntegerSize returns bytes count for length encoded integer @@ -744,6 +872,65 @@ func ReadNullTerminatedString(r *bytes.Reader) string { } } +// AuthSwitchRequest represents a request from the server to switch to a different authentication method. +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html +type AuthSwitchRequest struct { + SequenceNumber uint8 + PluginName string + PluginData []byte +} + +// UnpackAuthSwitchRequest decodes an AuthSwitchRequest packet from the provided data. +func UnpackAuthSwitchRequest(data []byte) (*AuthSwitchRequest, error) { + sequenceNumber := data[3] + data = data[4:] + + // Find the position of the null-terminated plugin name + nullTerminatorIndex := 1 + for i := 1; i < len(data); i++ { + if data[i] == 0x00 { + nullTerminatorIndex = i + break + } + } + + // Extract the plugin name + if nullTerminatorIndex == 1 { + return nil, fmt.Errorf("Invalid AuthSwitchRequest packet: Missing plugin name") + } + pluginName := string(data[1:nullTerminatorIndex]) + + // Extract the plugin provided data + var pluginData []byte + if nullTerminatorIndex+1 < len(data) { + pluginData = data[nullTerminatorIndex+1:] + } + + return &AuthSwitchRequest{ + SequenceNumber: sequenceNumber, + PluginName: pluginName, + PluginData: pluginData, + }, nil +} + +// UnpackAuthRequestPubKeyResponse decodes a response from the server to a request for its public key. +func UnpackAuthRequestPubKeyResponse(data []byte) (*rsa.PublicKey, error) { + // Parse public key + if data[4] != ResponseAuthMoreData { + return nil, fmt.Errorf("expected ResponseAuthMoreData packet") + } + + block, rest := pem.Decode(data[5:]) + if block == nil { + return nil, fmt.Errorf("no pem data found, data: %s", rest) + } + pkix, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %s", err) + } + return pkix.(*rsa.PublicKey), nil +} + // ReadNullTerminatedBytes reads bytes from reader until 0x00 byte func ReadNullTerminatedBytes(r *bytes.Reader) (str []byte) { for { @@ -779,7 +966,7 @@ func CheckPacketLength(expected int, packet []byte) error { } // NativePassword calculates native password expected by server in HandshakeResponse41 -// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41 +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html#sect_protocol_connection_phase_packets_protocol_handshake_response41 // SHA1( password ) XOR SHA1( "20-bytes random data from server" SHA1( SHA1( password ) ) ) func NativePassword(password []byte, salt []byte) (nativePassword []byte, err error) { sha1 := sha1.New() @@ -804,38 +991,104 @@ func NativePassword(password []byte, salt []byte) (nativePassword []byte, err er return } -// UpdateHeaderPayloadLength takes in a 4 byte header and a difference -// in length, and returns a new header -func UpdateHeaderPayloadLength(origHeader []byte, diff int32) (header []byte, err error) { +// AddHeaderToPacket adds a header to a packet +func AddHeaderToPacket(sequenceID uint8, restOfPacket []byte) []byte { + // Calculate the packet length (excluding the length field itself) + packetLength := len(restOfPacket) - initialPayloadLength, err := ReadUint24(origHeader[0:3]) + // Create a header buffer and write the packet length (int<3>) + headerBuffer := make([]byte, 4) + headerBuffer[0] = byte(packetLength & 0xFF) + headerBuffer[1] = byte((packetLength >> 8) & 0xFF) + headerBuffer[2] = byte((packetLength >> 16) & 0xFF) + headerBuffer[3] = sequenceID + + // Combine the header and packet data to create the final packet + return append(headerBuffer, restOfPacket...) +} + +// PackAuthSwitchResponse creates an AuthSwitchResponse packet with the provided response data. +func PackAuthSwitchResponse(authSwitchRequestSequenceID uint8, data []byte) ([]byte, error) { + // Create a buffer to write the packet data + buffer := new(bytes.Buffer) + + // Write the response data to the buffer + buffer.Write(data) + + return AddHeaderToPacket(authSwitchRequestSequenceID, buffer.Bytes()), nil +} + +// AuthMoreDataResponse represents a packet sent from the server to request more auth data from the client. +type AuthMoreDataResponse struct { + SequenceID uint8 + PacketType byte + StatusTag byte +} + +// UnpackAuthMoreDataResponse decodes AuthMoreData from server. +// Basic packet structure shown below. +// +// int<3> PacketLength +// int<1> PacketNumber +// int<1> PacketType (0x01) +// int<1> StatusTag (0x03 or 0x04) +// string AuthenticationMethodData (unused by secretless) +func UnpackAuthMoreDataResponse(packet []byte) (*AuthMoreDataResponse, error) { + + // Min packet length = header(4 bytes) + PacketType(1 byte) + if err := CheckPacketLength(5, packet); err != nil { + return nil, err + } + + r := bytes.NewReader(packet) + + header, err := GetPacketHeader(r) + if err != nil { + return nil, err + } + + // Read header, validate OK + packetType, err := r.ReadByte() if err != nil { return nil, err } - updatedPayloadLength := int32(initialPayloadLength) + diff - if updatedPayloadLength < 0 { + if packetType != ResponseAuthMoreData { return nil, errors.New("Malformed packet") } - header = append(WriteUint24(uint32(updatedPayloadLength)), origHeader[3]) - return -} - -// ReadUint24 takes in a byte slice and returns a uint32 -func ReadUint24(b []byte) (uint32, error) { - if len(b) < 3 { - return 0, errors.New("Invalid packet") + // Read status tag + statusTag, err := r.ReadByte() + if err != nil { + return nil, err } - return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16, nil + return &AuthMoreDataResponse{ + SequenceID: header[3], + PacketType: packetType, + StatusTag: statusTag, + }, nil } -// WriteUint24 takes in a uint32 and returns a byte slice -func WriteUint24(u uint32) (b []byte) { - b = make([]byte, 3) - b[0] = byte(u) - b[1] = byte(u >> 8) - b[2] = byte(u >> 16) +// PackAuthRequestPubKeyResponse encodes the request for the server's public key +func PackAuthRequestPubKeyResponse(sequenceID uint8) []byte { + return AddHeaderToPacket(sequenceID, []byte{CachingSha2PasswordRequestPublicKey}) +} - return +// PackAuthEncryptedPasswordResponse encodes the encrypted password response packet +func PackAuthEncryptedPasswordResponse(sequenceID uint8, encPwd []byte) []byte { + return AddHeaderToPacket(sequenceID, encPwd) +} + +// EncryptPassword encrypts a password using the provided seed and public key. +func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { + // For this stage of the authentication, we must use sha1. See + // https://github.com/go-sql-driver/mysql/blob/19171b59bf90e6bf7a5bdf979e5e24a84b328b8a/auth.go#L217-L226 + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(seed) + plain[i] ^= seed[j] + } + sha1 := sha1.New() + return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) } diff --git a/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go b/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go index 33318f31a..eb57d1ce4 100644 --- a/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go +++ b/internal/plugin/connectors/tcp/mysql/protocol/protocol_test.go @@ -26,7 +26,11 @@ package protocol import ( "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" "encoding/binary" + "encoding/pem" "testing" "github.com/stretchr/testify/assert" @@ -193,9 +197,27 @@ func TestUnpackHandshakeV10(t *testing.T) { } } +func TestPackHandshakeV10(t *testing.T) { + input := &HandshakeV10{ + ProtocolVersion: byte(10), + ServerVersion: "5.5.56", + ConnectionID: uint32(1630), + AuthPlugin: "mysql_native_password", + ServerCapabilities: binary.LittleEndian.Uint32([]byte{255, 247, 15, 128}), + Salt: []byte{0x48, 0x6a, 0x5b, 0x6a, 0x24, 0x71, 0x30, 0x3a, 0x6f, 0x43, 0x40, 0x56, 0x6e, 0x4b, + 0x68, 0x4a, 0x79, 0x46, 0x30, 0x5a}, + } + + output, err := PackHandshakeV10(input) + newInput, err := UnpackHandshakeV10(output) + + assert.Equal(t, input, newInput) + assert.Equal(t, nil, err) +} + func TestUnpackHandshakeResponse41(t *testing.T) { expected := HandshakeResponse41{ - Header: []byte{0xaa, 0x0, 0x0, 0x1}, + SequenceID: 1, CapabilityFlags: uint32(33464965), MaxPacketSize: uint32(1073741824), ClientCharset: uint8(8), @@ -247,25 +269,25 @@ func TestInjectCredentials(t *testing.T) { expectedAuth := []byte{0xf, 0xf8, 0xe1, 0xa3, 0xe7, 0xe3, 0x5f, 0xd2, 0xb1, 0x69, 0x8c, 0x39, 0x5b, 0xfa, 0x99, 0x4f, 0x53, 0xdd, 0xe5, 0x35} // 20 - expectedHeader := []byte{0x99, 0x0, 0x0, 0x1} - // expectedHeader[0] = 0xaa + (8 - 14) + (20 - 20) + (21 - 32) + expectedHeader := []byte{0xa2, 0x0, 0x0, 0x1} + // expectedHeader[0] = 0xaa + (8 - 14) + (20 - 20) + (21 - 23) // test with handshake response that already has auth set to another value handshake := HandshakeResponse41{ AuthLength: int64(20), - AuthPluginName: "non_native_mysql_native_password", // 32 + AuthPluginName: "caching_sha256_password", // 23 AuthResponse: []byte{0xc0, 0xb, 0xbc, 0xb6, 0x6, 0xf5, 0x4f, 0x4e, 0xf4, 0x1b, 0x87, 0xc0, 0xb8, 0x89, 0xae, 0xc4, 0x49, 0x7c, 0x46, 0xf3}, // 20 - Username: "madeupusername", // 14 - Header: []byte{0xaa, 0x0, 0x0, 0x1}, + Username: "madeupusername", // 14 + SequenceID: 1, } - err := InjectCredentials(&handshake, salt, username, password) + err := InjectCredentials("mysql_native_password", &handshake, salt, username, password) assert.Equal(t, username, handshake.Username) assert.Equal(t, int64(20), handshake.AuthLength) assert.Equal(t, expectedAuth, handshake.AuthResponse) - assert.Equal(t, expectedHeader, handshake.Header) + assert.Equal(t, expectedHeader[3], handshake.SequenceID) assert.Equal(t, nil, err) // test with handshake response with empty auth and mysql_native_password @@ -276,21 +298,21 @@ func TestInjectCredentials(t *testing.T) { AuthPluginName: "mysql_native_password", // 21 AuthResponse: []byte{}, // 0 Username: "madeupusername", // 14 - Header: []byte{0xaa, 0x0, 0x0, 0x1}, + SequenceID: 1, } - err = InjectCredentials(&handshake, salt, username, password) + err = InjectCredentials("mysql_native_password", &handshake, salt, username, password) assert.Equal(t, username, handshake.Username) assert.Equal(t, int64(20), handshake.AuthLength) assert.Equal(t, expectedAuth, handshake.AuthResponse) - assert.Equal(t, expectedHeader, handshake.Header) + assert.Equal(t, expectedHeader[3], handshake.SequenceID) assert.Equal(t, nil, err) } func TestPackHandshakeResponse41(t *testing.T) { - input := HandshakeResponse41{ - Header: []byte{0xaa, 0x0, 0x0, 0x1}, + input := &HandshakeResponse41{ + SequenceID: 1, CapabilityFlags: uint32(33464965), MaxPacketSize: uint32(1073741824), ClientCharset: uint8(8), @@ -310,26 +332,11 @@ func TestPackHandshakeResponse41(t *testing.T) { 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x6, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34}, } - expected := []byte{0xaa, 0x0, 0x0, 0x1, 0x85, 0xa2, 0xfe, 0x1, 0x0, - 0x0, 0x0, 0x40, 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x72, 0x6f, 0x67, 0x65, 0x72, 0x0, 0x14, 0xc0, - 0xb, 0xbc, 0xb6, 0x6, 0xf5, 0x4f, 0x4e, 0xf4, 0x1b, 0x87, 0xc0, - 0xb8, 0x89, 0xae, 0xc4, 0x49, 0x7c, 0x46, 0xf3, 0x6d, 0x79, 0x73, - 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, - 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x0, 0x58, 0x3, 0x5f, - 0x6f, 0x73, 0xa, 0x6d, 0x61, 0x63, 0x6f, 0x73, 0x31, 0x30, 0x2e, - 0x31, 0x32, 0xc, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, - 0x6e, 0x61, 0x6d, 0x65, 0x8, 0x6c, 0x69, 0x62, 0x6d, 0x79, 0x73, - 0x71, 0x6c, 0x4, 0x5f, 0x70, 0x69, 0x64, 0x5, 0x36, 0x36, 0x34, - 0x37, 0x39, 0xf, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, - 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x6, 0x35, 0x2e, 0x37, - 0x2e, 0x32, 0x30, 0x9, 0x5f, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, - 0x72, 0x6d, 0x6, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34} - output, err := PackHandshakeResponse41(&input) + output, err := PackHandshakeResponse41(input) + newInput, err := UnpackHandshakeResponse41(output) - assert.Equal(t, expected, output) + assert.Equal(t, input, newInput) assert.Equal(t, nil, err) } @@ -436,51 +443,188 @@ func TestNativePassword(t *testing.T) { assert.Equal(t, nil, err) } -func TestUpdateHeaderPayloadLength(t *testing.T) { - // Test with a valid negative value - expectedHeader := []byte{170, 0, 0, 0} - inputHeader := []byte{173, 0, 0, 0} - inputLength := int32(-3) - - output, err := UpdateHeaderPayloadLength(inputHeader, inputLength) - - assert.Equal(t, expectedHeader, output) - assert.Equal(t, nil, err) - - // Test with a valid positive value - expectedHeader = []byte{176, 0, 0, 0} - inputHeader = []byte{173, 0, 0, 0} - inputLength = int32(3) +func TestCreateAuthResponse(t *testing.T) { + testCases := []struct { + authPlugin string + password []byte + salt []byte + expectedLen int + expectErr bool + }{ + { + authPlugin: "mysql_native_password", + password: []byte("password"), + salt: []byte("salt"), + expectedLen: 20, + }, + { + authPlugin: "caching_sha2_password", + password: []byte("password"), + salt: []byte("salt"), + expectedLen: 32, + }, + { + authPlugin: "unknown_auth_plugin", + password: []byte("password"), + salt: []byte("salt"), + expectErr: true, + }, + } - output, err = UpdateHeaderPayloadLength(inputHeader, inputLength) + for _, tc := range testCases { + actual, err := CreateAuthResponse(tc.authPlugin, tc.password, tc.salt) + if tc.expectErr { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, tc.expectedLen, len(actual)) + } + } +} - assert.Equal(t, expectedHeader, output) - assert.Equal(t, nil, err) +func TestUnpackAuthSwitchRequest(t *testing.T) { + testCases := []struct { + name string + input []byte + expectedError string + expectedReq *AuthSwitchRequest + }{ + { + name: "valid AuthSwitchRequest packet", + input: []byte{ + 0x02, 0x00, 0x00, 0x01, // Header + 0x01, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x00, // Plugin name ("plugin") + 0x01, 0x02, 0x03, // Plugin data + }, + expectedReq: &AuthSwitchRequest{ + SequenceNumber: 1, + PluginName: "plugin", + PluginData: []byte{0x01, 0x02, 0x03}, + }, + }, + { + name: "missing plugin name", + input: []byte{0x00, 0x00, 0x00, 0x01}, + expectedError: "Invalid AuthSwitchRequest packet: Missing plugin name", + }, + } - // Test with an invalid value for the length difference - inputHeader = []byte{173, 0, 0, 0} - inputLength = int32(-180) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, err := UnpackAuthSwitchRequest(tc.input) + + if tc.expectedError != "" { + assert.EqualError(t, err, tc.expectedError) + assert.Nil(t, req) + } else { + assert.NoError(t, err) + assert.NotNil(t, req) + assert.Equal(t, tc.expectedReq.SequenceNumber, req.SequenceNumber) + assert.Equal(t, tc.expectedReq.PluginName, req.PluginName) + assert.Equal(t, tc.expectedReq.PluginData, req.PluginData) + } + }) + } +} - output, err = UpdateHeaderPayloadLength(inputHeader, inputLength) +func TestUnpackAuthRequestPubKeyResponse(t *testing.T) { + // Generate a test RSA public key + testPubKey, _ := rsa.GenerateKey(rand.Reader, 256) + testPubKeyBytes, _ := x509.MarshalPKIXPublicKey(&testPubKey.PublicKey) + testPubKeyPem := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: testPubKeyBytes, + }) + testPubKeyBytes = append([]byte{0x02, 0x00, 0x00, 0x04, 0x01}, testPubKeyPem...) + + testCases := []struct { + name string + input []byte + expectedError string + expectedResp *rsa.PublicKey + }{ + { + name: "valid AuthRequestPubKeyResponse packet", + input: testPubKeyBytes, + expectedResp: &testPubKey.PublicKey, + }, + { + name: "missing RSA public key", + input: []byte{0x02, 0x00, 0x00, 0x04, 0x01}, + expectedError: "no pem data found, data: ", + }, + { + name: "invalid RSA public key", + input: []byte{0x02, 0x00, 0x00, 0x04, 0x01, 0x01, 0x02}, + expectedError: "no pem data found, data: \x01\x02", + }, + } - assert.EqualError(t, err, "Malformed packet") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resp, err := UnpackAuthRequestPubKeyResponse(tc.input) + + if tc.expectedError != "" { + assert.EqualError(t, err, tc.expectedError) + assert.Nil(t, resp) + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.EqualValues(t, tc.expectedResp, resp) + } + }) + } } -func TestReadUint24(t *testing.T) { - expected := uint32(173) - input := []byte{173, 0, 0} +func TestPackAuthSwitchResponse(t *testing.T) { + data := []byte{0x01, 0x02, 0x03} + seqID := uint8(9) - output, err := ReadUint24(input) + expected := []byte{ + 0x03, 0x00, 0x00, 0x09, // Header (including sequence number) + 0x01, 0x02, 0x03, // Data + } + output, err := PackAuthSwitchResponse(seqID, data) + assert.NoError(t, err) assert.Equal(t, expected, output) - assert.Equal(t, nil, err) } -func TestWriteUint24(t *testing.T) { - expected := []byte{173, 0, 0} - input := uint32(173) - - output := WriteUint24(input) +func TestUnpackAuthMoreDataResponse(t *testing.T) { + testCases := []struct { + name string + input []byte + expectedError string + expectedResp *AuthMoreDataResponse + }{ + { + name: "valid AuthMoreDataResponse packet", + input: []byte{0x01, 0x00, 0x00, 0x09, 0x01, 0x04}, + expectedResp: &AuthMoreDataResponse{ + SequenceID: 9, + PacketType: 1, + StatusTag: 4, + }, + }, + { + name: "missing data", + input: []byte{0x01, 0x00, 0x00, 0x09}, + expectedError: ErrInvalidPacketLength.Error(), + }, + } - assert.Equal(t, expected, output) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resp, err := UnpackAuthMoreDataResponse(tc.input) + + if tc.expectedError != "" { + assert.EqualError(t, err, tc.expectedError) + assert.Nil(t, resp) + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, tc.expectedResp, resp) + } + }) + } } diff --git a/test/connector/tcp/mysql/Dockerfile.dev b/test/connector/tcp/mysql/Dockerfile.dev index 99cec2ca6..cb2f88bec 100644 --- a/test/connector/tcp/mysql/Dockerfile.dev +++ b/test/connector/tcp/mysql/Dockerfile.dev @@ -1,4 +1,3 @@ FROM secretless-dev -RUN apt-get update && \ - apt-get install -y default-mysql-client +COPY --from=mysql:8.1 /usr/bin/mysql /usr/bin/mysql diff --git a/test/connector/tcp/mysql/Dockerfile.mysql b/test/connector/tcp/mysql/Dockerfile.mysql deleted file mode 100644 index ee12922ee..000000000 --- a/test/connector/tcp/mysql/Dockerfile.mysql +++ /dev/null @@ -1,8 +0,0 @@ -FROM mysql/mysql-server:5.7 -RUN yum install openssl -y -RUN mkdir -p /etc/mysql/mysql.conf.d/ - -COPY etc/toggle_ssl.sh /docker-entrypoint-initdb.d/ -COPY etc/test.sql /docker-entrypoint-initdb.d/ - -COPY ./ssl /ssl diff --git a/test/connector/tcp/mysql/README.md b/test/connector/tcp/mysql/README.md index 3e71cd912..f64d0127d 100644 --- a/test/connector/tcp/mysql/README.md +++ b/test/connector/tcp/mysql/README.md @@ -1,46 +1,58 @@ -# MySQL Handler Development +# MySQL Connector Development ## Usage / known limitations -- The MySQL handler is currently limited to connections via Unix domain socket +- The MySQL connector is currently limited to connections via Unix domain socket + +### To use the MySQL connector -### To use the MySQL handler: #### Start your MySQL server + From this directory, call -``` + +```sh docker-compose up -d mysql ``` + This will automatically start a MySQL server in a Docker container at `localhost:$(docker-compose port mysql 3306)`. It will also configure the MySQL server as follows: + - Create a `testuser` user (with password `testpass`) - Authorize the `testuser` user to connect to the database server from any IP and access any schema - Create a table `test` in the `testdb` schema and add two rows #### Start and configure secretless-broker + From the root project directory, build the Secretless Broker binaries for your platform: -``` + +```sh platform=$(go run test/print_platform.go) ./bin/build $platform amd64 ``` From this directory, start Secretless Broker: -``` -./run_dev + +```sh +./dev ``` -#### Log in to the MySQL server via the MySQL handler +#### Log in to the MySQL server via the MySQL connector + In another terminal, navigate to the `test/mysql_handler` directory and send a MySQL request via Unix socket: _Note: Since the Secretless Broker container runs the daemon as a limited user, sockets should be mounted to `/sock` directory._ -``` +```sh mysql --socket=sock/mysql.sock ``` + or via TCP: + +```sh +mysql -h 0.0.0.0 -P 13306 -u testuser --ssl-mode=DISABLED ``` -mysql -h 0.0.0.0 -P 13306 --ssl-mode=DISABLED -``` + You may be prompted for a password, but you don't need to enter one; just hit return to continue. Once logged in, you should be able to `SELECT * FROM testdb.test` and see the rows that were added to the sample table. @@ -48,11 +60,13 @@ Once logged in, you should be able to `SELECT * FROM testdb.test` and see the ro Note: this assumes you have a MySQL client installed locally on your machine. In the examples above and when you run the test suite locally, it is assumed you use one like [mysqlsh](https://dev.mysql.com/doc/refman/5.7/en/mysqlsh.html), which assumes SSL connections when possible by default (and has an `--ssl-mode` flag you can use to disable SSL). If you use `mysqlsh`, you will need to create an executable `mysql` file in your `PATH` that contains the following in order to be able to run `run_dev_test` locally: -``` + +```sh #!/bin/bash -ex mysqlsh --sql "$@" ``` + This will run the MySQL shell as a client in SQL mode. ## MySQL Handler Development @@ -62,7 +76,8 @@ This will run the MySQL shell as a client in SQL mode. The easiest way to do Secretless Broker development is to use the VS Code debugger. As above, you will want to start up your MySQL server container before beginning development. To configure the Secretless Broker, you can provide VS Code with a `launch.json` file for debugging by copying the sample file below to `.vscode/launch.json`, replacing `[YOUR MYSQL PORT]` with the actual exposed port of your MySQL Docker container. Sample `launch.json`: -``` + +```json { // Use IntelliSense to learn about possible attributes. // Hover to view descriptions of existing attributes. @@ -70,29 +85,30 @@ Sample `launch.json`: "version": "0.2.0", "configurations": [ { - "name": "MySQL Handler", + "name": "MySQL Connector", "type": "go", "request": "launch", "mode": "debug", "remotePath": "", "port": 2345, "host": "127.0.0.1", - "program": "${workspaceFolder}/cmd/secretless/", + "program": "${workspaceFolder}/cmd/secretless-broker/", "env": { "MYSQL_HOST": "localhost", "MYSQL_PORT": "[YOUR MYSQL PORT]", "MYSQL_PASSWORD": "testpass" }, - "args": [ "-f", "/Users/gjennings/go/src/github.com/cyberark/secretless-broker/test/mysql_handler/secretless.dev.yml"], + "args": [ "-f", "${workspaceFolder}/test/connector/tcp/mysql/fixtures/secretless.dev.yml"], "showLog": true } ] } ``` -Once you start the debugger (which will automatically start the Secretless Broker with the dev MySQL Handler configuration), you can send requests to the MySQL server via a client as described above. +Once you start the debugger (which will automatically start the Secretless Broker with the dev MySQL Connector configuration), you can send requests to the MySQL server via a client as described above. ### Using Docker You can also run: -``` -cd test/mysql_handler/ + +```sh +cd test/connector/tcp/mysql/ ./start docker-compose run --rm secretless-dev ``` @@ -106,14 +122,17 @@ to connect via Unix socket. ## Running the test suite #### Run the tests in Docker + Make sure you have built updated Secretless Broker binaries for Linux and updated Docker images before running the test suite. To run the test suite in Docker, run: -``` + +```sh ./stop # Remove all existing project containers ./start # Stand up MySQL and Secretless Broker servers ./test # Run tests in a test container ``` + Make sure you build the project by running `./bin/build` in the project root before running the tests so that the test container will be using updated code. If you want to run using your local changes, you can run `./test -l` diff --git a/test/connector/tcp/mysql/docker-compose.yml b/test/connector/tcp/mysql/docker-compose.yml index 550b8a388..c2beab9e6 100644 --- a/test/connector/tcp/mysql/docker-compose.yml +++ b/test/connector/tcp/mysql/docker-compose.yml @@ -1,10 +1,8 @@ version: '3.0' services: - mysql: - build: - context: . - dockerfile: Dockerfile.mysql + mysql_no_tls: &mysql_no_tls + image: mysql:8.1 ports: - 3306 healthcheck: @@ -12,23 +10,17 @@ services: interval: 1s timeout: 30s environment: - NO_SSL: "false" - MYSQL_ROOT_PASSWORD: securerootpass - - mysql_no_tls: - build: - context: . - dockerfile: Dockerfile.mysql - ports: - - 3306 - healthcheck: - test: ["CMD-SHELL", "mysqladmin -psecurerootpass status"] - interval: 1s - timeout: 30s - environment: - NO_SSL: "true" MYSQL_ROOT_PASSWORD: securerootpass + volumes: + - ./etc/test.sql:/docker-entrypoint-initdb.d/test.sql + - ./etc/no-ssl.cnf:/etc/mysql/conf.d/no-ssl.cnf + mysql: + <<: *mysql_no_tls + volumes: + - ./etc/test.sql:/docker-entrypoint-initdb.d/test.sql + - ./etc/ssl.cnf:/etc/mysql/conf.d/ssl.cnf + - ./ssl:/etc/mysql-ssl secretless-dev: image: secretless-dev command: ./bin/reflex diff --git a/test/connector/tcp/mysql/etc/no-ssl.cnf b/test/connector/tcp/mysql/etc/no-ssl.cnf new file mode 100644 index 000000000..d46c18c66 --- /dev/null +++ b/test/connector/tcp/mysql/etc/no-ssl.cnf @@ -0,0 +1,2 @@ +[mysqld] +tls_version='' diff --git a/test/connector/tcp/mysql/etc/ssl.cnf b/test/connector/tcp/mysql/etc/ssl.cnf new file mode 100644 index 000000000..67733f279 --- /dev/null +++ b/test/connector/tcp/mysql/etc/ssl.cnf @@ -0,0 +1,4 @@ +[mysqld] +ssl-ca=/etc/mysql-ssl/ca.pem +ssl-cert=/etc/mysql-ssl/server.pem +ssl-key=/etc/mysql-ssl/server-key.pem diff --git a/test/connector/tcp/mysql/etc/test.sql b/test/connector/tcp/mysql/etc/test.sql index b43b4725f..907c6533c 100644 --- a/test/connector/tcp/mysql/etc/test.sql +++ b/test/connector/tcp/mysql/etc/test.sql @@ -1,5 +1,17 @@ -GRANT ALL PRIVILEGES ON *.* TO 'testuser'@'localhost' IDENTIFIED BY 'testpass'; -GRANT ALL PRIVILEGES ON *.* TO 'testuser'@'%' IDENTIFIED BY 'testpass'; +-- Create a user with the old mysql_native_password plugin which was the default in MySQL 5.7 +-- This allows us to log in using legacy MySQL clients +CREATE USER 'testuser_native_password'@'localhost' IDENTIFIED WITH mysql_native_password BY 'testpass'; +GRANT ALL PRIVILEGES ON *.* TO 'testuser_native_password'@'localhost'; + +CREATE USER 'testuser_native_password'@'%' IDENTIFIED WITH mysql_native_password BY 'testpass'; +GRANT ALL PRIVILEGES ON *.* TO 'testuser_native_password'@'%'; + +-- Create a user with the new caching_sha2_password plugin which is the default in MySQL 8.0 +CREATE USER 'testuser'@'localhost' IDENTIFIED BY 'testpass'; +GRANT ALL PRIVILEGES ON *.* TO 'testuser'@'localhost'; + +CREATE USER 'testuser'@'%' IDENTIFIED BY 'testpass'; +GRANT ALL PRIVILEGES ON *.* TO 'testuser'@'%'; CREATE DATABASE testdb; diff --git a/test/connector/tcp/mysql/etc/toggle_ssl.sh b/test/connector/tcp/mysql/etc/toggle_ssl.sh deleted file mode 100755 index ef9abc3bf..000000000 --- a/test/connector/tcp/mysql/etc/toggle_ssl.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash -e -# This script is used to build the ROOT/test/mysql_handler mysql container image -# The script expects ROOT/test/util/ssl to contain the pre-generated -# shared SSL fixtures used during testing -# -# This script is housed in /docker-entrypoint-initdb.d/ inside the container image -# The envvar NO_SSL is used to toggle SSL for the mysql container image at startup - -if [[ "$NO_SSL" = "true" ]] -then - echo "removing SSL support" - rm /var/lib/mysql/*.pem - echo "ssl=0" >> /etc/my.cnf -else - cp /ssl/ca.pem /var/lib/mysql/ca.pem - cp /ssl/ca-key.pem /var/lib/mysql/ca-key.pem - cp /ssl/client.pem /var/lib/mysql/client-cert.pem - cp /ssl/client-key.pem /var/lib/mysql/client-key.pem - cp /ssl/server.pem /var/lib/mysql/server-cert.pem - cp /ssl/server-key.pem /var/lib/mysql/server-key.pem -fi diff --git a/test/connector/tcp/mysql/pkg/run_test_case.go b/test/connector/tcp/mysql/pkg/run_test_case.go index a622a7657..44d9af4b6 100644 --- a/test/connector/tcp/mysql/pkg/run_test_case.go +++ b/test/connector/tcp/mysql/pkg/run_test_case.go @@ -18,7 +18,7 @@ func RunQuery( args := []string{"-e", "select count(*) from testdb.test"} if clientConfig.SSL { - args = append(args, "--ssl", "--ssl-verify-server-cert=TRUE") + args = append(args, "--ssl-mode", "VERIFY_CA") } if clientConfig.Username != "" { args = append(args, fmt.Sprintf("--user=%s", clientConfig.Username)) @@ -36,9 +36,6 @@ func RunQuery( panic("Listener Type can only be TCP or Socket") } - // ensures mysql can handle non-native auth - args = append(args, "--default-auth=mysql_clear_password") - // Pre command logs println("") println("---<< EXECUTED") diff --git a/test/connector/tcp/mysql/tests/essentials_test.go b/test/connector/tcp/mysql/tests/essentials_test.go index 26772567b..478775008 100644 --- a/test/connector/tcp/mysql/tests/essentials_test.go +++ b/test/connector/tcp/mysql/tests/essentials_test.go @@ -73,7 +73,7 @@ func TestEssentials(t *testing.T) { Password: "wrongpassword", SSL: true, }, - CmdOutput: StringPointer("ERROR 2026 (HY000): TLS/SSL error: SSL is required, but the server does not support it"), + CmdOutput: StringPointer("SSL is required but the server doesn't support it"), }, }, t) @@ -92,7 +92,7 @@ func TestEssentials(t *testing.T) { Password: "wrongpassword", SSL: true, }, - CmdOutput: StringPointer("ERROR 2026 (HY000): TLS/SSL error: SSL is required, but the server does not support it"), + CmdOutput: StringPointer("SSL is required but the server doesn't support it"), }, }, t) diff --git a/test/connector/tcp/pg/README.md b/test/connector/tcp/pg/README.md index 4c137b152..f525871fd 100644 --- a/test/connector/tcp/pg/README.md +++ b/test/connector/tcp/pg/README.md @@ -1,19 +1,24 @@ -# Postgresql Handler Development +# Postgresql Connector Development + +## TLDR -##TLDR The following two steps are all you need to know: + 1. A single command starts the dev workflow: ```sh-session $ ./dev ``` + 2. Another one runs the tests: + ```sh-session $ ./test ``` + So while developing, you'll do a single `./dev` and then many `./test` runs as you iteratively change the code or add new tests. -##Additional Details +## Additional Details `./dev` uses `docker-compose` to start both `pg` containers, the `secretless` container, and the `test` container (where tests are run from).