diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 000000000..14c7119a9 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1 @@ +FROM mcr.microsoft.com/vscode/devcontainers/go:0-1.18 \ No newline at end of file diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 000000000..8dd1b30a2 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,25 @@ +{ + "name": "Go & Mongo DB", + "dockerComposeFile": "docker-compose.yml", + "service": "app", + "workspaceFolder": "/workspace", + "runArgs": [ + "--cap-add=SYS_PTRACE", + "--security-opt", + "seccomp=unconfined" + ], + "customizations": { + "vscode": { + "settings": { + "go.toolsManagement.checkForUpdates": "local", + "go.useLanguageServer": true, + "go.gopath": "/go" + }, + "extensions": [ + "golang.Go", + "mongodb.mongodb-vscode" + ] + } + }, + "remoteUser": "vscode" +} \ No newline at end of file diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml new file mode 100644 index 000000000..deb58bf2a --- /dev/null +++ b/.devcontainer/docker-compose.yml @@ -0,0 +1,39 @@ +version: '3.8' + +services: + app: + build: + context: . + dockerfile: Dockerfile + volumes: + - ..:/workspace:cached + + # Overrides default command so things don't shut down after the process ends. + command: sleep infinity + + # Runs app on the same network as the database container, allows "forwardPorts" in devcontainer.json function. + network_mode: service:db + + # Uncomment the next line to use a non-root user for all processes. + # user: node + + # Use "forwardPorts" in **devcontainer.json** to forward an app port locally. + # (Adding the "ports" property to this file will not forward from a Codespace.) + + db: + image: mongo:latest + restart: unless-stopped + volumes: + - mongodb-data:/data/db + + # Uncomment to change startup options + # environment: + # MONGO_INITDB_ROOT_USERNAME: root + # MONGO_INITDB_ROOT_PASSWORD: example + # MONGO_INITDB_DATABASE: your-database-here + + # Add "forwardPorts": ["27017"] to **devcontainer.json** to forward MongoDB locally. + # (Adding the "ports" property to this file will not forward from a Codespace.) + +volumes: + mongodb-data: \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..d921d0ffd --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,7 @@ +version: 2 +updates: +- package-ecosystem: gomod + directory: "/" + schedule: + interval: daily + open-pull-requests-limit: 10 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 000000000..7de7c5db2 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,72 @@ +name: CI +on: + push: + tags: + - v* + branches: + - master + - main + pull_request: + branches: + - master + - main +permissions: + contents: read + +jobs: + golangci: + permissions: + contents: read # for actions/checkout to fetch code + pull-requests: read # for golangci/golangci-lint-action to fetch pull requests + name: Linter + runs-on: ubuntu-latest + steps: + - name: Checkout source code + uses: actions/checkout@v2 + - name: Setup Go + uses: actions/setup-go@v2 + with: + go-version: '1.18' + - name: Install golangci-lint + run: | + curl -sSLO https://github.com/golangci/golangci-lint/releases/download/v$GOLANGCI_LINT_VERSION/golangci-lint-$GOLANGCI_LINT_VERSION-linux-amd64.tar.gz + shasum -a 256 golangci-lint-$GOLANGCI_LINT_VERSION-linux-amd64.tar.gz | grep "^$GOLANGCI_LINT_SHA256 " > /dev/null + tar -xf golangci-lint-$GOLANGCI_LINT_VERSION-linux-amd64.tar.gz + sudo mv golangci-lint-$GOLANGCI_LINT_VERSION-linux-amd64/golangci-lint /usr/local/bin/golangci-lint + rm -rf golangci-lint-$GOLANGCI_LINT_VERSION-linux-amd64* + env: + GOLANGCI_LINT_VERSION: '1.46.2' + GOLANGCI_LINT_SHA256: '242cd4f2d6ac0556e315192e8555784d13da5d1874e51304711570769c4f2b9b' + - name: Test style + run: make test-style + + build: + name: build + runs-on: ubuntu-latest + strategy: + matrix: + go: [1.18] + fix-version: + - FIX_TEST= + - FIX_TEST=fix40 + - FIX_TEST=fix41 + - FIX_TEST=fix42 + - FIX_TEST=fix43 + - FIX_TEST=fix44 + - FIX_TEST=fix50 + - FIX_TEST=fix50sp1 + - FIX_TEST=fix50sp2 + steps: + - name: Setup + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + - name: Check out source + uses: actions/checkout@v2 + - name: Run Mongo + run: docker run -d -p 27017:27017 mongo + - name: Test + env: + GO111MODULE: "on" + MONGODB_TEST_CXN: "localhost" + run: make generate; if [ -z "$FIX_TEST" ]; then make build; make; else make build_accept; make $FIX_TEST; fi diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 000000000..4e51f8c12 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,71 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ main ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ main ] + schedule: + - cron: '42 21 * * 3' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'go' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] + # Learn more: + # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v1 + + # ℹī¸ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏ī¸ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/.gitignore b/.gitignore index c6137fd7a..ff2f9b241 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *~ *.swp *.swo +.idea vendor _test/test _test/echo_server diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 000000000..15aa0a7b1 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,29 @@ +run: + timeout: 10m + skip-dirs: + - gen + - vendor + +linters: + disable-all: true + enable: + - deadcode + - dupl + - gofmt + - goimports + - gosimple + - govet + - ineffassign + - misspell + - revive + - unused + - varcheck + - staticcheck + +linters-settings: + gofmt: + simplify: true + goimports: + local-prefixes: github.com/quickfixgo/quickfix + dupl: + threshold: 400 \ No newline at end of file diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 5849b2f95..000000000 --- a/.travis.yml +++ /dev/null @@ -1,43 +0,0 @@ -language: go -sudo: false - -services: - - mongodb - -go: - - 1.11.x - - 1.12.x - - 1.13 - - tip - -env: - global: - - GO111MODULE=on - - MONGODB_TEST_CXN=localhost - matrix: - - FIX_TEST= - - FIX_TEST=fix40 - - FIX_TEST=fix41 - - FIX_TEST=fix42 - - FIX_TEST=fix43 - - FIX_TEST=fix44 - - FIX_TEST=fix50 - - FIX_TEST=fix50sp1 - - FIX_TEST=fix50sp2 - -matrix: - allow_failures: - - go: tip - fast_finish: true - exclude: - - go: 1.13.x - env: GO111MODULE=on - - go: tip - env: GO111MODULE=on - -before_install: - - if [[ "${GO111MODULE}" = "on" ]]; then mkdir "${HOME}/go"; export GOPATH="${HOME}/go"; fi - - if [[ "${GO111MODULE}" = "on" ]]; then export PATH="${GOPATH}/bin:${GOROOT}/bin:${PATH}"; fi - - go mod download - -script: make generate; if [ -z "$FIX_TEST" ]; then make build; make; else make build_accept; make $FIX_TEST; fi diff --git a/Makefile b/Makefile index c883e5bd0..ece1c82af 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,26 @@ +GOBIN = $(shell go env GOBIN) +ifeq ($(GOBIN),) +GOBIN = $(shell go env GOPATH)/bin +endif +GOX = $(GOBIN)/gox +GOIMPORTS = $(GOBIN)/goimports +ARCH = $(shell uname -p) + +# ------------------------------------------------------------------------------ +# dependencies + +# If go install is run from inside the project directory it will add the +# dependencies to the go.mod file. To avoid that we change to a directory +# without a go.mod file when downloading the following dependencies + +$(GOX): + (cd /; GO111MODULE=on go install github.com/mitchellh/gox@latest) + +$(GOIMPORTS): + (cd /; GO111MODULE=on go install golang.org/x/tools/cmd/goimports@latest) + +# ------------------------------------------------------------------------------ + all: vet test clean: @@ -5,22 +28,26 @@ clean: generate: clean mkdir -p gen; cd gen; go run ../cmd/generate-fix/generate-fix.go ../spec/*.xml + go get -u all generate-dist: cd ..; go run quickfix/cmd/generate-fix/generate-fix.go quickfix/spec/*.xml +test-style: + GO111MODULE=on golangci-lint run + +.PHONY: format +format: $(GOIMPORTS) + GO111MODULE=on go list -f '{{.Dir}}' ./... | xargs $(GOIMPORTS) -w -local github.com/quickfixgo/quickfix + fmt: go fmt `go list ./... | grep -v quickfix/gen` vet: go vet `go list ./... | grep -v quickfix/gen` -lint: - go get github.com/golang/lint/golint - golint . - test: - go test -v -cover . ./datadictionary ./internal + MONGODB_TEST_CXN=localhost go test -v -cover . ./datadictionary ./internal _build_all: go build -v `go list ./...` diff --git a/README.md b/README.md index d615b4db7..fbfd16f81 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ QuickFIX/Go =========== -[![GoDoc](https://godoc.org/github.com/quickfixgo/quickfix?status.png)](https://godoc.org/github.com/quickfixgo/quickfix) [![Build Status](https://travis-ci.org/quickfixgo/quickfix.svg?branch=master)](https://travis-ci.org/quickfixgo/quickfix) [![Go Report Card](https://goreportcard.com/badge/github.com/quickfixgo/quickfix)](https://goreportcard.com/report/github.com/quickfixgo/quickfix) +[![Build Status](https://github.com/quickfixgo/quickfix/workflows/CI/badge.svg)](https://github.com/quickfixgo/quickfix/actions) [![GoDoc](https://godoc.org/github.com/quickfixgo/quickfix?status.png)](https://godoc.org/github.com/quickfixgo/quickfix) [![Go Report Card](https://goreportcard.com/badge/github.com/quickfixgo/quickfix)](https://goreportcard.com/report/github.com/quickfixgo/quickfix) - Website: http://www.quickfixgo.org - Mailing list: [Google Groups](https://groups.google.com/forum/#!forum/quickfixgo) diff --git a/accepter_test.go b/accepter_test.go new file mode 100644 index 000000000..5e032cbb1 --- /dev/null +++ b/accepter_test.go @@ -0,0 +1,70 @@ +package quickfix + +import ( + "net" + "testing" + + "github.com/quickfixgo/quickfix/config" + + proxyproto "github.com/armon/go-proxyproto" + "github.com/stretchr/testify/assert" +) + +func TestAcceptor_Start(t *testing.T) { + sessionSettings := NewSessionSettings() + sessionSettings.Set(config.BeginString, BeginStringFIX42) + sessionSettings.Set(config.SenderCompID, "sender") + sessionSettings.Set(config.TargetCompID, "target") + + settingsWithTCPProxy := NewSettings() + settingsWithTCPProxy.GlobalSettings().Set("UseTCPProxy", "Y") + + settingsWithNoTCPProxy := NewSettings() + settingsWithNoTCPProxy.GlobalSettings().Set("UseTCPProxy", "N") + + genericSettings := NewSettings() + + const ( + GenericListener = iota + ProxyListener + ) + + acceptorStartTests := []struct { + name string + settings *Settings + listenerType int + }{ + {"with TCP proxy set", settingsWithTCPProxy, ProxyListener}, + {"with no TCP proxy set", settingsWithNoTCPProxy, GenericListener}, + {"no TCP proxy configuration set", genericSettings, GenericListener}, + } + + for _, tt := range acceptorStartTests { + t.Run(tt.name, func(t *testing.T) { + tt.settings.GlobalSettings().Set("SocketAcceptPort", "5001") + if _, err := tt.settings.AddSession(sessionSettings); err != nil { + assert.Nil(t, err) + } + + acceptor := &Acceptor{settings: tt.settings} + if err := acceptor.Start(); err != nil { + assert.NotNil(t, err) + } + assert.Len(t, acceptor.listeners, 1) + + for _, listener := range acceptor.listeners { + if tt.listenerType == ProxyListener { + _, ok := listener.(*proxyproto.Listener) + assert.True(t, ok) + } + + if tt.listenerType == GenericListener { + _, ok := listener.(*net.TCPListener) + assert.True(t, ok) + } + } + + acceptor.Stop() + }) + } +} diff --git a/acceptor.go b/acceptor.go index 62d17c91d..0ce8e8724 100644 --- a/acceptor.go +++ b/acceptor.go @@ -10,6 +10,8 @@ import ( "strconv" "sync" + proxyproto "github.com/armon/go-proxyproto" + "github.com/quickfixgo/quickfix/config" ) @@ -22,54 +24,78 @@ type Acceptor struct { globalLog Log sessions map[SessionID]*session sessionGroup sync.WaitGroup - listener net.Listener listenerShutdown sync.WaitGroup dynamicSessions bool dynamicQualifier bool dynamicQualifierCount int dynamicSessionChan chan *session - sessionAddr map[SessionID]net.Addr + sessionAddr sync.Map + sessionHostPort map[SessionID]int + listeners map[string]net.Listener + connectionValidator ConnectionValidator sessionFactory } +// ConnectionValidator is an interface allowing to implement a custom authentication logic. +type ConnectionValidator interface { + // Validate the connection for validity. This can be a part of authentication process. + // For example, you may tie up a SenderCompID to an IP range, or to a specific TLS certificate as a part of mTLS. + Validate(netConn net.Conn, session SessionID) error +} + //Start accepting connections. -func (a *Acceptor) Start() error { +func (a *Acceptor) Start() (err error) { socketAcceptHost := "" if a.settings.GlobalSettings().HasSetting(config.SocketAcceptHost) { - var err error if socketAcceptHost, err = a.settings.GlobalSettings().Setting(config.SocketAcceptHost); err != nil { - return err + return } } - socketAcceptPort, err := a.settings.GlobalSettings().IntSetting(config.SocketAcceptPort) - if err != nil { - return err + a.sessionHostPort = make(map[SessionID]int) + a.listeners = make(map[string]net.Listener) + for sessionID, sessionSettings := range a.settings.SessionSettings() { + if sessionSettings.HasSetting(config.SocketAcceptPort) { + if a.sessionHostPort[sessionID], err = sessionSettings.IntSetting(config.SocketAcceptPort); err != nil { + return + } + } else if a.sessionHostPort[sessionID], err = a.settings.GlobalSettings().IntSetting(config.SocketAcceptPort); err != nil { + return + } + address := net.JoinHostPort(socketAcceptHost, strconv.Itoa(a.sessionHostPort[sessionID])) + a.listeners[address] = nil } var tlsConfig *tls.Config if tlsConfig, err = loadTLSConfig(a.settings.GlobalSettings()); err != nil { - return err + return } - address := net.JoinHostPort(socketAcceptHost, strconv.Itoa(socketAcceptPort)) - if tlsConfig != nil { - if a.listener, err = tls.Listen("tcp", address, tlsConfig); err != nil { - return err + var useTCPProxy bool + if a.settings.GlobalSettings().HasSetting(config.UseTCPProxy) { + if useTCPProxy, err = a.settings.GlobalSettings().BoolSetting(config.UseTCPProxy); err != nil { + return } - } else { - if a.listener, err = net.Listen("tcp", address); err != nil { - return err + } + + for address := range a.listeners { + if tlsConfig != nil { + if a.listeners[address], err = tls.Listen("tcp", address, tlsConfig); err != nil { + return + } + } else if a.listeners[address], err = net.Listen("tcp", address); err != nil { + return + } else if useTCPProxy { + a.listeners[address] = &proxyproto.Listener{Listener: a.listeners[address]} } } - for sessionID := range a.sessions { - session := a.sessions[sessionID] + for _, s := range a.sessions { a.sessionGroup.Add(1) - go func() { - session.run() + go func(s *session) { + s.run() a.sessionGroup.Done() - }() + }(s) } if a.dynamicSessions { a.dynamicSessionChan = make(chan *session) @@ -79,9 +105,11 @@ func (a *Acceptor) Start() error { a.sessionGroup.Done() }() } - a.listenerShutdown.Add(1) - go a.listenForConnections() - return nil + a.listenerShutdown.Add(len(a.listeners)) + for _, listener := range a.listeners { + go a.listenForConnections(listener) + } + return } //Stop logs out existing sessions, close their connections, and stop accepting new connections. @@ -90,7 +118,9 @@ func (a *Acceptor) Stop() { _ = recover() // suppress sending on closed channel error }() - a.listener.Close() + for _, listener := range a.listeners { + listener.Close() + } a.listenerShutdown.Wait() if a.dynamicSessions { close(a.dynamicSessionChan) @@ -101,21 +131,26 @@ func (a *Acceptor) Stop() { a.sessionGroup.Wait() } -//Get remote IP address for a given session. +// RemoteAddr gets remote IP address for a given session. func (a *Acceptor) RemoteAddr(sessionID SessionID) (net.Addr, bool) { - addr, ok := a.sessionAddr[sessionID] - return addr, ok + addr, ok := a.sessionAddr.Load(sessionID) + if !ok || addr == nil { + return nil, false + } + val, ok := addr.(net.Addr) + return val, ok } //NewAcceptor creates and initializes a new Acceptor. func NewAcceptor(app Application, storeFactory MessageStoreFactory, settings *Settings, logFactory LogFactory) (a *Acceptor, err error) { a = &Acceptor{ - app: app, - storeFactory: storeFactory, - settings: settings, - logFactory: logFactory, - sessions: make(map[SessionID]*session), - sessionAddr: make(map[SessionID]net.Addr), + app: app, + storeFactory: storeFactory, + settings: settings, + logFactory: logFactory, + sessions: make(map[SessionID]*session), + sessionHostPort: make(map[SessionID]int), + listeners: make(map[string]net.Listener), } if a.settings.GlobalSettings().HasSetting(config.DynamicSessions) { if a.dynamicSessions, err = settings.globalSettings.BoolSetting(config.DynamicSessions); err != nil { @@ -149,11 +184,11 @@ func NewAcceptor(app Application, storeFactory MessageStoreFactory, settings *Se return } -func (a *Acceptor) listenForConnections() { +func (a *Acceptor) listenForConnections(listener net.Listener) { defer a.listenerShutdown.Done() for { - netConn, err := a.listener.Accept() + netConn, err := listener.Accept() if err != nil { return } @@ -253,6 +288,21 @@ func (a *Acceptor) handleConnection(netConn net.Conn) { SenderCompID: string(targetCompID), SenderSubID: string(targetSubID), SenderLocationID: string(targetLocationID), TargetCompID: string(senderCompID), TargetSubID: string(senderSubID), TargetLocationID: string(senderLocationID), } + + localConnectionPort := netConn.LocalAddr().(*net.TCPAddr).Port + if expectedPort, ok := a.sessionHostPort[sessID]; !ok || expectedPort != localConnectionPort { + a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes) + return + } + + // We have a session ID and a network connection. This seems to be a good place for any custom authentication logic. + if a.connectionValidator != nil { + if err := a.connectionValidator.Validate(netConn, sessID); err != nil { + a.globalLog.OnEventf("Unable to validate a connection %v", err.Error()) + return + } + } + if a.dynamicQualifier { a.dynamicQualifierCount++ sessID.Qualifier = strconv.Itoa(a.dynamicQualifierCount) @@ -273,7 +323,7 @@ func (a *Acceptor) handleConnection(netConn net.Conn) { defer session.stop() } - a.sessionAddr[sessID] = netConn.RemoteAddr() + a.sessionAddr.Store(sessID, netConn.RemoteAddr()) msgIn := make(chan fixIn) msgOut := make(chan []byte) @@ -320,7 +370,7 @@ LOOP: case id := <-complete: session, ok := sessions[id] if ok { - delete(a.sessionAddr, session.sessionID) + a.sessionAddr.Delete(session.sessionID) delete(sessions, id) } else { a.globalLog.OnEventf("Missing dynamic session %v!", id) @@ -339,3 +389,12 @@ LOOP: } } } + +// SetConnectionValidator sets an optional connection validator. +// Use it when you need a custom authentication logic that includes lower level interactions, +// like mTLS auth or IP whitelistening. +// To remove a previously set validator call it with a nil value: +// a.SetConnectionValidator(nil) +func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) { + a.connectionValidator = validator +} diff --git a/application.go b/application.go index 09b3b49a3..e91ab1021 100644 --- a/application.go +++ b/application.go @@ -1,26 +1,26 @@ package quickfix -//The Application interface should be implemented by FIX Applications. +//Application interface should be implemented by FIX Applications. //This is the primary interface for processing messages from a FIX Session. type Application interface { - //Notification of a session begin created. + //OnCreate notification of a session begin created. OnCreate(sessionID SessionID) - //Notification of a session successfully logging on. + //OnLogon notification of a session successfully logging on. OnLogon(sessionID SessionID) - //Notification of a session logging off or disconnecting. + //OnLogout notification of a session logging off or disconnecting. OnLogout(sessionID SessionID) - //Notification of admin message being sent to target. + //ToAdmin notification of admin message being sent to target. ToAdmin(message *Message, sessionID SessionID) - //Notification of app message being sent to target. + //ToApp notification of app message being sent to target. ToApp(message *Message, sessionID SessionID) error - //Notification of admin message being received from target. + //FromAdmin notification of admin message being received from target. FromAdmin(message *Message, sessionID SessionID) MessageRejectError - //Notification of app message being received from target. + //FromApp notification of app message being received from target. FromApp(message *Message, sessionID SessionID) MessageRejectError } diff --git a/config/configuration.go b/config/configuration.go index ac74ff68a..ed7f1303e 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -29,6 +29,7 @@ const ( ProxyPort string = "ProxyPort" ProxyUser string = "ProxyUser" ProxyPassword string = "ProxyPassword" + UseTCPProxy string = "UseTCPProxy" DefaultApplVerID string = "DefaultApplVerID" StartTime string = "StartTime" EndTime string = "EndTime" @@ -46,6 +47,7 @@ const ( LogoutTimeout string = "LogoutTimeout" LogonTimeout string = "LogonTimeout" HeartBtInt string = "HeartBtInt" + HeartBtIntOverride string = "HeartBtIntOverride" FileLogPath string = "FileLogPath" FileStorePath string = "FileStorePath" SQLStoreDriver string = "SQLStoreDriver" diff --git a/config/doc.go b/config/doc.go index 0538b7c5f..c7223bbad 100644 --- a/config/doc.go +++ b/config/doc.go @@ -244,7 +244,15 @@ Defaults to 10 HeartBtInt -Heartbeat interval in seconds. Only used for initiators. Value must be positive integer. +Heartbeat interval in seconds. Only used for initiators (unless HeartBtIntOverride is Y). Value must be positive integer. + +HeartBtIntOverride + +If set to Y, will use the HeartBtInt interval rather than what the initiator dictates. Only used for acceptors. Valid Values: + Y + N + +Defaults to N. SocketConnectPort @@ -325,6 +333,12 @@ ProxyPassword Proxy password +UseTCPProxy + +Use TCP proxy for servers listening behind HAProxy of Amazon ELB load balancers. The server can then receive the address of the client instead of the load balancer's. Valid Values: + Y + N + PersistMessages If set to N, no messages will be persisted. This will force QuickFIX/Go to always send GapFills instead of resending messages. Use this if you know you never want to resend a message. Useful for market data streams. Valid Values: diff --git a/datadictionary/component_type_test.go b/datadictionary/component_type_test.go index 6de2b345a..c9774764b 100644 --- a/datadictionary/component_type_test.go +++ b/datadictionary/component_type_test.go @@ -3,8 +3,9 @@ package datadictionary_test import ( "testing" - "github.com/quickfixgo/quickfix/datadictionary" "github.com/stretchr/testify/assert" + + "github.com/quickfixgo/quickfix/datadictionary" ) func TestNewComponentType(t *testing.T) { diff --git a/datadictionary/datadictionary.go b/datadictionary/datadictionary.go index 410ca8ee1..ea93f6b25 100644 --- a/datadictionary/datadictionary.go +++ b/datadictionary/datadictionary.go @@ -3,7 +3,10 @@ package datadictionary import ( "encoding/xml" + "io" "os" + + "github.com/pkg/errors" ) //DataDictionary models FIX messages, components, and fields. @@ -197,26 +200,7 @@ func (f FieldDef) childTags() []int { for _, f := range f.Fields { tags = append(tags, f.Tag()) - for _, t := range f.childTags() { - tags = append(tags, t) - } - } - - return tags -} - -func (f FieldDef) requiredChildTags() []int { - var tags []int - - for _, f := range f.Fields { - if !f.Required() { - continue - } - - tags = append(tags, f.Tag()) - for _, t := range f.requiredChildTags() { - tags = append(tags, t) - } + tags = append(tags, f.childTags()...) } return tags @@ -315,24 +299,30 @@ func NewMessageDef(name, msgType string, parts []MessagePart) *MessageDef { return &msg } -//Parse loads and and build a datadictionary instance from an xml file. +//Parse loads and build a datadictionary instance from an xml file. func Parse(path string) (*DataDictionary, error) { var xmlFile *os.File - xmlFile, err := os.Open(path) + var err error + xmlFile, err = os.Open(path) if err != nil { - return nil, err + return nil, errors.Wrapf(err, "problem opening file: %v", path) } defer xmlFile.Close() + return ParseSrc(xmlFile) +} + +//ParseSrc loads and build a datadictionary instance from an xml source. +func ParseSrc(xmlSrc io.Reader) (*DataDictionary, error) { doc := new(XMLDoc) - decoder := xml.NewDecoder(xmlFile) + decoder := xml.NewDecoder(xmlSrc) if err := decoder.Decode(doc); err != nil { - return nil, err + return nil, errors.Wrapf(err, "problem parsing XML file") } b := new(builder) - var dict *DataDictionary - if dict, err = b.build(doc); err != nil { + dict, err := b.build(doc) + if err != nil { return nil, err } diff --git a/datadictionary/datadictionary_test.go b/datadictionary/datadictionary_test.go index 4179716c5..11d98bfca 100644 --- a/datadictionary/datadictionary_test.go +++ b/datadictionary/datadictionary_test.go @@ -82,7 +82,7 @@ func TestFieldsByTag(t *testing.T) { func TestEnumFieldsByTag(t *testing.T) { d, _ := dict() - f, _ := d.FieldTypeByTag[658] + f := d.FieldTypeByTag[658] var tests = []struct { Value string @@ -141,7 +141,7 @@ func TestDataDictionaryTrailer(t *testing.T) { func TestMessageRequiredTags(t *testing.T) { d, _ := dict() - nos, _ := d.Messages["D"] + nos := d.Messages["D"] var tests = []struct { *MessageDef @@ -169,7 +169,7 @@ func TestMessageRequiredTags(t *testing.T) { func TestMessageTags(t *testing.T) { d, _ := dict() - nos, _ := d.Messages["D"] + nos := d.Messages["D"] var tests = []struct { *MessageDef diff --git a/datadictionary/field_def_test.go b/datadictionary/field_def_test.go index 9c0a68abb..cb80fa131 100644 --- a/datadictionary/field_def_test.go +++ b/datadictionary/field_def_test.go @@ -3,8 +3,9 @@ package datadictionary_test import ( "testing" - "github.com/quickfixgo/quickfix/datadictionary" "github.com/stretchr/testify/assert" + + "github.com/quickfixgo/quickfix/datadictionary" ) func TestNewFieldDef(t *testing.T) { diff --git a/datadictionary/field_type_test.go b/datadictionary/field_type_test.go index 2a4328dff..00c651eef 100644 --- a/datadictionary/field_type_test.go +++ b/datadictionary/field_type_test.go @@ -3,8 +3,9 @@ package datadictionary_test import ( "testing" - "github.com/quickfixgo/quickfix/datadictionary" "github.com/stretchr/testify/assert" + + "github.com/quickfixgo/quickfix/datadictionary" ) func TestNewFieldType(t *testing.T) { diff --git a/datadictionary/group_field_def_test.go b/datadictionary/group_field_def_test.go index e1e963298..87147ce55 100644 --- a/datadictionary/group_field_def_test.go +++ b/datadictionary/group_field_def_test.go @@ -3,8 +3,9 @@ package datadictionary_test import ( "testing" - "github.com/quickfixgo/quickfix/datadictionary" "github.com/stretchr/testify/assert" + + "github.com/quickfixgo/quickfix/datadictionary" ) func TestNewGroupField(t *testing.T) { diff --git a/datadictionary/message_def_test.go b/datadictionary/message_def_test.go index 666751fed..73de16452 100644 --- a/datadictionary/message_def_test.go +++ b/datadictionary/message_def_test.go @@ -3,8 +3,9 @@ package datadictionary_test import ( "testing" - "github.com/quickfixgo/quickfix/datadictionary" "github.com/stretchr/testify/assert" + + "github.com/quickfixgo/quickfix/datadictionary" ) func TestNewMessageDef(t *testing.T) { diff --git a/dialer.go b/dialer.go index 1645076bf..1eb28ad04 100644 --- a/dialer.go +++ b/dialer.go @@ -4,8 +4,9 @@ import ( "fmt" "net" - "github.com/quickfixgo/quickfix/config" "golang.org/x/net/proxy" + + "github.com/quickfixgo/quickfix/config" ) func loadDialerConfig(settings *SessionSettings) (dialer proxy.Dialer, err error) { diff --git a/dialer_test.go b/dialer_test.go index 510c18f04..47bf79d08 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -5,8 +5,9 @@ import ( "testing" "time" - "github.com/quickfixgo/quickfix/config" "github.com/stretchr/testify/suite" + + "github.com/quickfixgo/quickfix/config" ) type DialerTestSuite struct { diff --git a/errors.go b/errors.go index af949ebd4..19ca16e50 100644 --- a/errors.go +++ b/errors.go @@ -33,6 +33,7 @@ type MessageRejectError interface { //RejectReason, tag 373 for session rejects, tag 380 for business rejects. RejectReason() int + BusinessRejectRefID() string RefTagID() *Tag IsBusinessReject() bool } @@ -50,20 +51,25 @@ func (RejectLogon) RefTagID() *Tag { return nil } //RejectReason implements MessageRejectError func (RejectLogon) RejectReason() int { return 0 } +//BusinessRejectRefID implements MessageRejectError +func (RejectLogon) BusinessRejectRefID() string { return "" } + //IsBusinessReject implements MessageRejectError func (RejectLogon) IsBusinessReject() bool { return false } type messageRejectError struct { - rejectReason int - text string - refTagID *Tag - isBusinessReject bool + rejectReason int + text string + businessRejectRefID string + refTagID *Tag + isBusinessReject bool } -func (e messageRejectError) Error() string { return e.text } -func (e messageRejectError) RefTagID() *Tag { return e.refTagID } -func (e messageRejectError) RejectReason() int { return e.rejectReason } -func (e messageRejectError) IsBusinessReject() bool { return e.isBusinessReject } +func (e messageRejectError) Error() string { return e.text } +func (e messageRejectError) RefTagID() *Tag { return e.refTagID } +func (e messageRejectError) RejectReason() int { return e.rejectReason } +func (e messageRejectError) BusinessRejectRefID() string { return e.businessRejectRefID } +func (e messageRejectError) IsBusinessReject() bool { return e.isBusinessReject } //NewMessageRejectError returns a MessageRejectError with the given error message, reject reason, and optional reftagid func NewMessageRejectError(err string, rejectReason int, refTagID *Tag) MessageRejectError { @@ -76,6 +82,12 @@ func NewBusinessMessageRejectError(err string, rejectReason int, refTagID *Tag) return messageRejectError{text: err, rejectReason: rejectReason, refTagID: refTagID, isBusinessReject: true} } +//NewBusinessMessageRejectErrorWithRefID returns a MessageRejectError with the given error mesage, reject reason, refID, and optional reftagid. +//Reject is treated as a business level reject +func NewBusinessMessageRejectErrorWithRefID(err string, rejectReason int, businessRejectRefID string, refTagID *Tag) MessageRejectError { + return messageRejectError{text: err, rejectReason: rejectReason, refTagID: refTagID, businessRejectRefID: businessRejectRefID, isBusinessReject: true} +} + //IncorrectDataFormatForValue returns an error indicating a field that cannot be parsed as the type required. func IncorrectDataFormatForValue(tag Tag) MessageRejectError { return NewMessageRejectError("Incorrect data format for value", rejectReasonIncorrectDataFormatForValue, &tag) diff --git a/errors_test.go b/errors_test.go index 56554af51..e612d3242 100644 --- a/errors_test.go +++ b/errors_test.go @@ -52,6 +52,33 @@ func TestNewBusinessMessageRejectError(t *testing.T) { } } +func TestNewBusinessMessageRejectErrorWithRefID(t *testing.T) { + var ( + expectedErrorString = "Custom error" + expectedRejectReason = 5 + expectedbusinessRejectRefID = "1" + expectedRefTagID Tag = 44 + expectedIsBusinessReject = true + ) + msgRej := NewBusinessMessageRejectErrorWithRefID(expectedErrorString, expectedRejectReason, expectedbusinessRejectRefID, &expectedRefTagID) + + if strings.Compare(msgRej.Error(), expectedErrorString) != 0 { + t.Errorf("expected: %s, got: %s\n", expectedErrorString, msgRej.Error()) + } + if msgRej.RejectReason() != expectedRejectReason { + t.Errorf("expected: %d, got: %d\n", expectedRejectReason, msgRej.RejectReason()) + } + if strings.Compare(msgRej.BusinessRejectRefID(), expectedbusinessRejectRefID) != 0 { + t.Errorf("expected: %s, got: %s\n", expectedbusinessRejectRefID, msgRej.BusinessRejectRefID()) + } + if *msgRej.RefTagID() != expectedRefTagID { + t.Errorf("expected: %d, got: %d\n", expectedRefTagID, msgRej.RefTagID()) + } + if msgRej.IsBusinessReject() != expectedIsBusinessReject { + t.Error("Expected IsBusinessReject to be true\n") + } +} + func TestIncorrectDataFormatForValue(t *testing.T) { var ( expectedErrorString = "Incorrect data format for value" diff --git a/field.go b/field.go index 044e31418..43b9e4d86 100644 --- a/field.go +++ b/field.go @@ -2,13 +2,14 @@ package quickfix //FieldValueWriter is an interface for writing field values type FieldValueWriter interface { - //Writes out the contents of the FieldValue to a []byte + //Write writes out the contents of the FieldValue to a []byte Write() []byte } //FieldValueReader is an interface for reading field values type FieldValueReader interface { - //Reads the contents of the []byte into FieldValue. Returns an error if there are issues in the data processing + //Read reads the contents of the []byte into FieldValue. + //Returns an error if there are issues in the data processing Read([]byte) error } diff --git a/field_map.go b/field_map.go index 4284fc280..4bf02e6ab 100644 --- a/field_map.go +++ b/field_map.go @@ -3,6 +3,7 @@ package quickfix import ( "bytes" "sort" + "sync" "time" ) @@ -39,6 +40,7 @@ func (t tagSort) Less(i, j int) bool { return t.compare(t.tags[i], t.tags[j]) } type FieldMap struct { tagLookup map[Tag]field tagSort + rwLock *sync.RWMutex } // ascending tags @@ -49,12 +51,16 @@ func (m *FieldMap) init() { } func (m *FieldMap) initWithOrdering(ordering tagOrder) { + m.rwLock = &sync.RWMutex{} m.tagLookup = make(map[Tag]field) m.compare = ordering } //Tags returns all of the Field Tags in this FieldMap func (m FieldMap) Tags() []Tag { + m.rwLock.RLock() + defer m.rwLock.RUnlock() + tags := make([]Tag, 0, len(m.tagLookup)) for t := range m.tagLookup { tags = append(tags, t) @@ -70,12 +76,18 @@ func (m FieldMap) Get(parser Field) MessageRejectError { //Has returns true if the Tag is present in this FieldMap func (m FieldMap) Has(tag Tag) bool { + m.rwLock.RLock() + defer m.rwLock.RUnlock() + _, ok := m.tagLookup[tag] return ok } //GetField parses of a field with Tag tag. Returned reject may indicate the field is not present, or the field value is invalid. func (m FieldMap) GetField(tag Tag, parser FieldValueReader) MessageRejectError { + m.rwLock.RLock() + defer m.rwLock.RUnlock() + f, ok := m.tagLookup[tag] if !ok { return ConditionallyRequiredFieldMissing(tag) @@ -90,6 +102,9 @@ func (m FieldMap) GetField(tag Tag, parser FieldValueReader) MessageRejectError //GetBytes is a zero-copy GetField wrapper for []bytes fields func (m FieldMap) GetBytes(tag Tag) ([]byte, MessageRejectError) { + m.rwLock.RLock() + defer m.rwLock.RUnlock() + f, ok := m.tagLookup[tag] if !ok { return nil, ConditionallyRequiredFieldMissing(tag) @@ -124,6 +139,9 @@ func (m FieldMap) GetInt(tag Tag) (int, MessageRejectError) { //GetTime is a GetField wrapper for utc timestamp fields func (m FieldMap) GetTime(tag Tag) (t time.Time, err MessageRejectError) { + m.rwLock.RLock() + defer m.rwLock.RUnlock() + bytes, err := m.GetBytes(tag) if err != nil { return @@ -148,6 +166,9 @@ func (m FieldMap) GetString(tag Tag) (string, MessageRejectError) { //GetGroup is a Get function specific to Group Fields. func (m FieldMap) GetGroup(parser FieldGroupReader) MessageRejectError { + m.rwLock.RLock() + defer m.rwLock.RUnlock() + f, ok := m.tagLookup[parser.Tag()] if !ok { return ConditionallyRequiredFieldMissing(parser.Tag()) @@ -193,6 +214,9 @@ func (m *FieldMap) SetString(tag Tag, value string) *FieldMap { //Clear purges all fields from field map func (m *FieldMap) Clear() { + m.rwLock.Lock() + defer m.rwLock.Unlock() + m.tags = m.tags[0:0] for k := range m.tagLookup { delete(m.tagLookup, k) @@ -201,6 +225,9 @@ func (m *FieldMap) Clear() { //CopyInto overwrites the given FieldMap with this one func (m *FieldMap) CopyInto(to *FieldMap) { + m.rwLock.RLock() + defer m.rwLock.RUnlock() + to.tagLookup = make(map[Tag]field) for tag, f := range m.tagLookup { clone := make(field, 1) @@ -213,6 +240,9 @@ func (m *FieldMap) CopyInto(to *FieldMap) { } func (m *FieldMap) add(f field) { + m.rwLock.Lock() + defer m.rwLock.Unlock() + t := fieldTag(f) if _, ok := m.tagLookup[t]; !ok { m.tags = append(m.tags, t) @@ -222,6 +252,9 @@ func (m *FieldMap) add(f field) { } func (m *FieldMap) getOrCreate(tag Tag) field { + m.rwLock.Lock() + defer m.rwLock.Unlock() + if f, ok := m.tagLookup[tag]; ok { f = f[:1] return f @@ -242,6 +275,9 @@ func (m *FieldMap) Set(field FieldWriter) *FieldMap { //SetGroup is a setter specific to group fields func (m *FieldMap) SetGroup(field FieldGroupWriter) *FieldMap { + m.rwLock.Lock() + defer m.rwLock.Unlock() + _, ok := m.tagLookup[field.Tag()] if !ok { m.tags = append(m.tags, field.Tag()) @@ -256,6 +292,9 @@ func (m *FieldMap) sortedTags() []Tag { } func (m FieldMap) write(buffer *bytes.Buffer) { + m.rwLock.Lock() + defer m.rwLock.Unlock() + for _, tag := range m.sortedTags() { if f, ok := m.tagLookup[tag]; ok { writeField(f, buffer) @@ -264,6 +303,9 @@ func (m FieldMap) write(buffer *bytes.Buffer) { } func (m FieldMap) total() int { + m.rwLock.RLock() + defer m.rwLock.RUnlock() + total := 0 for _, fields := range m.tagLookup { for _, tv := range fields { @@ -279,6 +321,9 @@ func (m FieldMap) total() int { } func (m FieldMap) length() int { + m.rwLock.RLock() + defer m.rwLock.RUnlock() + length := 0 for _, fields := range m.tagLookup { for _, tv := range fields { diff --git a/field_map_test.go b/field_map_test.go index c984f7c45..f07f6f396 100644 --- a/field_map_test.go +++ b/field_map_test.go @@ -157,7 +157,7 @@ func TestFieldMap_CopyInto(t *testing.T) { assert.Equal(t, "a", s) // old fields cleared - s, err = fMapB.GetString(3) + _, err = fMapB.GetString(3) assert.NotNil(t, err) // check that ordering is overwritten diff --git a/file_log_test.go b/file_log_test.go index b807dca42..3358dc20a 100644 --- a/file_log_test.go +++ b/file_log_test.go @@ -11,7 +11,7 @@ import ( func TestFileLog_NewFileLogFactory(t *testing.T) { - factory, err := NewFileLogFactory(NewSettings()) + _, err := NewFileLogFactory(NewSettings()) if err == nil { t.Error("Should expect error when settings have no file log path") @@ -39,7 +39,7 @@ SessionQualifier=BS stringReader := strings.NewReader(cfg) settings, _ := ParseSettings(stringReader) - factory, err = NewFileLogFactory(settings) + factory, err := NewFileLogFactory(settings) if err != nil { t.Error("Did not expect error", err) diff --git a/filestore.go b/filestore.go index aa2f7cde3..976273754 100644 --- a/filestore.go +++ b/filestore.go @@ -2,12 +2,15 @@ package quickfix import ( "fmt" + "io" "io/ioutil" "os" "path" "strconv" "time" + "github.com/pkg/errors" + "github.com/quickfixgo/quickfix/config" ) @@ -81,9 +84,12 @@ func newFileStore(sessionID SessionID, dirname string) (*fileStore, error) { // Reset deletes the store files and sets the seqnums back to 1 func (store *fileStore) Reset() error { - store.cache.Reset() + if err := store.cache.Reset(); err != nil { + return errors.Wrap(err, "cache reset") + } + if err := store.Close(); err != nil { - return err + return errors.Wrap(err, "close") } if err := removeFile(store.bodyFname); err != nil { return err @@ -105,7 +111,10 @@ func (store *fileStore) Reset() error { // Refresh closes the store files and then reloads from them func (store *fileStore) Refresh() (err error) { - store.cache.Reset() + if err = store.cache.Reset(); err != nil { + err = errors.Wrap(err, "cache reset") + return + } if err = store.Close(); err != nil { return err @@ -138,8 +147,13 @@ func (store *fileStore) Refresh() (err error) { } } - store.SetNextSenderMsgSeqNum(store.NextSenderMsgSeqNum()) - store.SetNextTargetMsgSeqNum(store.NextTargetMsgSeqNum()) + if err := store.SetNextSenderMsgSeqNum(store.NextSenderMsgSeqNum()); err != nil { + return errors.Wrap(err, "set next sender") + } + + if err := store.SetNextTargetMsgSeqNum(store.NextTargetMsgSeqNum()); err != nil { + return errors.Wrap(err, "set next target") + } return nil } @@ -166,13 +180,17 @@ func (store *fileStore) populateCache() (creationTimePopulated bool, err error) if senderSeqNumBytes, err := ioutil.ReadFile(store.senderSeqNumsFname); err == nil { if senderSeqNum, err := strconv.Atoi(string(senderSeqNumBytes)); err == nil { - store.cache.SetNextSenderMsgSeqNum(senderSeqNum) + if err = store.cache.SetNextSenderMsgSeqNum(senderSeqNum); err != nil { + return creationTimePopulated, errors.Wrap(err, "cache set next sender") + } } } if targetSeqNumBytes, err := ioutil.ReadFile(store.targetSeqNumsFname); err == nil { if targetSeqNum, err := strconv.Atoi(string(targetSeqNumBytes)); err == nil { - store.cache.SetNextTargetMsgSeqNum(targetSeqNum) + if err = store.cache.SetNextTargetMsgSeqNum(targetSeqNum); err != nil { + return creationTimePopulated, errors.Wrap(err, "cache set next target") + } } } @@ -180,7 +198,7 @@ func (store *fileStore) populateCache() (creationTimePopulated bool, err error) } func (store *fileStore) setSession() error { - if _, err := store.sessionFile.Seek(0, os.SEEK_SET); err != nil { + if _, err := store.sessionFile.Seek(0, io.SeekStart); err != nil { return fmt.Errorf("unable to rewind file: %s: %s", store.sessionFname, err.Error()) } @@ -198,7 +216,7 @@ func (store *fileStore) setSession() error { } func (store *fileStore) setSeqNum(f *os.File, seqNum int) error { - if _, err := f.Seek(0, os.SEEK_SET); err != nil { + if _, err := f.Seek(0, io.SeekStart); err != nil { return fmt.Errorf("unable to rewind file: %s: %s", f.Name(), err.Error()) } if _, err := fmt.Fprintf(f, "%019d", seqNum); err != nil { @@ -222,25 +240,33 @@ func (store *fileStore) NextTargetMsgSeqNum() int { // SetNextSenderMsgSeqNum sets the next MsgSeqNum that will be sent func (store *fileStore) SetNextSenderMsgSeqNum(next int) error { - store.cache.SetNextSenderMsgSeqNum(next) + if err := store.cache.SetNextSenderMsgSeqNum(next); err != nil { + return errors.Wrap(err, "cache") + } return store.setSeqNum(store.senderSeqNumsFile, next) } // SetNextTargetMsgSeqNum sets the next MsgSeqNum that should be received func (store *fileStore) SetNextTargetMsgSeqNum(next int) error { - store.cache.SetNextTargetMsgSeqNum(next) + if err := store.cache.SetNextTargetMsgSeqNum(next); err != nil { + return errors.Wrap(err, "cache") + } return store.setSeqNum(store.targetSeqNumsFile, next) } // IncrNextSenderMsgSeqNum increments the next MsgSeqNum that will be sent func (store *fileStore) IncrNextSenderMsgSeqNum() error { - store.cache.IncrNextSenderMsgSeqNum() + if err := store.cache.IncrNextSenderMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache") + } return store.setSeqNum(store.senderSeqNumsFile, store.cache.NextSenderMsgSeqNum()) } // IncrNextTargetMsgSeqNum increments the next MsgSeqNum that should be received func (store *fileStore) IncrNextTargetMsgSeqNum() error { - store.cache.IncrNextTargetMsgSeqNum() + if err := store.cache.IncrNextTargetMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache") + } return store.setSeqNum(store.targetSeqNumsFile, store.cache.NextTargetMsgSeqNum()) } @@ -250,11 +276,11 @@ func (store *fileStore) CreationTime() time.Time { } func (store *fileStore) SaveMessage(seqNum int, msg []byte) error { - offset, err := store.bodyFile.Seek(0, os.SEEK_END) + offset, err := store.bodyFile.Seek(0, io.SeekEnd) if err != nil { return fmt.Errorf("unable to seek to end of file: %s: %s", store.bodyFname, err.Error()) } - if _, err := store.headerFile.Seek(0, os.SEEK_END); err != nil { + if _, err := store.headerFile.Seek(0, io.SeekEnd); err != nil { return fmt.Errorf("unable to seek to end of file: %s: %s", store.headerFname, err.Error()) } if _, err := fmt.Fprintf(store.headerFile, "%d,%d,%d\n", seqNum, offset, len(msg)); err != nil { diff --git a/fileutil.go b/fileutil.go index 9fa11c964..5334f271c 100644 --- a/fileutil.go +++ b/fileutil.go @@ -4,6 +4,8 @@ import ( "fmt" "os" "strings" + + "github.com/pkg/errors" ) func sessionIDFilenamePrefix(s SessionID) string { @@ -44,9 +46,8 @@ func closeFile(f *os.File) error { // removeFile behaves like os.Remove, except that no error is returned if the file does not exist func removeFile(fname string) error { - err := os.Remove(fname) - if (err != nil) && !os.IsNotExist(err) { - return err + if err := os.Remove(fname); (err != nil) && !os.IsNotExist(err) { + return errors.Wrapf(err, "remove %v", fname) } return nil } diff --git a/fileutil_test.go b/fileutil_test.go index 4817c7862..f634651df 100644 --- a/fileutil_test.go +++ b/fileutil_test.go @@ -53,6 +53,7 @@ func TestOpenOrCreateFile(t *testing.T) { // Then it should be created f, err := openOrCreateFile(fname, 0664) + require.Nil(t, err) requireFileExists(t, fname) // When the file already exists diff --git a/fix_int_test.go b/fix_int_test.go index 173cb5ee5..64142c045 100644 --- a/fix_int_test.go +++ b/fix_int_test.go @@ -32,6 +32,6 @@ func BenchmarkFIXInt_Read(b *testing.B) { var field FIXInt for i := 0; i < b.N; i++ { - field.Read(intBytes) + _ = field.Read(intBytes) } } diff --git a/go.mod b/go.mod index 7a1c07b7d..a47647044 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,19 @@ module github.com/quickfixgo/quickfix -go 1.13 +go 1.18 require ( + github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 - github.com/mattn/go-sqlite3 v1.11.0 - github.com/shopspring/decimal v0.0.0-20190905144223-a36b5d85f337 - github.com/stretchr/testify v1.4.0 - golang.org/x/net v0.0.0-20190926025831-c00fd9afed17 + github.com/mattn/go-sqlite3 v1.14.14 + github.com/pkg/errors v0.9.1 + github.com/shopspring/decimal v1.3.1 + github.com/stretchr/testify v1.8.0 + golang.org/x/net v0.0.0-20220708220712-1185a9018129 + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.4.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ee9f475a9..d567c931c 100644 --- a/go.sum +++ b/go.sum @@ -1,23 +1,54 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a h1:AP/vsCIvJZ129pdm9Ek7bH7yutN3hByqsMoNrWAxRQc= +github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a/go.mod h1:QmP9hvJ91BbJmGVGSbutW19IC0Q9phDCLGaomwTJbgU= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 h1:DujepqpGd1hyOd7aW59XpK7Qymp8iy83xq74fLr21is= github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= -github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q= -github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-sqlite3 v1.14.14 h1:qZgc/Rwetq+MtyE18WhzjokPD93dNqLGNT3QJuLvBGw= +github.com/mattn/go-sqlite3 v1.14.14/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/shopspring/decimal v0.0.0-20190905144223-a36b5d85f337 h1:Da9XEUfFxgyDOqUfwgoTDcWzmnlOnCGi6i4iPS+8Fbw= -github.com/shopspring/decimal v0.0.0-20190905144223-a36b5d85f337/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= -github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= +github.com/quickfixgo/enum v0.0.0-20210629025633-9afc8539baba h1:ysNRAW5kAhJ76Wo6LMAtbuFq/64s2E5KK00cS/rR/Po= +github.com/quickfixgo/enum v0.0.0-20210629025633-9afc8539baba/go.mod h1:65gdG2/8vr6uOYcjZBObVHMuTEYc5rr/+aKVWTrFIrQ= +github.com/quickfixgo/field v0.0.0-20171007195410-74cea5ec78c7 h1:a/qsvkJNoj1vcSFTzgLqNcwTRuiM1VjchoRjDOIMNyY= +github.com/quickfixgo/field v0.0.0-20171007195410-74cea5ec78c7/go.mod h1:7kiKeQwJLOrwVXqvt2GAnk4vOH0jErrB3qO6SERmq7c= +github.com/quickfixgo/fix40 v0.0.0-20171007200002-cce875b2c2e7 h1:csHnaP2l65lrchDUvpk2LbA7BF23wsl4aqFqGVX8r10= +github.com/quickfixgo/fix40 v0.0.0-20171007200002-cce875b2c2e7/go.mod h1:RC0yl+6EULF8t40eFen3UdouFVdZu1uVwMsk7O/6i5s= +github.com/quickfixgo/fix41 v0.0.0-20171007212429-b272ca343ed2 h1:dofXBfr8IrzTXvHyu0gV9KIgzBzjVyBd5rmxERw50rE= +github.com/quickfixgo/fix41 v0.0.0-20171007212429-b272ca343ed2/go.mod h1:2CARkVxpb7YV3Ib6qRjI2UFzvlIyhP0lx/nN6GDoZ7w= +github.com/quickfixgo/fix42 v0.0.0-20171007212724-86a4567f1c77 h1:mD360ECTwAK4jbOIrqAH2MdJF3RrH5KuboFNdxSmIDM= +github.com/quickfixgo/fix42 v0.0.0-20171007212724-86a4567f1c77/go.mod h1:RyVaOPb/+NasHAt2e5UJ9PjXT+5AeqyKZfNPOB4UspE= +github.com/quickfixgo/fix43 v0.0.0-20171007213001-a7ff4f2a2470 h1:1bkIwYMs5XrjvKouIiqn2Iiw46XKcNTS1as9hCZqnMg= +github.com/quickfixgo/fix43 v0.0.0-20171007213001-a7ff4f2a2470/go.mod h1:qpAUIIjRXEQRtuMpJolR+8VkDajoU8J7Iw55Gnwm15s= +github.com/quickfixgo/fix44 v0.0.0-20171007213039-f090a1006218 h1:zrm7CRhis2ArB/xjOj0EQJOuHq5fI0JpS4AILtjwUug= +github.com/quickfixgo/fix44 v0.0.0-20171007213039-f090a1006218/go.mod h1:KFN4LkI1sidKgWnUvmpDdvsa+aNvcbExtS8iPQvA9ys= +github.com/quickfixgo/fixt11 v0.0.0-20171007213433-d9788ca97f5d h1:PlymcwOkKZXnOI3uJZu0lpnvnweTkML9YxnTuYhhzIM= +github.com/quickfixgo/fixt11 v0.0.0-20171007213433-d9788ca97f5d/go.mod h1:/oN4Arv+/8zKshQTj+ggpWfjXuVv9bMUO93zOO6R+oM= +github.com/quickfixgo/tag v0.0.0-20171007194743-cbb465760521 h1:RXfjXtjvXb4wgzBHTTFbyW5uBP04Nrlky9e9lAe8mIE= +github.com/quickfixgo/tag v0.0.0-20171007194743-cbb465760521/go.mod h1:EKAI2kkSaIuSywW0WbIgXIcuA9vS4IXfCga9U9Oax2E= +github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20190926025831-c00fd9afed17 h1:qPnAdmjNA41t3QBTx2mFGf/SD1IoslhYu7AmdsVzCcs= -golang.org/x/net v0.0.0-20190926025831-c00fd9afed17/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +golang.org/x/net v0.0.0-20220708220712-1185a9018129 h1:vucSRfWwTsoXro7P+3Cjlr6flUMtzCwzlvkxEQtHHB0= +golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/in_session.go b/in_session.go index d4dfc701d..3ca45d7bd 100644 --- a/in_session.go +++ b/in_session.go @@ -239,7 +239,7 @@ func (state inSession) resendMessages(session *session, beginSeqNo, endSeqNo int session.log.OnEventf("Resending Message: %v", sentMessageSeqNum) msgBytes = msg.build() - session.sendBytes(msgBytes) + session.EnqueueBytesAndSend(msgBytes) seqNum = sentMessageSeqNum + 1 nextSeqNum = seqNum @@ -382,7 +382,7 @@ func (state *inSession) generateSequenceReset(session *session, beginSeqNo int, msgBytes := sequenceReset.build() - session.sendBytes(msgBytes) + session.EnqueueBytesAndSend(msgBytes) session.log.OnEventf("Sent SequenceReset TO: %v", endSeqNo) return diff --git a/in_session_test.go b/in_session_test.go index c8c0b1523..80e827dec 100644 --- a/in_session_test.go +++ b/in_session_test.go @@ -4,8 +4,9 @@ import ( "testing" "time" - "github.com/quickfixgo/quickfix/internal" "github.com/stretchr/testify/suite" + + "github.com/quickfixgo/quickfix/internal" ) type InSessionTestSuite struct { diff --git a/initiator.go b/initiator.go index 030ea9570..3b6672395 100644 --- a/initiator.go +++ b/initiator.go @@ -109,7 +109,7 @@ func (i *Initiator) waitForInSessionTime(session *session) bool { return true } -//watiForReconnectInterval returns true if a reconnect should be re-attempted, false if handler should stop +//waitForReconnectInterval returns true if a reconnect should be re-attempted, false if handler should stop func (i *Initiator) waitForReconnectInterval(reconnectInterval time.Duration) bool { select { case <-time.After(reconnectInterval): diff --git a/internal/buffer_pool.go b/internal/buffer_pool.go deleted file mode 100644 index f0be7c7e5..000000000 --- a/internal/buffer_pool.go +++ /dev/null @@ -1,37 +0,0 @@ -package internal - -import ( - "bytes" - "sync" -) - -//BufferPool is an concurrently safe pool for byte buffers. Used to constructing inbound messages and writing outbound messages -type BufferPool struct { - b []*bytes.Buffer - sync.Mutex -} - -//Get returns a buffer from the pool, or creates a new buffer if the pool is empty -func (p *BufferPool) Get() (buf *bytes.Buffer) { - p.Lock() - if len(p.b) > 0 { - buf, p.b = p.b[len(p.b)-1], p.b[:len(p.b)-1] - } else { - buf = new(bytes.Buffer) - } - - p.Unlock() - - return -} - -//Put returns adds a buffer to the pool -func (p *BufferPool) Put(buf *bytes.Buffer) { - if buf == nil { - panic("Nil Buffer inserted into pool") - } - - p.Lock() - p.b = append(p.b, buf) - p.Unlock() -} diff --git a/internal/session_settings.go b/internal/session_settings.go index 4c82b000d..7547c0fa6 100644 --- a/internal/session_settings.go +++ b/internal/session_settings.go @@ -9,6 +9,7 @@ type SessionSettings struct { ResetOnLogout bool ResetOnDisconnect bool HeartBtInt time.Duration + HeartBtIntOverride bool SessionTime *TimeRange InitiateLogon bool ResendRequestChunkSize int diff --git a/internal/time_range.go b/internal/time_range.go index b7fa4569b..e69f2fd21 100644 --- a/internal/time_range.go +++ b/internal/time_range.go @@ -1,8 +1,9 @@ package internal import ( - "errors" "time" + + "github.com/pkg/errors" ) //TimeOfDay represents the time of day @@ -13,8 +14,6 @@ type TimeOfDay struct { const shortForm = "15:04:05" -var errParseTime = errors.New("Time must be in the format HH:MM:SS") - //NewTimeOfDay returns a newly initialized TimeOfDay func NewTimeOfDay(hour, minute, second int) TimeOfDay { d := time.Duration(second)*time.Second + @@ -28,7 +27,7 @@ func NewTimeOfDay(hour, minute, second int) TimeOfDay { func ParseTimeOfDay(str string) (TimeOfDay, error) { t, err := time.Parse(shortForm, str) if err != nil { - return TimeOfDay{}, errParseTime + return TimeOfDay{}, errors.Wrap(err, "time must be in the format HH:MM:SS") } return NewTimeOfDay(t.Clock()), nil diff --git a/internal/time_range_test.go b/internal/time_range_test.go index df392118a..a28a495f5 100644 --- a/internal/time_range_test.go +++ b/internal/time_range_test.go @@ -374,6 +374,7 @@ func TestTimeRangeIsInSameRangeWithDay(t *testing.T) { time1 = time.Date(2004, time.July, 27, 3, 0, 0, 0, time.UTC) time2 = time.Date(2004, time.July, 27, 3, 0, 0, 0, time.UTC) + assert.True(t, NewUTCWeekRange(startTime, endTime, startDay, endDay).IsInSameRange(time1, time2)) time1 = time.Date(2004, time.July, 26, 10, 0, 0, 0, time.UTC) time2 = time.Date(2004, time.July, 27, 3, 0, 0, 0, time.UTC) diff --git a/log.go b/log.go index 6527e92dc..6b1fa04bb 100644 --- a/log.go +++ b/log.go @@ -2,24 +2,24 @@ package quickfix //Log is a generic interface for logging FIX messages and events. type Log interface { - //log incoming fix message + //OnIncoming log incoming fix message OnIncoming([]byte) - //log outgoing fix message + //OnOutgoing log outgoing fix message OnOutgoing([]byte) - //log fix event + //OnEvent log fix event OnEvent(string) - //log fix event according to format specifier + //OnEventf log fix event according to format specifier OnEventf(string, ...interface{}) } //The LogFactory interface creates global and session specific Log instances type LogFactory interface { - //global log + //Create global log Create() (Log, error) - //session specific log + //CreateSessionLog session specific log CreateSessionLog(sessionID SessionID) (Log, error) } diff --git a/logon_state_test.go b/logon_state_test.go index 3ba29af1d..cea45618b 100644 --- a/logon_state_test.go +++ b/logon_state_test.go @@ -5,8 +5,9 @@ import ( "testing" "time" - "github.com/quickfixgo/quickfix/internal" "github.com/stretchr/testify/suite" + + "github.com/quickfixgo/quickfix/internal" ) type LogonStateTestSuite struct { @@ -76,12 +77,14 @@ func (s *LogonStateTestSuite) TestFixMsgInLogon() { s.MockApp.On("FromAdmin").Return(nil) s.MockApp.On("OnLogon") s.MockApp.On("ToAdmin") + s.Zero(s.session.HeartBtInt) s.fixMsgIn(s.session, logon) s.MockApp.AssertExpectations(s.T()) s.State(inSession{}) - s.Equal(32*time.Second, s.session.HeartBtInt) + s.Equal(32*time.Second, s.session.HeartBtInt) //should be written from logon message + s.False(s.session.HeartBtIntOverride) s.LastToAdminMessageSent() s.MessageType(string(msgTypeLogon), s.MockApp.lastToAdmin) @@ -91,6 +94,35 @@ func (s *LogonStateTestSuite) TestFixMsgInLogon() { s.NextSenderMsgSeqNum(3) } +func (s *LogonStateTestSuite) TestFixMsgInLogonHeartBtIntOverride() { + s.IncrNextSenderMsgSeqNum() + s.MessageFactory.seqNum = 1 + s.IncrNextTargetMsgSeqNum() + + logon := s.Logon() + logon.Body.SetField(tagHeartBtInt, FIXInt(32)) + + s.MockApp.On("FromAdmin").Return(nil) + s.MockApp.On("OnLogon") + s.MockApp.On("ToAdmin") + s.session.HeartBtIntOverride = true + s.session.HeartBtInt = time.Second + s.fixMsgIn(s.session, logon) + + s.MockApp.AssertExpectations(s.T()) + + s.State(inSession{}) + s.Equal(time.Second, s.session.HeartBtInt) //should not have changed + s.True(s.session.HeartBtIntOverride) + + s.LastToAdminMessageSent() + s.MessageType(string(msgTypeLogon), s.MockApp.lastToAdmin) + s.FieldEquals(tagHeartBtInt, 1, s.MockApp.lastToAdmin.Body) + + s.NextTargetMsgSeqNum(3) + s.NextSenderMsgSeqNum(3) +} + func (s *LogonStateTestSuite) TestFixMsgInLogonEnableLastMsgSeqNumProcessed() { s.session.EnableLastMsgSeqNumProcessed = true diff --git a/logout_state_test.go b/logout_state_test.go index ce1ae84b6..5b074f329 100644 --- a/logout_state_test.go +++ b/logout_state_test.go @@ -3,8 +3,9 @@ package quickfix import ( "testing" - "github.com/quickfixgo/quickfix/internal" "github.com/stretchr/testify/suite" + + "github.com/quickfixgo/quickfix/internal" ) type LogoutStateTestSuite struct { diff --git a/message.go b/message.go index f800cbebf..d76e5ca44 100644 --- a/message.go +++ b/message.go @@ -114,7 +114,7 @@ func NewMessage() *Message { return m } -// CopyInto erases the dest messages and copies the curreny message content +// CopyInto erases the dest messages and copies the currency message content // into it. func (m *Message) CopyInto(to *Message) { m.Header.CopyInto(&to.Header.FieldMap) diff --git a/message_pool.go b/message_pool.go deleted file mode 100644 index 519062516..000000000 --- a/message_pool.go +++ /dev/null @@ -1,26 +0,0 @@ -package quickfix - -type messagePool struct { - m []*Message -} - -func (p *messagePool) New() *Message { - msg := NewMessage() - return msg -} - -func (p *messagePool) Get() (msg *Message) { - if len(p.m) > 0 { - msg, p.m = p.m[len(p.m)-1], p.m[:len(p.m)-1] - } else { - msg = p.New() - } - - msg.keepMessage = false - - return -} - -func (p *messagePool) Put(msg *Message) { - p.m = append(p.m, msg) -} diff --git a/message_test.go b/message_test.go index 3766c9221..40b063ca1 100644 --- a/message_test.go +++ b/message_test.go @@ -5,8 +5,9 @@ import ( "reflect" "testing" - "github.com/quickfixgo/quickfix/datadictionary" "github.com/stretchr/testify/suite" + + "github.com/quickfixgo/quickfix/datadictionary" ) func BenchmarkParseMessage(b *testing.B) { diff --git a/mongostore.go b/mongostore.go index 76789dbe9..15c419347 100644 --- a/mongostore.go +++ b/mongostore.go @@ -6,6 +6,7 @@ import ( "github.com/globalsign/mgo" "github.com/globalsign/mgo/bson" + "github.com/pkg/errors" "github.com/quickfixgo/quickfix/config" ) @@ -66,7 +67,11 @@ func newMongoStore(sessionID SessionID, mongoURL string, mongoDatabase string, m messagesCollection: messagesCollection, sessionsCollection: sessionsCollection, } - store.cache.Reset() + + if err = store.cache.Reset(); err != nil { + err = errors.Wrap(err, "cache reset") + return + } if store.db, err = mgo.Dial(mongoURL); err != nil { return @@ -139,27 +144,43 @@ func (store *mongoStore) Refresh() error { return store.populateCache() } -func (store *mongoStore) populateCache() (err error) { +func (store *mongoStore) populateCache() error { msgFilter := generateMessageFilter(&store.sessionID) query := store.db.DB(store.mongoDatabase).C(store.sessionsCollection).Find(msgFilter) - if cnt, err := query.Count(); err == nil && cnt > 0 { + cnt, err := query.Count() + if err != nil { + return errors.Wrap(err, "count") + } + + if cnt > 0 { // session record found, load it sessionData := &mongoQuickFixEntryData{} - err = query.One(&sessionData) - if err == nil { - store.cache.creationTime = sessionData.CreationTime - store.cache.SetNextTargetMsgSeqNum(sessionData.IncomingSeqNum) - store.cache.SetNextSenderMsgSeqNum(sessionData.OutgoingSeqNum) + if err = query.One(&sessionData); err != nil { + return errors.Wrap(err, "query one") } - } else if err == nil && cnt == 0 { - // session record not found, create it - msgFilter.CreationTime = store.cache.creationTime - msgFilter.IncomingSeqNum = store.cache.NextTargetMsgSeqNum() - msgFilter.OutgoingSeqNum = store.cache.NextSenderMsgSeqNum() - err = store.db.DB(store.mongoDatabase).C(store.sessionsCollection).Insert(msgFilter) + + store.cache.creationTime = sessionData.CreationTime + if err = store.cache.SetNextTargetMsgSeqNum(sessionData.IncomingSeqNum); err != nil { + return errors.Wrap(err, "cache set next target") + } + + if err = store.cache.SetNextSenderMsgSeqNum(sessionData.OutgoingSeqNum); err != nil { + return errors.Wrap(err, "cache set next sender") + } + + return nil } - return + + // session record not found, create it + msgFilter.CreationTime = store.cache.creationTime + msgFilter.IncomingSeqNum = store.cache.NextTargetMsgSeqNum() + msgFilter.OutgoingSeqNum = store.cache.NextSenderMsgSeqNum() + + if err = store.db.DB(store.mongoDatabase).C(store.sessionsCollection).Insert(msgFilter); err != nil { + return errors.Wrap(err, "insert") + } + return nil } // NextSenderMsgSeqNum returns the next MsgSeqNum that will be sent @@ -200,13 +221,17 @@ func (store *mongoStore) SetNextTargetMsgSeqNum(next int) error { // IncrNextSenderMsgSeqNum increments the next MsgSeqNum that will be sent func (store *mongoStore) IncrNextSenderMsgSeqNum() error { - store.cache.IncrNextSenderMsgSeqNum() + if err := store.cache.IncrNextSenderMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache incr") + } return store.SetNextSenderMsgSeqNum(store.cache.NextSenderMsgSeqNum()) } // IncrNextTargetMsgSeqNum increments the next MsgSeqNum that should be received func (store *mongoStore) IncrNextTargetMsgSeqNum() error { - store.cache.IncrNextTargetMsgSeqNum() + if err := store.cache.IncrNextTargetMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache incr") + } return store.SetNextTargetMsgSeqNum(store.cache.NextTargetMsgSeqNum()) } diff --git a/parser.go b/parser.go index 8816345d4..13b992d03 100644 --- a/parser.go +++ b/parser.go @@ -5,16 +5,12 @@ import ( "errors" "io" "time" - - "github.com/quickfixgo/quickfix/internal" ) const ( defaultBufSize = 4096 ) -var bufferPool internal.BufferPool - type parser struct { //buffer is a slice of bigBuffer bigBuffer, buffer []byte @@ -145,7 +141,7 @@ func (p *parser) ReadMessage() (msgBytes *bytes.Buffer, err error) { return } - msgBytes = bufferPool.Get() + msgBytes = new(bytes.Buffer) msgBytes.Reset() msgBytes.Write(p.buffer[:index]) p.buffer = p.buffer[index:] diff --git a/parser_test.go b/parser_test.go index ab9c086a7..533691986 100644 --- a/parser_test.go +++ b/parser_test.go @@ -13,7 +13,7 @@ func BenchmarkParser_ReadMessage(b *testing.B) { for i := 0; i < b.N; i++ { reader := strings.NewReader(stream) parser := newParser(reader) - parser.ReadMessage() + _, _ = parser.ReadMessage() } } diff --git a/pending_timeout_test.go b/pending_timeout_test.go index a9e4bf8f1..1ad845d89 100644 --- a/pending_timeout_test.go +++ b/pending_timeout_test.go @@ -3,8 +3,9 @@ package quickfix import ( "testing" - "github.com/quickfixgo/quickfix/internal" "github.com/stretchr/testify/suite" + + "github.com/quickfixgo/quickfix/internal" ) type PendingTimeoutTestSuite struct { diff --git a/quickfix_test.go b/quickfix_test.go index 2f511b77d..189aaf21f 100644 --- a/quickfix_test.go +++ b/quickfix_test.go @@ -3,10 +3,11 @@ package quickfix import ( "time" - "github.com/quickfixgo/quickfix/internal" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + + "github.com/quickfixgo/quickfix/internal" ) type QuickFIXSuite struct { diff --git a/repeating_group.go b/repeating_group.go index d3bc4a58c..10b4150fc 100644 --- a/repeating_group.go +++ b/repeating_group.go @@ -11,7 +11,7 @@ type GroupItem interface { //Tag returns the tag identifying this GroupItem Tag() Tag - //Parameter to Read is tagValues. For most fields, only the first tagValue will be required. + //Read Parameter to Read is tagValues. For most fields, only the first tagValue will be required. //The length of the slice extends from the tagValue mapped to the field to be read through the //following fields. This can be useful for GroupItems made up of repeating groups. // @@ -109,17 +109,18 @@ func (f *RepeatingGroup) Add() *Group { //Write returns tagValues for all Items in the repeating group ordered by //Group sequence and Group template order func (f RepeatingGroup) Write() []TagValue { - tvs := make([]TagValue, 1, 1) + tvs := make([]TagValue, 1) tvs[0].init(f.tag, []byte(strconv.Itoa(len(f.groups)))) for _, group := range f.groups { tags := group.sortedTags() - + group.rwLock.RLock() for _, tag := range tags { if fields, ok := group.tagLookup[tag]; ok { tvs = append(tvs, fields...) } } + group.rwLock.RUnlock() } return tvs @@ -144,8 +145,8 @@ func (f RepeatingGroup) groupTagOrder() tagOrder { } return func(i, j Tag) bool { - orderi := math.MaxInt64 - orderj := math.MaxInt64 + orderi := math.MaxInt32 + orderj := math.MaxInt32 if iIndex, ok := tagMap[i]; ok { orderi = iIndex @@ -199,7 +200,9 @@ func (f *RepeatingGroup) Read(tv []TagValue) ([]TagValue, error) { f.groups = append(f.groups, group) } + group.rwLock.Lock() group.tagLookup[tvRange[0].tag] = tvRange + group.rwLock.Unlock() } if len(f.groups) != expectedGroupSize { diff --git a/resend_state.go b/resend_state.go index a89d3f697..d07559b82 100644 --- a/resend_state.go +++ b/resend_state.go @@ -54,9 +54,6 @@ func (s resendState) FixMsgIn(session *session, msg *Message) (nextState session delete(s.messageStash, targetSeqNum) - //return stashed message to pool - session.returnToPool(msg) - nextState = inSession{}.FixMsgIn(session, msg) if !nextState.IsLoggedOn() { return diff --git a/resend_state_test.go b/resend_state_test.go index aa13cc871..4d8184a9d 100644 --- a/resend_state_test.go +++ b/resend_state_test.go @@ -3,8 +3,9 @@ package quickfix import ( "testing" - "github.com/quickfixgo/quickfix/internal" "github.com/stretchr/testify/suite" + + "github.com/quickfixgo/quickfix/internal" ) type resendStateTestSuite struct { diff --git a/session.go b/session.go index 6d6c1aa0c..6ef143f2a 100644 --- a/session.go +++ b/session.go @@ -30,7 +30,7 @@ type session struct { sessionEvent chan internal.Event messageEvent chan bool application Application - validator + Validator stateMachine stateTimer *internal.EventTimer peerTimer *internal.EventTimer @@ -43,7 +43,6 @@ type session struct { transportDataDictionary *datadictionary.DataDictionary appDataDictionary *datadictionary.DataDictionary - messagePool timestampPrecision TimestampPrecision } @@ -342,6 +341,14 @@ func (s *session) dropQueued() { s.toSend = s.toSend[:0] } +func (s *session) EnqueueBytesAndSend(msg []byte) { + s.sendMutex.Lock() + defer s.sendMutex.Unlock() + + s.toSend = append(s.toSend, msg) + s.sendQueued() +} + func (s *session) sendBytes(msg []byte) { s.log.OnOutgoing(msg) s.messageOut <- msg @@ -433,9 +440,11 @@ func (s *session) handleLogon(msg *Message) error { } if !s.InitiateLogon { - var heartBtInt FIXInt - if err := msg.Body.GetField(tagHeartBtInt, &heartBtInt); err == nil { - s.HeartBtInt = time.Duration(heartBtInt) * time.Second + if !s.HeartBtIntOverride { + var heartBtInt FIXInt + if err := msg.Body.GetField(tagHeartBtInt, &heartBtInt); err == nil { + s.HeartBtInt = time.Duration(heartBtInt) * time.Second + } } s.log.OnEvent("Responding to logon request") @@ -506,8 +515,8 @@ func (s *session) verifySelect(msg *Message, checkTooHigh bool, checkTooLow bool } } - if s.validator != nil { - if reject := s.validator.Validate(msg); reject != nil { + if s.Validator != nil { + if reject := s.Validator.Validate(msg); reject != nil { return reject } } @@ -622,6 +631,9 @@ func (s *session) doReject(msg *Message, rej MessageRejectError) error { if rej.IsBusinessReject() { reply.Header.SetField(tagMsgType, FIXString("j")) reply.Body.SetField(tagBusinessRejectReason, FIXInt(rej.RejectReason())) + if refID := rej.BusinessRejectRefID(); refID != "" { + reply.Body.SetField(tagBusinessRejectRefID, FIXString(refID)) + } } else { reply.Header.SetField(tagMsgType, FIXString("3")) switch { @@ -665,14 +677,6 @@ type fixIn struct { receiveTime time.Time } -func (s *session) returnToPool(msg *Message) { - s.messagePool.Put(msg) - if msg.rawMessage != nil { - bufferPool.Put(msg.rawMessage) - msg.rawMessage = nil - } -} - func (s *session) onDisconnect() { s.log.OnEvent("Disconnected") if s.ResetOnDisconnect { @@ -702,6 +706,15 @@ func (s *session) onAdmin(msg interface{}) { return } + if !s.IsSessionTime() { + s.handleDisconnectState(s) + if msg.err != nil { + msg.err <- errors.New("Connection outside of session time") + close(msg.err) + } + return + } + if msg.err != nil { close(msg.err) } diff --git a/session_factory.go b/session_factory.go index f4f3fc0ef..8fc42d3c0 100644 --- a/session_factory.go +++ b/session_factory.go @@ -1,11 +1,12 @@ package quickfix import ( - "errors" "net" "strconv" "time" + "github.com/pkg/errors" + "github.com/quickfixgo/quickfix/config" "github.com/quickfixgo/quickfix/datadictionary" "github.com/quickfixgo/quickfix/internal" @@ -103,14 +104,22 @@ func (f sessionFactory) newSession( } if s.transportDataDictionary, err = datadictionary.Parse(transportDataDictionaryPath); err != nil { + err = errors.Wrapf( + err, "problem parsing XML datadictionary path '%v' for setting '%v", + settings.settings[config.TransportDataDictionary], config.TransportDataDictionary, + ) return } if s.appDataDictionary, err = datadictionary.Parse(appDataDictionaryPath); err != nil { + err = errors.Wrapf( + err, "problem parsing XML datadictionary path '%v' for setting '%v", + settings.settings[config.AppDataDictionary], config.AppDataDictionary, + ) return } - s.validator = &fixtValidator{s.transportDataDictionary, s.appDataDictionary, validatorSettings} + s.Validator = NewValidator(validatorSettings, s.appDataDictionary, s.transportDataDictionary) } } else if settings.HasSetting(config.DataDictionary) { var dataDictionaryPath string @@ -119,10 +128,14 @@ func (f sessionFactory) newSession( } if s.appDataDictionary, err = datadictionary.Parse(dataDictionaryPath); err != nil { + err = errors.Wrapf( + err, "problem parsing XML datadictionary path '%v' for setting '%v", + settings.settings[config.DataDictionary], config.DataDictionary, + ) return } - s.validator = &fixValidator{s.appDataDictionary, validatorSettings} + s.Validator = NewValidator(validatorSettings, s.appDataDictionary, nil) } if settings.HasSetting(config.ResetOnLogon) { @@ -198,10 +211,18 @@ func (f sessionFactory) newSession( var start, end internal.TimeOfDay if start, err = internal.ParseTimeOfDay(startTimeStr); err != nil { + err = errors.Wrapf( + err, "problem parsing time of day '%v' for setting '%v", + settings.settings[config.StartTime], config.StartTime, + ) return } if end, err = internal.ParseTimeOfDay(endTimeStr); err != nil { + err = errors.Wrapf( + err, "problem parsing time of day '%v' for setting '%v", + settings.settings[config.EndTime], config.EndTime, + ) return } @@ -214,6 +235,10 @@ func (f sessionFactory) newSession( loc, err = time.LoadLocation(locStr) if err != nil { + err = errors.Wrapf( + err, "problem parsing time zone '%v' for setting '%v", + settings.settings[config.TimeZone], config.TimeZone, + ) return } } @@ -286,6 +311,8 @@ func (f sessionFactory) newSession( if err = f.buildInitiatorSettings(s, settings); err != nil { return } + } else if err = f.buildAcceptorSettings(s, settings); err != nil { + return } if s.log, err = logFactory.CreateSessionLog(s.sessionID); err != nil { @@ -303,19 +330,20 @@ func (f sessionFactory) newSession( return } +func (f sessionFactory) buildAcceptorSettings(session *session, settings *SessionSettings) error { + if err := f.buildHeartBtIntSettings(session, settings, false); err != nil { + return err + } + return nil +} + func (f sessionFactory) buildInitiatorSettings(session *session, settings *SessionSettings) error { session.InitiateLogon = true - heartBtInt, err := settings.IntSetting(config.HeartBtInt) - if err != nil { + if err := f.buildHeartBtIntSettings(session, settings, true); err != nil { return err } - if heartBtInt <= 0 { - return errors.New("Heartbeat must be greater than zero") - } - session.HeartBtInt = time.Duration(heartBtInt) * time.Second - session.ReconnectInterval = 30 * time.Second if settings.HasSetting(config.ReconnectInterval) { @@ -394,3 +422,23 @@ func (f sessionFactory) configureSocketConnectAddress(session *session, settings i++ } } + +func (f sessionFactory) buildHeartBtIntSettings(session *session, settings *SessionSettings, mustProvide bool) (err error) { + if settings.HasSetting(config.HeartBtIntOverride) { + if session.HeartBtIntOverride, err = settings.BoolSetting(config.HeartBtIntOverride); err != nil { + return + } + } + + if session.HeartBtIntOverride || mustProvide { + var heartBtInt int + if heartBtInt, err = settings.IntSetting(config.HeartBtInt); err != nil { + return + } else if heartBtInt <= 0 { + err = errors.New("Heartbeat must be greater than zero") + return + } + session.HeartBtInt = time.Duration(heartBtInt) * time.Second + } + return +} diff --git a/session_factory_test.go b/session_factory_test.go index a560843f7..a86cfe369 100644 --- a/session_factory_test.go +++ b/session_factory_test.go @@ -4,9 +4,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/suite" + "github.com/quickfixgo/quickfix/config" "github.com/quickfixgo/quickfix/internal" - "github.com/stretchr/testify/suite" ) type SessionFactorySuite struct { @@ -51,6 +52,7 @@ func (s *SessionFactorySuite) TestDefaults() { s.Equal(Millis, session.timestampPrecision) s.Equal(120*time.Second, session.MaxLatency) s.False(session.DisableMessagePersist) + s.False(session.HeartBtIntOverride) } func (s *SessionFactorySuite) TestResetOnLogon() { @@ -129,7 +131,7 @@ func (s *SessionFactorySuite) TestResendRequestChunkSize() { s.Equal(2500, session.ResendRequestChunkSize) s.SessionSettings.Set(config.ResendRequestChunkSize, "notanint") - session, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.NotNil(err) } @@ -358,9 +360,47 @@ func (s *SessionFactorySuite) TestNewSessionBuildInitiators() { s.Equal("127.0.0.1:5000", session.SocketConnectAddress[0]) } +func (s *SessionFactorySuite) TestNewSessionBuildAcceptors() { + s.sessionFactory.BuildInitiators = false + s.SessionSettings.Set(config.HeartBtInt, "34") + + session, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.Nil(err) + s.False(session.InitiateLogon) + s.Zero(session.HeartBtInt) + s.False(session.HeartBtIntOverride) + + s.SessionSettings.Set(config.HeartBtIntOverride, "Y") + session, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.Nil(err) + s.False(session.InitiateLogon) + s.Equal(34*time.Second, session.HeartBtInt) + s.True(session.HeartBtIntOverride) +} + func (s *SessionFactorySuite) TestNewSessionBuildInitiatorsValidHeartBtInt() { s.sessionFactory.BuildInitiators = true + _, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.NotNil(err, "HeartBtInt should be required for acceptors with override defined") + + s.SessionSettings.Set(config.HeartBtInt, "not a number") + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.NotNil(err, "HeartBtInt must be a number") + + s.SessionSettings.Set(config.HeartBtInt, "0") + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.NotNil(err, "HeartBtInt must be greater than zero") + + s.SessionSettings.Set(config.HeartBtInt, "-20") + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.NotNil(err, "HeartBtInt must be greater than zero") +} + +func (s *SessionFactorySuite) TestNewSessionBuildAcceptorsValidHeartBtInt() { + s.sessionFactory.BuildInitiators = false + + s.SessionSettings.Set(config.HeartBtIntOverride, "Y") _, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.NotNil(err, "HeartBtInt should be required for initiators") @@ -518,7 +558,7 @@ func (s *SessionFactorySuite) TestConfigureSocketConnectAddressMulti() { func (s *SessionFactorySuite) TestNewSessionTimestampPrecision() { s.SessionSettings.Set(config.TimeStampPrecision, "blah") - session, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + _, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.NotNil(err) var tests = []struct { @@ -533,7 +573,7 @@ func (s *SessionFactorySuite) TestNewSessionTimestampPrecision() { for _, test := range tests { s.SessionSettings.Set(config.TimeStampPrecision, test.config) - session, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + session, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.Nil(err) s.Equal(session.timestampPrecision, test.precision) @@ -542,19 +582,19 @@ func (s *SessionFactorySuite) TestNewSessionTimestampPrecision() { func (s *SessionFactorySuite) TestNewSessionMaxLatency() { s.SessionSettings.Set(config.MaxLatency, "not a number") - session, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + _, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.NotNil(err, "MaxLatency must be a number") s.SessionSettings.Set(config.MaxLatency, "-20") - session, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.NotNil(err, "MaxLatency must be positive") s.SessionSettings.Set(config.MaxLatency, "0") - session, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.NotNil(err, "MaxLatency must be positive") s.SessionSettings.Set(config.MaxLatency, "20") - session, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + session, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.Nil(err) s.Equal(session.MaxLatency, 20*time.Second) } diff --git a/session_id.go b/session_id.go index a261d3789..9ee621c52 100644 --- a/session_id.go +++ b/session_id.go @@ -2,7 +2,7 @@ package quickfix import "bytes" -// SessionID is a unique identifer of a Session +// SessionID is a unique identifier of a Session type SessionID struct { BeginString, TargetCompID, TargetSubID, TargetLocationID, SenderCompID, SenderSubID, SenderLocationID, Qualifier string } diff --git a/session_settings.go b/session_settings.go index 130bf9667..17889eccc 100644 --- a/session_settings.go +++ b/session_settings.go @@ -100,7 +100,7 @@ func (s *SessionSettings) DurationSetting(setting string) (val time.Duration, er return } -//BoolSetting returns the requested setting parsed as a boolean. Returns an errror if the setting is not set or cannot be parsed as a bool. +//BoolSetting returns the requested setting parsed as a boolean. Returns an error if the setting is not set or cannot be parsed as a bool. func (s SessionSettings) BoolSetting(setting string) (bool, error) { stringVal, err := s.Setting(setting) diff --git a/session_state.go b/session_state.go index c45ec9dda..55e341365 100644 --- a/session_state.go +++ b/session_state.go @@ -22,12 +22,6 @@ func (sm *stateMachine) Start(s *session) { } func (sm *stateMachine) Connect(session *session) { - if !sm.IsSessionTime() { - session.log.OnEvent("Connection outside of session time") - sm.handleDisconnectState(session) - return - } - // No special logon logic needed for FIX Acceptors. if !session.InitiateLogon { sm.setState(session, logonState{}) @@ -74,7 +68,7 @@ func (sm *stateMachine) Incoming(session *session, m fixIn) { session.log.OnIncoming(m.bytes.Bytes()) - msg := session.messagePool.Get() + msg := NewMessage() if err := ParseMessageWithDataDictionary(msg, m.bytes, session.transportDataDictionary, session.appDataDictionary); err != nil { session.log.OnEventf("Msg Parse Error: %v, %q", err.Error(), m.bytes) } else { @@ -82,9 +76,6 @@ func (sm *stateMachine) Incoming(session *session, m fixIn) { sm.fixMsgIn(session, msg) } - if !msg.keepMessage { - session.returnToPool(msg) - } session.peerTimer.Reset(time.Duration(float64(1.2) * float64(session.HeartBtInt))) } @@ -224,7 +215,7 @@ type sessionState interface { //Stop triggers a clean stop Stop(*session) (nextState sessionState) - //debugging convenience + //Stringer debugging convenience fmt.Stringer } diff --git a/session_test.go b/session_test.go index 33c7cf417..293b94805 100644 --- a/session_test.go +++ b/session_test.go @@ -276,8 +276,8 @@ func (s *SessionSuite) TestShouldSendReset() { s.session.ResetOnDisconnect = test.ResetOnDisconnect s.session.ResetOnLogout = test.ResetOnLogout - s.MockStore.SetNextSenderMsgSeqNum(test.NextSenderMsgSeqNum) - s.MockStore.SetNextTargetMsgSeqNum(test.NextTargetMsgSeqNum) + s.Require().Nil(s.MockStore.SetNextSenderMsgSeqNum(test.NextSenderMsgSeqNum)) + s.Require().Nil(s.MockStore.SetNextTargetMsgSeqNum(test.NextTargetMsgSeqNum)) s.Equal(s.shouldSendReset(), test.Expected) } @@ -944,7 +944,7 @@ func (suite *SessionSendTestSuite) TestDropAndSendDropsQueueWithReset() { suite.NoMessageSent() suite.MockApp.On("ToAdmin") - suite.MockStore.Reset() + suite.Require().Nil(suite.MockStore.Reset()) require.Nil(suite.T(), suite.dropAndSend(suite.Logon())) suite.MockApp.AssertExpectations(suite.T()) msg := suite.MockApp.lastToAdmin diff --git a/settings_test.go b/settings_test.go index b9c219760..9ce568e50 100644 --- a/settings_test.go +++ b/settings_test.go @@ -4,10 +4,11 @@ import ( "strings" "testing" - "github.com/quickfixgo/quickfix/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + + "github.com/quickfixgo/quickfix/config" ) func TestSettings_New(t *testing.T) { diff --git a/sqlstore.go b/sqlstore.go index b16b0a74d..a57281dd8 100644 --- a/sqlstore.go +++ b/sqlstore.go @@ -3,8 +3,11 @@ package quickfix import ( "database/sql" "fmt" + "regexp" "time" + "github.com/pkg/errors" + "github.com/quickfixgo/quickfix/config" ) @@ -19,6 +22,27 @@ type sqlStore struct { sqlDataSourceName string sqlConnMaxLifetime time.Duration db *sql.DB + placeholder placeholderFunc +} + +type placeholderFunc func(int) string + +var rePlaceholder = regexp.MustCompile(`\?`) + +func sqlString(raw string, placeholder placeholderFunc) string { + if placeholder == nil { + return raw + } + idx := 0 + return rePlaceholder.ReplaceAllStringFunc(raw, func(s string) string { + new := placeholder(idx) + idx++ + return new + }) +} + +func postgresPlaceholder(i int) string { + return fmt.Sprintf("$%d", i+1) } // NewSQLStoreFactory returns a sql-based implementation of MessageStoreFactory @@ -58,7 +82,14 @@ func newSQLStore(sessionID SessionID, driver string, dataSourceName string, conn sqlDataSourceName: dataSourceName, sqlConnMaxLifetime: connMaxLifetime, } - store.cache.Reset() + if err = store.cache.Reset(); err != nil { + err = errors.Wrap(err, "cache reset") + return + } + + if store.sqlDriver == "postgres" { + store.placeholder = postgresPlaceholder + } if store.db, err = sql.Open(store.sqlDriver, store.sqlDataSourceName); err != nil { return nil, err @@ -78,10 +109,10 @@ func newSQLStore(sessionID SessionID, driver string, dataSourceName string, conn // Reset deletes the store records and sets the seqnums back to 1 func (store *sqlStore) Reset() error { s := store.sessionID - _, err := store.db.Exec(`DELETE FROM messages + _, err := store.db.Exec(sqlString(`DELETE FROM messages WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID) @@ -93,11 +124,11 @@ func (store *sqlStore) Reset() error { return err } - _, err = store.db.Exec(`UPDATE sessions + _, err = store.db.Exec(sqlString(`UPDATE sessions SET creation_time=?, incoming_seqnum=?, outgoing_seqnum=? WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), store.cache.CreationTime(), store.cache.NextTargetMsgSeqNum(), store.cache.NextSenderMsgSeqNum(), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, @@ -114,26 +145,30 @@ func (store *sqlStore) Refresh() error { return store.populateCache() } -func (store *sqlStore) populateCache() (err error) { +func (store *sqlStore) populateCache() error { s := store.sessionID var creationTime time.Time var incomingSeqNum, outgoingSeqNum int - row := store.db.QueryRow(`SELECT creation_time, incoming_seqnum, outgoing_seqnum + row := store.db.QueryRow(sqlString(`SELECT creation_time, incoming_seqnum, outgoing_seqnum FROM sessions WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID) - err = row.Scan(&creationTime, &incomingSeqNum, &outgoingSeqNum) + err := row.Scan(&creationTime, &incomingSeqNum, &outgoingSeqNum) // session record found, load it if err == nil { store.cache.creationTime = creationTime - store.cache.SetNextTargetMsgSeqNum(incomingSeqNum) - store.cache.SetNextSenderMsgSeqNum(outgoingSeqNum) + if err = store.cache.SetNextTargetMsgSeqNum(incomingSeqNum); err != nil { + return errors.Wrap(err, "cache set next target") + } + if err = store.cache.SetNextSenderMsgSeqNum(outgoingSeqNum); err != nil { + return errors.Wrap(err, "cache set next sender") + } return nil } @@ -143,12 +178,12 @@ func (store *sqlStore) populateCache() (err error) { } // session record not found, create it - _, err = store.db.Exec(`INSERT INTO sessions ( + _, err = store.db.Exec(sqlString(`INSERT INTO sessions ( creation_time, incoming_seqnum, outgoing_seqnum, beginstring, session_qualifier, sendercompid, sendersubid, senderlocid, targetcompid, targetsubid, targetlocid) - VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, store.placeholder), store.cache.creationTime, store.cache.NextTargetMsgSeqNum(), store.cache.NextSenderMsgSeqNum(), @@ -172,10 +207,10 @@ func (store *sqlStore) NextTargetMsgSeqNum() int { // SetNextSenderMsgSeqNum sets the next MsgSeqNum that will be sent func (store *sqlStore) SetNextSenderMsgSeqNum(next int) error { s := store.sessionID - _, err := store.db.Exec(`UPDATE sessions SET outgoing_seqnum = ? + _, err := store.db.Exec(sqlString(`UPDATE sessions SET outgoing_seqnum = ? WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), next, s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID) @@ -188,10 +223,10 @@ func (store *sqlStore) SetNextSenderMsgSeqNum(next int) error { // SetNextTargetMsgSeqNum sets the next MsgSeqNum that should be received func (store *sqlStore) SetNextTargetMsgSeqNum(next int) error { s := store.sessionID - _, err := store.db.Exec(`UPDATE sessions SET incoming_seqnum = ? + _, err := store.db.Exec(sqlString(`UPDATE sessions SET incoming_seqnum = ? WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), next, s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID) @@ -203,13 +238,17 @@ func (store *sqlStore) SetNextTargetMsgSeqNum(next int) error { // IncrNextSenderMsgSeqNum increments the next MsgSeqNum that will be sent func (store *sqlStore) IncrNextSenderMsgSeqNum() error { - store.cache.IncrNextSenderMsgSeqNum() + if err := store.cache.IncrNextSenderMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache incr next") + } return store.SetNextSenderMsgSeqNum(store.cache.NextSenderMsgSeqNum()) } // IncrNextTargetMsgSeqNum increments the next MsgSeqNum that should be received func (store *sqlStore) IncrNextTargetMsgSeqNum() error { - store.cache.IncrNextTargetMsgSeqNum() + if err := store.cache.IncrNextTargetMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache incr next") + } return store.SetNextTargetMsgSeqNum(store.cache.NextTargetMsgSeqNum()) } @@ -221,12 +260,12 @@ func (store *sqlStore) CreationTime() time.Time { func (store *sqlStore) SaveMessage(seqNum int, msg []byte) error { s := store.sessionID - _, err := store.db.Exec(`INSERT INTO messages ( + _, err := store.db.Exec(sqlString(`INSERT INTO messages ( msgseqnum, message, beginstring, session_qualifier, sendercompid, sendersubid, senderlocid, targetcompid, targetsubid, targetlocid) - VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, store.placeholder), seqNum, string(msg), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, @@ -238,12 +277,12 @@ func (store *sqlStore) SaveMessage(seqNum int, msg []byte) error { func (store *sqlStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { s := store.sessionID var msgs [][]byte - rows, err := store.db.Query(`SELECT message FROM messages + rows, err := store.db.Query(sqlString(`SELECT message FROM messages WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? AND targetcompid=? AND targetsubid=? AND targetlocid=? AND msgseqnum>=? AND msgseqnum<=? - ORDER BY msgseqnum`, + ORDER BY msgseqnum`, store.placeholder), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID, diff --git a/sqlstore_test.go b/sqlstore_test.go index b252b2c25..da2c8a5a4 100644 --- a/sqlstore_test.go +++ b/sqlstore_test.go @@ -60,6 +60,11 @@ TargetCompID=%s`, sqlDriver, sqlDsn, sessionID.BeginString, sessionID.SenderComp require.Nil(suite.T(), err) } +func (suite *SQLStoreTestSuite) TestSqlPlaceholderReplacement() { + got := sqlString("A ? B ? C ?", postgresPlaceholder) + suite.Equal("A $1 B $2 C $3", got) +} + func (suite *SQLStoreTestSuite) TearDownTest() { suite.msgStore.Close() os.RemoveAll(suite.sqlStoreRootPath) diff --git a/store.go b/store.go index 837bbca13..41b6bc0c0 100644 --- a/store.go +++ b/store.go @@ -1,6 +1,10 @@ package quickfix -import "time" +import ( + "time" + + "github.com/pkg/errors" +) //The MessageStore interface provides methods to record and retrieve messages for resend purposes type MessageStore interface { @@ -107,7 +111,9 @@ type memoryStoreFactory struct{} func (f memoryStoreFactory) Create(sessionID SessionID) (MessageStore, error) { m := new(memoryStore) - m.Reset() + if err := m.Reset(); err != nil { + return m, errors.Wrap(err, "reset") + } return m, nil } diff --git a/store_test.go b/store_test.go index a61ccbb05..7b8f6a2b8 100644 --- a/store_test.go +++ b/store_test.go @@ -30,61 +30,55 @@ func TestMemoryStoreTestSuite(t *testing.T) { suite.Run(t, new(MemoryStoreTestSuite)) } -func (suite *MessageStoreTestSuite) TestMessageStore_SetNextMsgSeqNum_Refresh_IncrNextMsgSeqNum() { - t := suite.T() - +func (s *MessageStoreTestSuite) TestMessageStore_SetNextMsgSeqNum_Refresh_IncrNextMsgSeqNum() { // Given a MessageStore with the following sender and target seqnums - suite.msgStore.SetNextSenderMsgSeqNum(867) - suite.msgStore.SetNextTargetMsgSeqNum(5309) + s.Require().Nil(s.msgStore.SetNextSenderMsgSeqNum(867)) + s.Require().Nil(s.msgStore.SetNextTargetMsgSeqNum(5309)) // When the store is refreshed from its backing store - suite.msgStore.Refresh() + s.Require().Nil(s.msgStore.Refresh()) // Then the sender and target seqnums should still be - assert.Equal(t, 867, suite.msgStore.NextSenderMsgSeqNum()) - assert.Equal(t, 5309, suite.msgStore.NextTargetMsgSeqNum()) + s.Equal(867, s.msgStore.NextSenderMsgSeqNum()) + s.Equal(5309, s.msgStore.NextTargetMsgSeqNum()) // When the sender and target seqnums are incremented - require.Nil(t, suite.msgStore.IncrNextSenderMsgSeqNum()) - require.Nil(t, suite.msgStore.IncrNextTargetMsgSeqNum()) + s.Require().Nil(s.msgStore.IncrNextSenderMsgSeqNum()) + s.Require().Nil(s.msgStore.IncrNextTargetMsgSeqNum()) // Then the sender and target seqnums should be - assert.Equal(t, 868, suite.msgStore.NextSenderMsgSeqNum()) - assert.Equal(t, 5310, suite.msgStore.NextTargetMsgSeqNum()) + s.Equal(868, s.msgStore.NextSenderMsgSeqNum()) + s.Equal(5310, s.msgStore.NextTargetMsgSeqNum()) // When the store is refreshed from its backing store - suite.msgStore.Refresh() + s.Require().Nil(s.msgStore.Refresh()) // Then the sender and target seqnums should still be - assert.Equal(t, 868, suite.msgStore.NextSenderMsgSeqNum()) - assert.Equal(t, 5310, suite.msgStore.NextTargetMsgSeqNum()) + s.Equal(868, s.msgStore.NextSenderMsgSeqNum()) + s.Equal(5310, s.msgStore.NextTargetMsgSeqNum()) } -func (suite *MessageStoreTestSuite) TestMessageStore_Reset() { - t := suite.T() - +func (s *MessageStoreTestSuite) TestMessageStore_Reset() { // Given a MessageStore with the following sender and target seqnums - suite.msgStore.SetNextSenderMsgSeqNum(1234) - suite.msgStore.SetNextTargetMsgSeqNum(5678) + s.Require().Nil(s.msgStore.SetNextSenderMsgSeqNum(1234)) + s.Require().Nil(s.msgStore.SetNextTargetMsgSeqNum(5678)) // When the store is reset - require.Nil(t, suite.msgStore.Reset()) + s.Require().Nil(s.msgStore.Reset()) // Then the sender and target seqnums should be - assert.Equal(t, 1, suite.msgStore.NextSenderMsgSeqNum()) - assert.Equal(t, 1, suite.msgStore.NextTargetMsgSeqNum()) + s.Equal(1, s.msgStore.NextSenderMsgSeqNum()) + s.Equal(1, s.msgStore.NextTargetMsgSeqNum()) // When the store is refreshed from its backing store - suite.msgStore.Refresh() + s.Require().Nil(s.msgStore.Refresh()) // Then the sender and target seqnums should still be - assert.Equal(t, 1, suite.msgStore.NextSenderMsgSeqNum()) - assert.Equal(t, 1, suite.msgStore.NextTargetMsgSeqNum()) + s.Equal(1, s.msgStore.NextSenderMsgSeqNum()) + s.Equal(1, s.msgStore.NextTargetMsgSeqNum()) } -func (suite *MessageStoreTestSuite) TestMessageStore_SaveMessage_GetMessage() { - t := suite.T() - +func (s *MessageStoreTestSuite) TestMessageStore_SaveMessage_GetMessage() { // Given the following saved messages expectedMsgsBySeqNum := map[int]string{ 1: "In the frozen land of Nador", @@ -92,49 +86,49 @@ func (suite *MessageStoreTestSuite) TestMessageStore_SaveMessage_GetMessage() { 3: "and there was much rejoicing", } for seqNum, msg := range expectedMsgsBySeqNum { - require.Nil(t, suite.msgStore.SaveMessage(seqNum, []byte(msg))) + s.Require().Nil(s.msgStore.SaveMessage(seqNum, []byte(msg))) } // When the messages are retrieved from the MessageStore - actualMsgs, err := suite.msgStore.GetMessages(1, 3) - require.Nil(t, err) + actualMsgs, err := s.msgStore.GetMessages(1, 3) + s.Require().Nil(err) // Then the messages should be - require.Len(t, actualMsgs, 3) - assert.Equal(t, expectedMsgsBySeqNum[1], string(actualMsgs[0])) - assert.Equal(t, expectedMsgsBySeqNum[2], string(actualMsgs[1])) - assert.Equal(t, expectedMsgsBySeqNum[3], string(actualMsgs[2])) + s.Require().Len(actualMsgs, 3) + s.Equal(expectedMsgsBySeqNum[1], string(actualMsgs[0])) + s.Equal(expectedMsgsBySeqNum[2], string(actualMsgs[1])) + s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2])) // When the store is refreshed from its backing store - suite.msgStore.Refresh() + s.Require().Nil(s.msgStore.Refresh()) // And the messages are retrieved from the MessageStore - actualMsgs, err = suite.msgStore.GetMessages(1, 3) - require.Nil(t, err) + actualMsgs, err = s.msgStore.GetMessages(1, 3) + s.Require().Nil(err) // Then the messages should still be - require.Len(t, actualMsgs, 3) - assert.Equal(t, expectedMsgsBySeqNum[1], string(actualMsgs[0])) - assert.Equal(t, expectedMsgsBySeqNum[2], string(actualMsgs[1])) - assert.Equal(t, expectedMsgsBySeqNum[3], string(actualMsgs[2])) + s.Require().Len(actualMsgs, 3) + s.Equal(expectedMsgsBySeqNum[1], string(actualMsgs[0])) + s.Equal(expectedMsgsBySeqNum[2], string(actualMsgs[1])) + s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2])) } -func (suite *MessageStoreTestSuite) TestMessageStore_GetMessages_EmptyStore() { +func (s *MessageStoreTestSuite) TestMessageStore_GetMessages_EmptyStore() { // When messages are retrieved from an empty store - messages, err := suite.msgStore.GetMessages(1, 2) - require.Nil(suite.T(), err) + messages, err := s.msgStore.GetMessages(1, 2) + require.Nil(s.T(), err) // Then no messages should be returned - require.Empty(suite.T(), messages, "Did not expect messages from empty store") + require.Empty(s.T(), messages, "Did not expect messages from empty store") } -func (suite *MessageStoreTestSuite) TestMessageStore_GetMessages_VariousRanges() { - t := suite.T() +func (s *MessageStoreTestSuite) TestMessageStore_GetMessages_VariousRanges() { + t := s.T() // Given the following saved messages - require.Nil(t, suite.msgStore.SaveMessage(1, []byte("hello"))) - require.Nil(t, suite.msgStore.SaveMessage(2, []byte("cruel"))) - require.Nil(t, suite.msgStore.SaveMessage(3, []byte("world"))) + require.Nil(t, s.msgStore.SaveMessage(1, []byte("hello"))) + require.Nil(t, s.msgStore.SaveMessage(2, []byte("cruel"))) + require.Nil(t, s.msgStore.SaveMessage(3, []byte("world"))) // When the following requests are made to the store var testCases = []struct { @@ -154,7 +148,7 @@ func (suite *MessageStoreTestSuite) TestMessageStore_GetMessages_VariousRanges() // Then the returned messages should be for _, tc := range testCases { - actualMsgs, err := suite.msgStore.GetMessages(tc.beginSeqNo, tc.endSeqNo) + actualMsgs, err := s.msgStore.GetMessages(tc.beginSeqNo, tc.endSeqNo) require.Nil(t, err) require.Len(t, actualMsgs, len(tc.expectedBytes)) for i, expectedMsg := range tc.expectedBytes { @@ -163,12 +157,12 @@ func (suite *MessageStoreTestSuite) TestMessageStore_GetMessages_VariousRanges() } } -func (suite *MessageStoreTestSuite) TestMessageStore_CreationTime() { - assert.False(suite.T(), suite.msgStore.CreationTime().IsZero()) +func (s *MessageStoreTestSuite) TestMessageStore_CreationTime() { + s.False(s.msgStore.CreationTime().IsZero()) t0 := time.Now() - suite.msgStore.Reset() + s.Require().Nil(s.msgStore.Reset()) t1 := time.Now() - require.True(suite.T(), suite.msgStore.CreationTime().After(t0)) - require.True(suite.T(), suite.msgStore.CreationTime().Before(t1)) + s.Require().True(s.msgStore.CreationTime().After(t0)) + s.Require().True(s.msgStore.CreationTime().Before(t1)) } diff --git a/tag.go b/tag.go index 800375e34..20fcda0b7 100644 --- a/tag.go +++ b/tag.go @@ -43,6 +43,7 @@ const ( tagBusinessRejectReason Tag = 380 tagSessionRejectReason Tag = 373 tagRefMsgType Tag = 372 + tagBusinessRejectRefID Tag = 379 tagRefTagID Tag = 371 tagRefSeqNum Tag = 45 tagEncryptMethod Tag = 98 diff --git a/tls.go b/tls.go index fa81a3808..e8a541c0b 100644 --- a/tls.go +++ b/tls.go @@ -35,50 +35,40 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) } if !settings.HasSetting(config.SocketPrivateKeyFile) && !settings.HasSetting(config.SocketCertificateFile) { - if allowSkipClientCerts { - tlsConfig = defaultTLSConfig() - tlsConfig.ServerName = serverName - tlsConfig.InsecureSkipVerify = insecureSkipVerify + if !allowSkipClientCerts { + return } - return - } - - privateKeyFile, err := settings.Setting(config.SocketPrivateKeyFile) - if err != nil { - return - } - - certificateFile, err := settings.Setting(config.SocketCertificateFile) - if err != nil { - return } tlsConfig = defaultTLSConfig() - tlsConfig.Certificates = make([]tls.Certificate, 1) tlsConfig.ServerName = serverName tlsConfig.InsecureSkipVerify = insecureSkipVerify + setMinVersionExplicit(settings, tlsConfig) - minVersion := "TLS12" - if settings.HasSetting(config.SocketMinimumTLSVersion) { - minVersion, err = settings.Setting(config.SocketMinimumTLSVersion) + if settings.HasSetting(config.SocketPrivateKeyFile) || settings.HasSetting(config.SocketCertificateFile) { + + var privateKeyFile string + var certificateFile string + + privateKeyFile, err = settings.Setting(config.SocketPrivateKeyFile) if err != nil { return } - switch minVersion { - case "SSL30": - tlsConfig.MinVersion = tls.VersionSSL30 - case "TLS10": - tlsConfig.MinVersion = tls.VersionTLS10 - case "TLS11": - tlsConfig.MinVersion = tls.VersionTLS11 - case "TLS12": - tlsConfig.MinVersion = tls.VersionTLS12 + certificateFile, err = settings.Setting(config.SocketCertificateFile) + if err != nil { + return + } + + tlsConfig.Certificates = make([]tls.Certificate, 1) + + if tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(certificateFile, privateKeyFile); err != nil { + return } } - if tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(certificateFile, privateKeyFile); err != nil { - return + if !allowSkipClientCerts { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert } if !settings.HasSetting(config.SocketCAFile) { @@ -103,7 +93,6 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) tlsConfig.RootCAs = certPool tlsConfig.ClientCAs = certPool - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert return } @@ -122,3 +111,24 @@ func defaultTLSConfig() *tls.Config { }, } } + +func setMinVersionExplicit(settings *SessionSettings, tlsConfig *tls.Config) { + if settings.HasSetting(config.SocketMinimumTLSVersion) { + minVersion, err := settings.Setting(config.SocketMinimumTLSVersion) + if err != nil { + return + } + + switch minVersion { + case "SSL30": + //nolint:staticcheck // SA1019 min version ok + tlsConfig.MinVersion = tls.VersionSSL30 + case "TLS10": + tlsConfig.MinVersion = tls.VersionTLS10 + case "TLS11": + tlsConfig.MinVersion = tls.VersionTLS11 + case "TLS12": + tlsConfig.MinVersion = tls.VersionTLS12 + } + } +} diff --git a/tls_test.go b/tls_test.go index a858dc6db..9274b95dd 100644 --- a/tls_test.go +++ b/tls_test.go @@ -4,8 +4,9 @@ import ( "crypto/tls" "testing" - "github.com/quickfixgo/quickfix/config" "github.com/stretchr/testify/suite" + + "github.com/quickfixgo/quickfix/config" ) type TLSTestSuite struct { @@ -60,7 +61,7 @@ func (s *TLSTestSuite) TestLoadTLSNoCA() { s.Len(tlsConfig.Certificates, 1) s.Nil(tlsConfig.RootCAs) s.Nil(tlsConfig.ClientCAs) - s.Equal(tls.NoClientCert, tlsConfig.ClientAuth) + s.Equal(tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth) } func (s *TLSTestSuite) TestLoadTLSWithBadCA() { @@ -87,6 +88,36 @@ func (s *TLSTestSuite) TestLoadTLSWithCA() { s.Equal(tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth) } +func (s *TLSTestSuite) TestLoadTLSWithOnlyCA() { + s.settings.GlobalSettings().Set(config.SocketUseSSL, "Y") + s.settings.GlobalSettings().Set(config.SocketCAFile, s.CAFile) + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + + s.NotNil(tlsConfig.RootCAs) + s.NotNil(tlsConfig.ClientCAs) +} + +func (s *TLSTestSuite) TestLoadTLSWithoutSSLWithOnlyCA() { + s.settings.GlobalSettings().Set(config.SocketCAFile, s.CAFile) + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.Nil(tlsConfig) +} + +func (s *TLSTestSuite) TestLoadTLSAllowSkipClientCerts() { + s.settings.GlobalSettings().Set(config.SocketUseSSL, "Y") + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + + s.Equal(tls.NoClientCert, tlsConfig.ClientAuth) +} + func (s *TLSTestSuite) TestServerNameUseSSL() { s.settings.GlobalSettings().Set(config.SocketUseSSL, "Y") s.settings.GlobalSettings().Set(config.SocketServerName, "DummyServerNameUseSSL") @@ -150,6 +181,7 @@ func (s *TLSTestSuite) TestMinimumTLSVersion() { s.Nil(err) s.NotNil(tlsConfig) + //nolint:staticcheck s.Equal(tlsConfig.MinVersion, uint16(tls.VersionSSL30)) // TLS10 diff --git a/validation.go b/validation.go index dc17abc58..2a3d61dd6 100644 --- a/validation.go +++ b/validation.go @@ -4,31 +4,48 @@ import ( "github.com/quickfixgo/quickfix/datadictionary" ) -type validator interface { +//Validator validates a FIX message +type Validator interface { Validate(*Message) MessageRejectError } -type validatorSettings struct { +//ValidatorSettings describe validation behavior +type ValidatorSettings struct { CheckFieldsOutOfOrder bool RejectInvalidMessage bool } //Default configuration for message validation. //See http://www.quickfixengine.org/quickfix/doc/html/configuration.html. -var defaultValidatorSettings = validatorSettings{ +var defaultValidatorSettings = ValidatorSettings{ CheckFieldsOutOfOrder: true, RejectInvalidMessage: true, } type fixValidator struct { dataDictionary *datadictionary.DataDictionary - settings validatorSettings + settings ValidatorSettings } type fixtValidator struct { transportDataDictionary *datadictionary.DataDictionary appDataDictionary *datadictionary.DataDictionary - settings validatorSettings + settings ValidatorSettings +} + +//NewValidator creates a FIX message validator from the given data dictionaries +func NewValidator(settings ValidatorSettings, appDataDictionary, transportDataDictionary *datadictionary.DataDictionary) Validator { + if transportDataDictionary != nil { + return &fixtValidator{ + transportDataDictionary: transportDataDictionary, + appDataDictionary: appDataDictionary, + settings: settings, + } + } + return &fixValidator{ + dataDictionary: appDataDictionary, + settings: settings, + } } //Validate tests the message against the provided data dictionary. @@ -61,7 +78,7 @@ func (v *fixtValidator) Validate(msg *Message) MessageRejectError { return validateFIXT(v.transportDataDictionary, v.appDataDictionary, v.settings, msgType, msg) } -func validateFIX(d *datadictionary.DataDictionary, settings validatorSettings, msgType string, msg *Message) MessageRejectError { +func validateFIX(d *datadictionary.DataDictionary, settings ValidatorSettings, msgType string, msg *Message) MessageRejectError { if err := validateMsgType(d, msgType, msg); err != nil { return err } @@ -89,7 +106,7 @@ func validateFIX(d *datadictionary.DataDictionary, settings validatorSettings, m return nil } -func validateFIXT(transportDD, appDD *datadictionary.DataDictionary, settings validatorSettings, msgType string, msg *Message) MessageRejectError { +func validateFIXT(transportDD, appDD *datadictionary.DataDictionary, settings ValidatorSettings, msgType string, msg *Message) MessageRejectError { if err := validateMsgType(appDD, msgType, msg); err != nil { return err } @@ -116,7 +133,7 @@ func validateFIXT(transportDD, appDD *datadictionary.DataDictionary, settings va } func validateMsgType(d *datadictionary.DataDictionary, msgType string, msg *Message) MessageRejectError { - if _, validMsgType := d.Messages[msgType]; validMsgType == false { + if _, validMsgType := d.Messages[msgType]; !validMsgType { return InvalidMessageType() } return nil diff --git a/validation_test.go b/validation_test.go index 10a8a7e3f..8224c96c0 100644 --- a/validation_test.go +++ b/validation_test.go @@ -5,13 +5,14 @@ import ( "testing" "time" - "github.com/quickfixgo/quickfix/datadictionary" "github.com/stretchr/testify/assert" + + "github.com/quickfixgo/quickfix/datadictionary" ) type validateTest struct { TestName string - Validator validator + Validator Validator MessageBytes []byte ExpectedRejectReason int ExpectedRefTagID *Tag @@ -126,7 +127,7 @@ func createFIX43NewOrderSingle() *Message { func tcInvalidTagNumberHeader() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) invalidHeaderFieldMessage := createFIX40NewOrderSingle() tag := Tag(9999) invalidHeaderFieldMessage.Header.SetField(tag, FIXString("hello")) @@ -142,7 +143,7 @@ func tcInvalidTagNumberHeader() validateTest { } func tcInvalidTagNumberBody() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) invalidBodyFieldMessage := createFIX40NewOrderSingle() tag := Tag(9999) invalidBodyFieldMessage.Body.SetField(tag, FIXString("hello")) @@ -159,7 +160,7 @@ func tcInvalidTagNumberBody() validateTest { func tcInvalidTagNumberTrailer() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) invalidTrailerFieldMessage := createFIX40NewOrderSingle() tag := Tag(9999) invalidTrailerFieldMessage.Trailer.SetField(tag, FIXString("hello")) @@ -176,7 +177,7 @@ func tcInvalidTagNumberTrailer() validateTest { func tcTagNotDefinedForMessage() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) invalidMsg := createFIX40NewOrderSingle() tag := Tag(41) invalidMsg.Body.SetField(tag, FIXString("hello")) @@ -194,7 +195,7 @@ func tcTagNotDefinedForMessage() validateTest { func tcTagIsDefinedForMessage() validateTest { //compare to tcTagIsNotDefinedForMessage dict, _ := datadictionary.Parse("spec/FIX43.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) validMsg := createFIX43NewOrderSingle() msgBytes := validMsg.build() @@ -208,7 +209,7 @@ func tcTagIsDefinedForMessage() validateTest { func tcFieldNotFoundBody() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) invalidMsg1 := NewMessage() invalidMsg1.Header.SetField(tagMsgType, FIXString("D")). SetField(tagBeginString, FIXString("FIX.4.0")). @@ -242,7 +243,7 @@ func tcFieldNotFoundBody() validateTest { func tcFieldNotFoundHeader() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) invalidMsg2 := NewMessage() invalidMsg2.Trailer.SetField(tagCheckSum, FIXString("000")) @@ -275,7 +276,7 @@ func tcFieldNotFoundHeader() validateTest { func tcTagSpecifiedWithoutAValue() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) builder := createFIX40NewOrderSingle() bogusTag := Tag(109) @@ -293,7 +294,7 @@ func tcTagSpecifiedWithoutAValue() validateTest { func tcInvalidMsgType() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) builder := createFIX40NewOrderSingle() builder.Header.SetField(tagMsgType, FIXString("z")) msgBytes := builder.build() @@ -308,7 +309,7 @@ func tcInvalidMsgType() validateTest { func tcValueIsIncorrect() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) tag := Tag(21) builder := createFIX40NewOrderSingle() @@ -326,7 +327,7 @@ func tcValueIsIncorrect() validateTest { func tcIncorrectDataFormatForValue() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) builder := createFIX40NewOrderSingle() tag := Tag(38) builder.Body.SetField(tag, FIXString("+200.00")) @@ -343,7 +344,7 @@ func tcIncorrectDataFormatForValue() validateTest { func tcTagSpecifiedOutOfRequiredOrderHeader() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) builder := createFIX40NewOrderSingle() tag := tagOnBehalfOfCompID @@ -362,7 +363,7 @@ func tcTagSpecifiedOutOfRequiredOrderHeader() validateTest { func tcTagSpecifiedOutOfRequiredOrderTrailer() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) builder := createFIX40NewOrderSingle() tag := tagSignature @@ -382,8 +383,9 @@ func tcTagSpecifiedOutOfRequiredOrderTrailer() validateTest { func tcInvalidTagCheckDisabled() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} - validator.settings.RejectInvalidMessage = false + customValidatorSettings := defaultValidatorSettings + customValidatorSettings.RejectInvalidMessage = false + validator := NewValidator(customValidatorSettings, dict, nil) builder := createFIX40NewOrderSingle() tag := Tag(9999) @@ -400,8 +402,9 @@ func tcInvalidTagCheckDisabled() validateTest { func tcInvalidTagCheckEnabled() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} - validator.settings.RejectInvalidMessage = true + customValidatorSettings := defaultValidatorSettings + customValidatorSettings.RejectInvalidMessage = true + validator := NewValidator(customValidatorSettings, dict, nil) builder := createFIX40NewOrderSingle() tag := Tag(9999) @@ -419,8 +422,9 @@ func tcInvalidTagCheckEnabled() validateTest { func tcTagSpecifiedOutOfRequiredOrderDisabledHeader() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} - validator.settings.CheckFieldsOutOfOrder = false + customValidatorSettings := defaultValidatorSettings + customValidatorSettings.CheckFieldsOutOfOrder = false + validator := NewValidator(customValidatorSettings, dict, nil) builder := createFIX40NewOrderSingle() tag := tagOnBehalfOfCompID @@ -438,8 +442,9 @@ func tcTagSpecifiedOutOfRequiredOrderDisabledHeader() validateTest { func tcTagSpecifiedOutOfRequiredOrderDisabledTrailer() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} - validator.settings.CheckFieldsOutOfOrder = false + customValidatorSettings := defaultValidatorSettings + customValidatorSettings.CheckFieldsOutOfOrder = false + validator := NewValidator(customValidatorSettings, dict, nil) builder := createFIX40NewOrderSingle() tag := tagSignature @@ -457,7 +462,7 @@ func tcTagSpecifiedOutOfRequiredOrderDisabledTrailer() validateTest { func tcTagAppearsMoreThanOnce() validateTest { dict, _ := datadictionary.Parse("spec/FIX40.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) tag := Tag(40) return validateTest{ @@ -471,7 +476,7 @@ func tcTagAppearsMoreThanOnce() validateTest { func tcFloatValidation() validateTest { dict, _ := datadictionary.Parse("spec/FIX42.xml") - validator := &fixValidator{dict, defaultValidatorSettings} + validator := NewValidator(defaultValidatorSettings, dict, nil) tag := Tag(38) return validateTest{ TestName: "FloatValidation",