Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context aware Connect() functions #136

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package steam
import (
"bytes"
"compress/gzip"
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
Expand Down Expand Up @@ -181,6 +183,58 @@ func (c *Client) ConnectToBind(addr *netutil.PortAddr, local *net.TCPAddr) error
return nil
}

// ConnectContext connects to a random Steam server and returns its address.
// If this client is already connected, it is disconnected first.
// This method tries to use an address from the Steam Directory and falls
// back to the built-in server list if the Steam Directory can't be reached.
// If you want to connect to a specific server, use `ConnectTo`. Context
// aware.
func (c *Client) ConnectContext(ctx context.Context) (*netutil.PortAddr, error) {
var server *netutil.PortAddr

// try to initialize the directory cache
if !steamDirectoryCache.IsInitialized() {
_ = steamDirectoryCache.Initialize()
}
if steamDirectoryCache.IsInitialized() {
server = steamDirectoryCache.GetRandomCM()
} else {
server = GetRandomCM()
}

err := c.ConnectToContext(ctx, server)
return server, err
}

// ConnectToContext connects to a specific server.
// You may want to use one of the `GetRandom*CM()` functions in this package.
// If this client is already connected, it is disconnected first. Context
// aware.
func (c *Client) ConnectToContext(ctx context.Context, addr *netutil.PortAddr) error {
return c.ConnectToBindContext(ctx, addr, nil)
}

// ConnectToBindContext connects to a specific server, and binds to a specified local IP
// If this client is already connected, it is disconnected first. Context aware.
func (c *Client) ConnectToBindContext(ctx context.Context, addr *netutil.PortAddr, local *net.TCPAddr) error {
c.Disconnect()

conn, err := dialTCPContext(ctx, local, addr.ToTCPAddr())
if err != nil {
if !(errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) {
c.Fatalf("Connect failed: %v", err)
}
return err
}
c.conn = conn
c.writeChan = make(chan protocol.IMsg, 5)

go c.readLoop()
go c.writeLoop()

return nil
}

func (c *Client) Disconnect() {
c.mutex.Lock()
defer c.mutex.Unlock()
Expand Down
77 changes: 77 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package steam

import (
"context"
"errors"
"sync"
"testing"
"time"

"github.com/Philipp15b/go-steam/v3/netutil"
)

func TestConnectToContext(t *testing.T) {

c := NewClient()

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for event := range c.Events() {
switch e := event.(type) {

case nil:
return

default:
t.Logf("\t\tReceived event of type %T", e)
}
}
}()

// Times out on dial.
badServer := "162.254.198.104:27018"

// Should work.
okServer := "155.133.246.52:27018"

// Fail 5 times with bad server to test for deadlocks.
for i := 0; i < 5; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)

t.Log("Testing [ConnectToContext] on bad server", badServer)
err := c.ConnectToContext(
ctx,
netutil.ParsePortAddr(badServer),
)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
t.Logf("\t[ConnectToContext] timed out as expected: %v", err)
} else {
t.Fatalf("\t[ConnectToContext] failed: %v", err)
}
}
cancel()
}

// Connect to OK server.
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

t.Log("Testing [ConnectToContext] on OK server", okServer)
err := c.ConnectToContext(
ctx,
netutil.ParsePortAddr(okServer),
)
if err != nil {
t.Fatalf("\t[ConnectToContext] failed (is the server still OK?): %v", err)
}

t.Logf("\t[ConnectToContext] connected without issues")

time.Sleep(time.Second)
c.Emit(nil) // Stop the event loop.

wg.Wait()
}
31 changes: 31 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package steam

import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/binary"
Expand Down Expand Up @@ -40,6 +41,36 @@ func dialTCP(laddr, raddr *net.TCPAddr) (*tcpConnection, error) {
}, nil
}

// dialTCPContext is the same as dialTCP, but is context aware.
func dialTCPContext(ctx context.Context, laddr, raddr *net.TCPAddr) (*tcpConnection, error) {

// TODO add DialTCPContext once it's available.
// Currently a workaround.
ok := make(chan struct{})
var (
conn *net.TCPConn
err error
)
go func() {
conn, err = net.DialTCP("tcp", laddr, raddr)
close(ok)
}()

select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ok:
}

if err != nil {
return nil, err
}

return &tcpConnection{
conn: conn,
}, nil
}

func (c *tcpConnection) Read() (*protocol.Packet, error) {
// All packets begin with a packet length
var packetLen uint32
Expand Down