diff --git a/client.go b/client.go index d135f504..ff0e41e4 100644 --- a/client.go +++ b/client.go @@ -3,8 +3,10 @@ package steam import ( "bytes" "compress/gzip" + "context" "crypto/rand" "encoding/binary" + "errors" "fmt" "hash/crc32" "io/ioutil" @@ -181,6 +183,58 @@ func (c *Client) ConnectToBind(addr *netutil.PortAddr, local *net.TCPAddr) error return nil } +// ConnectContext connects to a random Steam server and returns its address. +// If this client is already connected, it is disconnected first. +// This method tries to use an address from the Steam Directory and falls +// back to the built-in server list if the Steam Directory can't be reached. +// If you want to connect to a specific server, use `ConnectTo`. Context +// aware. +func (c *Client) ConnectContext(ctx context.Context) (*netutil.PortAddr, error) { + var server *netutil.PortAddr + + // try to initialize the directory cache + if !steamDirectoryCache.IsInitialized() { + _ = steamDirectoryCache.Initialize() + } + if steamDirectoryCache.IsInitialized() { + server = steamDirectoryCache.GetRandomCM() + } else { + server = GetRandomCM() + } + + err := c.ConnectToContext(ctx, server) + return server, err +} + +// ConnectToContext connects to a specific server. +// You may want to use one of the `GetRandom*CM()` functions in this package. +// If this client is already connected, it is disconnected first. Context +// aware. +func (c *Client) ConnectToContext(ctx context.Context, addr *netutil.PortAddr) error { + return c.ConnectToBindContext(ctx, addr, nil) +} + +// ConnectToBindContext connects to a specific server, and binds to a specified local IP +// If this client is already connected, it is disconnected first. Context aware. +func (c *Client) ConnectToBindContext(ctx context.Context, addr *netutil.PortAddr, local *net.TCPAddr) error { + c.Disconnect() + + conn, err := dialTCPContext(ctx, local, addr.ToTCPAddr()) + if err != nil { + if !(errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) { + c.Fatalf("Connect failed: %v", err) + } + return err + } + c.conn = conn + c.writeChan = make(chan protocol.IMsg, 5) + + go c.readLoop() + go c.writeLoop() + + return nil +} + func (c *Client) Disconnect() { c.mutex.Lock() defer c.mutex.Unlock() diff --git a/client_test.go b/client_test.go new file mode 100644 index 00000000..bd69cce3 --- /dev/null +++ b/client_test.go @@ -0,0 +1,77 @@ +package steam + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/Philipp15b/go-steam/v3/netutil" +) + +func TestConnectToContext(t *testing.T) { + + c := NewClient() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for event := range c.Events() { + switch e := event.(type) { + + case nil: + return + + default: + t.Logf("\t\tReceived event of type %T", e) + } + } + }() + + // Times out on dial. + badServer := "162.254.198.104:27018" + + // Should work. + okServer := "155.133.246.52:27018" + + // Fail 5 times with bad server to test for deadlocks. + for i := 0; i < 5; i++ { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500) + + t.Log("Testing [ConnectToContext] on bad server", badServer) + err := c.ConnectToContext( + ctx, + netutil.ParsePortAddr(badServer), + ) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + t.Logf("\t[ConnectToContext] timed out as expected: %v", err) + } else { + t.Fatalf("\t[ConnectToContext] failed: %v", err) + } + } + cancel() + } + + // Connect to OK server. + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + t.Log("Testing [ConnectToContext] on OK server", okServer) + err := c.ConnectToContext( + ctx, + netutil.ParsePortAddr(okServer), + ) + if err != nil { + t.Fatalf("\t[ConnectToContext] failed (is the server still OK?): %v", err) + } + + t.Logf("\t[ConnectToContext] connected without issues") + + time.Sleep(time.Second) + c.Emit(nil) // Stop the event loop. + + wg.Wait() +} diff --git a/connection.go b/connection.go index 5c2db9b4..e7dc5382 100644 --- a/connection.go +++ b/connection.go @@ -1,6 +1,7 @@ package steam import ( + "context" "crypto/aes" "crypto/cipher" "encoding/binary" @@ -40,6 +41,36 @@ func dialTCP(laddr, raddr *net.TCPAddr) (*tcpConnection, error) { }, nil } +// dialTCPContext is the same as dialTCP, but is context aware. +func dialTCPContext(ctx context.Context, laddr, raddr *net.TCPAddr) (*tcpConnection, error) { + + // TODO add DialTCPContext once it's available. + // Currently a workaround. + ok := make(chan struct{}) + var ( + conn *net.TCPConn + err error + ) + go func() { + conn, err = net.DialTCP("tcp", laddr, raddr) + close(ok) + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ok: + } + + if err != nil { + return nil, err + } + + return &tcpConnection{ + conn: conn, + }, nil +} + func (c *tcpConnection) Read() (*protocol.Packet, error) { // All packets begin with a packet length var packetLen uint32