diff --git a/app/cmd/client.go b/app/cmd/client.go index c1d04bd786..1e9bca59ff 100644 --- a/app/cmd/client.go +++ b/app/cmd/client.go @@ -21,6 +21,7 @@ import ( "github.com/apernet/hysteria/app/internal/forwarding" "github.com/apernet/hysteria/app/internal/http" + "github.com/apernet/hysteria/app/internal/proxymux" "github.com/apernet/hysteria/app/internal/redirect" "github.com/apernet/hysteria/app/internal/socks5" "github.com/apernet/hysteria/app/internal/tproxy" @@ -531,7 +532,7 @@ func clientSOCKS5(config socks5Config, c client.Client) error { if config.Listen == "" { return configError{Field: "listen", Err: errors.New("listen address is empty")} } - l, err := correctnet.Listen("tcp", config.Listen) + l, err := proxymux.ListenSOCKS(config.Listen) if err != nil { return configError{Field: "listen", Err: err} } @@ -556,7 +557,7 @@ func clientHTTP(config httpConfig, c client.Client) error { if config.Listen == "" { return configError{Field: "listen", Err: errors.New("listen address is empty")} } - l, err := correctnet.Listen("tcp", config.Listen) + l, err := proxymux.ListenHTTP(config.Listen) if err != nil { return configError{Field: "listen", Err: err} } diff --git a/app/internal/proxymux/.mockery.yaml b/app/internal/proxymux/.mockery.yaml new file mode 100644 index 0000000000..7d3fac08a5 --- /dev/null +++ b/app/internal/proxymux/.mockery.yaml @@ -0,0 +1,12 @@ +with-expecter: true +dir: internal/mocks +outpkg: mocks +packages: + net: + interfaces: + Listener: + config: + mockname: MockListener + Conn: + config: + mockname: MockConn diff --git a/app/internal/proxymux/internal/mocks/mock_Conn.go b/app/internal/proxymux/internal/mocks/mock_Conn.go new file mode 100644 index 0000000000..0bcbe657d7 --- /dev/null +++ b/app/internal/proxymux/internal/mocks/mock_Conn.go @@ -0,0 +1,427 @@ +// Code generated by mockery v2.42.2. DO NOT EDIT. + +package mocks + +import ( + net "net" + + mock "github.com/stretchr/testify/mock" + + time "time" +) + +// MockConn is an autogenerated mock type for the Conn type +type MockConn struct { + mock.Mock +} + +type MockConn_Expecter struct { + mock *mock.Mock +} + +func (_m *MockConn) EXPECT() *MockConn_Expecter { + return &MockConn_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with given fields: +func (_m *MockConn) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockConn_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockConn_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockConn_Expecter) Close() *MockConn_Close_Call { + return &MockConn_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockConn_Close_Call) Run(run func()) *MockConn_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockConn_Close_Call) Return(_a0 error) *MockConn_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockConn_Close_Call) RunAndReturn(run func() error) *MockConn_Close_Call { + _c.Call.Return(run) + return _c +} + +// LocalAddr provides a mock function with given fields: +func (_m *MockConn) LocalAddr() net.Addr { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for LocalAddr") + } + + var r0 net.Addr + if rf, ok := ret.Get(0).(func() net.Addr); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Addr) + } + } + + return r0 +} + +// MockConn_LocalAddr_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LocalAddr' +type MockConn_LocalAddr_Call struct { + *mock.Call +} + +// LocalAddr is a helper method to define mock.On call +func (_e *MockConn_Expecter) LocalAddr() *MockConn_LocalAddr_Call { + return &MockConn_LocalAddr_Call{Call: _e.mock.On("LocalAddr")} +} + +func (_c *MockConn_LocalAddr_Call) Run(run func()) *MockConn_LocalAddr_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockConn_LocalAddr_Call) Return(_a0 net.Addr) *MockConn_LocalAddr_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockConn_LocalAddr_Call) RunAndReturn(run func() net.Addr) *MockConn_LocalAddr_Call { + _c.Call.Return(run) + return _c +} + +// Read provides a mock function with given fields: b +func (_m *MockConn) Read(b []byte) (int, error) { + ret := _m.Called(b) + + if len(ret) == 0 { + panic("no return value specified for Read") + } + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func([]byte) (int, error)); ok { + return rf(b) + } + if rf, ok := ret.Get(0).(func([]byte) int); ok { + r0 = rf(b) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockConn_Read_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Read' +type MockConn_Read_Call struct { + *mock.Call +} + +// Read is a helper method to define mock.On call +// - b []byte +func (_e *MockConn_Expecter) Read(b interface{}) *MockConn_Read_Call { + return &MockConn_Read_Call{Call: _e.mock.On("Read", b)} +} + +func (_c *MockConn_Read_Call) Run(run func(b []byte)) *MockConn_Read_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *MockConn_Read_Call) Return(n int, err error) *MockConn_Read_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *MockConn_Read_Call) RunAndReturn(run func([]byte) (int, error)) *MockConn_Read_Call { + _c.Call.Return(run) + return _c +} + +// RemoteAddr provides a mock function with given fields: +func (_m *MockConn) RemoteAddr() net.Addr { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for RemoteAddr") + } + + var r0 net.Addr + if rf, ok := ret.Get(0).(func() net.Addr); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Addr) + } + } + + return r0 +} + +// MockConn_RemoteAddr_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoteAddr' +type MockConn_RemoteAddr_Call struct { + *mock.Call +} + +// RemoteAddr is a helper method to define mock.On call +func (_e *MockConn_Expecter) RemoteAddr() *MockConn_RemoteAddr_Call { + return &MockConn_RemoteAddr_Call{Call: _e.mock.On("RemoteAddr")} +} + +func (_c *MockConn_RemoteAddr_Call) Run(run func()) *MockConn_RemoteAddr_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockConn_RemoteAddr_Call) Return(_a0 net.Addr) *MockConn_RemoteAddr_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockConn_RemoteAddr_Call) RunAndReturn(run func() net.Addr) *MockConn_RemoteAddr_Call { + _c.Call.Return(run) + return _c +} + +// SetDeadline provides a mock function with given fields: t +func (_m *MockConn) SetDeadline(t time.Time) error { + ret := _m.Called(t) + + if len(ret) == 0 { + panic("no return value specified for SetDeadline") + } + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockConn_SetDeadline_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetDeadline' +type MockConn_SetDeadline_Call struct { + *mock.Call +} + +// SetDeadline is a helper method to define mock.On call +// - t time.Time +func (_e *MockConn_Expecter) SetDeadline(t interface{}) *MockConn_SetDeadline_Call { + return &MockConn_SetDeadline_Call{Call: _e.mock.On("SetDeadline", t)} +} + +func (_c *MockConn_SetDeadline_Call) Run(run func(t time.Time)) *MockConn_SetDeadline_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(time.Time)) + }) + return _c +} + +func (_c *MockConn_SetDeadline_Call) Return(_a0 error) *MockConn_SetDeadline_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockConn_SetDeadline_Call) RunAndReturn(run func(time.Time) error) *MockConn_SetDeadline_Call { + _c.Call.Return(run) + return _c +} + +// SetReadDeadline provides a mock function with given fields: t +func (_m *MockConn) SetReadDeadline(t time.Time) error { + ret := _m.Called(t) + + if len(ret) == 0 { + panic("no return value specified for SetReadDeadline") + } + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockConn_SetReadDeadline_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetReadDeadline' +type MockConn_SetReadDeadline_Call struct { + *mock.Call +} + +// SetReadDeadline is a helper method to define mock.On call +// - t time.Time +func (_e *MockConn_Expecter) SetReadDeadline(t interface{}) *MockConn_SetReadDeadline_Call { + return &MockConn_SetReadDeadline_Call{Call: _e.mock.On("SetReadDeadline", t)} +} + +func (_c *MockConn_SetReadDeadline_Call) Run(run func(t time.Time)) *MockConn_SetReadDeadline_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(time.Time)) + }) + return _c +} + +func (_c *MockConn_SetReadDeadline_Call) Return(_a0 error) *MockConn_SetReadDeadline_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockConn_SetReadDeadline_Call) RunAndReturn(run func(time.Time) error) *MockConn_SetReadDeadline_Call { + _c.Call.Return(run) + return _c +} + +// SetWriteDeadline provides a mock function with given fields: t +func (_m *MockConn) SetWriteDeadline(t time.Time) error { + ret := _m.Called(t) + + if len(ret) == 0 { + panic("no return value specified for SetWriteDeadline") + } + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockConn_SetWriteDeadline_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetWriteDeadline' +type MockConn_SetWriteDeadline_Call struct { + *mock.Call +} + +// SetWriteDeadline is a helper method to define mock.On call +// - t time.Time +func (_e *MockConn_Expecter) SetWriteDeadline(t interface{}) *MockConn_SetWriteDeadline_Call { + return &MockConn_SetWriteDeadline_Call{Call: _e.mock.On("SetWriteDeadline", t)} +} + +func (_c *MockConn_SetWriteDeadline_Call) Run(run func(t time.Time)) *MockConn_SetWriteDeadline_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(time.Time)) + }) + return _c +} + +func (_c *MockConn_SetWriteDeadline_Call) Return(_a0 error) *MockConn_SetWriteDeadline_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockConn_SetWriteDeadline_Call) RunAndReturn(run func(time.Time) error) *MockConn_SetWriteDeadline_Call { + _c.Call.Return(run) + return _c +} + +// Write provides a mock function with given fields: b +func (_m *MockConn) Write(b []byte) (int, error) { + ret := _m.Called(b) + + if len(ret) == 0 { + panic("no return value specified for Write") + } + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func([]byte) (int, error)); ok { + return rf(b) + } + if rf, ok := ret.Get(0).(func([]byte) int); ok { + r0 = rf(b) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockConn_Write_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Write' +type MockConn_Write_Call struct { + *mock.Call +} + +// Write is a helper method to define mock.On call +// - b []byte +func (_e *MockConn_Expecter) Write(b interface{}) *MockConn_Write_Call { + return &MockConn_Write_Call{Call: _e.mock.On("Write", b)} +} + +func (_c *MockConn_Write_Call) Run(run func(b []byte)) *MockConn_Write_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *MockConn_Write_Call) Return(n int, err error) *MockConn_Write_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *MockConn_Write_Call) RunAndReturn(run func([]byte) (int, error)) *MockConn_Write_Call { + _c.Call.Return(run) + return _c +} + +// NewMockConn creates a new instance of MockConn. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockConn(t interface { + mock.TestingT + Cleanup(func()) +}) *MockConn { + mock := &MockConn{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/app/internal/proxymux/internal/mocks/mock_Listener.go b/app/internal/proxymux/internal/mocks/mock_Listener.go new file mode 100644 index 0000000000..e4ca2f46cd --- /dev/null +++ b/app/internal/proxymux/internal/mocks/mock_Listener.go @@ -0,0 +1,185 @@ +// Code generated by mockery v2.42.2. DO NOT EDIT. + +package mocks + +import ( + net "net" + + mock "github.com/stretchr/testify/mock" +) + +// MockListener is an autogenerated mock type for the Listener type +type MockListener struct { + mock.Mock +} + +type MockListener_Expecter struct { + mock *mock.Mock +} + +func (_m *MockListener) EXPECT() *MockListener_Expecter { + return &MockListener_Expecter{mock: &_m.Mock} +} + +// Accept provides a mock function with given fields: +func (_m *MockListener) Accept() (net.Conn, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Accept") + } + + var r0 net.Conn + var r1 error + if rf, ok := ret.Get(0).(func() (net.Conn, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() net.Conn); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Conn) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockListener_Accept_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Accept' +type MockListener_Accept_Call struct { + *mock.Call +} + +// Accept is a helper method to define mock.On call +func (_e *MockListener_Expecter) Accept() *MockListener_Accept_Call { + return &MockListener_Accept_Call{Call: _e.mock.On("Accept")} +} + +func (_c *MockListener_Accept_Call) Run(run func()) *MockListener_Accept_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockListener_Accept_Call) Return(_a0 net.Conn, _a1 error) *MockListener_Accept_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockListener_Accept_Call) RunAndReturn(run func() (net.Conn, error)) *MockListener_Accept_Call { + _c.Call.Return(run) + return _c +} + +// Addr provides a mock function with given fields: +func (_m *MockListener) Addr() net.Addr { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Addr") + } + + var r0 net.Addr + if rf, ok := ret.Get(0).(func() net.Addr); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Addr) + } + } + + return r0 +} + +// MockListener_Addr_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Addr' +type MockListener_Addr_Call struct { + *mock.Call +} + +// Addr is a helper method to define mock.On call +func (_e *MockListener_Expecter) Addr() *MockListener_Addr_Call { + return &MockListener_Addr_Call{Call: _e.mock.On("Addr")} +} + +func (_c *MockListener_Addr_Call) Run(run func()) *MockListener_Addr_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockListener_Addr_Call) Return(_a0 net.Addr) *MockListener_Addr_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockListener_Addr_Call) RunAndReturn(run func() net.Addr) *MockListener_Addr_Call { + _c.Call.Return(run) + return _c +} + +// Close provides a mock function with given fields: +func (_m *MockListener) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockListener_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockListener_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockListener_Expecter) Close() *MockListener_Close_Call { + return &MockListener_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockListener_Close_Call) Run(run func()) *MockListener_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockListener_Close_Call) Return(_a0 error) *MockListener_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockListener_Close_Call) RunAndReturn(run func() error) *MockListener_Close_Call { + _c.Call.Return(run) + return _c +} + +// NewMockListener creates a new instance of MockListener. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockListener(t interface { + mock.TestingT + Cleanup(func()) +}) *MockListener { + mock := &MockListener{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/app/internal/proxymux/manager.go b/app/internal/proxymux/manager.go new file mode 100644 index 0000000000..92df918ba0 --- /dev/null +++ b/app/internal/proxymux/manager.go @@ -0,0 +1,72 @@ +package proxymux + +import ( + "net" + "sync" + + "github.com/apernet/hysteria/extras/correctnet" +) + +type muxManager struct { + listeners map[string]*muxListener + lock sync.Mutex +} + +var globalMuxManager *muxManager + +func init() { + globalMuxManager = &muxManager{ + listeners: make(map[string]*muxListener), + } +} + +func (m *muxManager) GetOrCreate(address string) (*muxListener, error) { + key, err := m.canonicalizeAddrPort(address) + if err != nil { + return nil, err + } + + m.lock.Lock() + defer m.lock.Unlock() + + if ml, ok := m.listeners[key]; ok { + return ml, nil + } + + listener, err := correctnet.Listen("tcp", key) + if err != nil { + return nil, err + } + + ml := newMuxListener(listener, func() { + m.lock.Lock() + defer m.lock.Unlock() + delete(m.listeners, key) + }) + m.listeners[key] = ml + return ml, nil +} + +func (m *muxManager) canonicalizeAddrPort(address string) (string, error) { + taddr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + return "", err + } + return taddr.String(), nil +} + +func ListenHTTP(address string) (net.Listener, error) { + ml, err := globalMuxManager.GetOrCreate(address) + if err != nil { + return nil, err + } + return ml.ListenHTTP() +} + +func ListenSOCKS(address string) (net.Listener, error) { + ml, err := globalMuxManager.GetOrCreate(address) + if err != nil { + return nil, err + } + return ml.ListenSOCKS() +} diff --git a/app/internal/proxymux/manager_test.go b/app/internal/proxymux/manager_test.go new file mode 100644 index 0000000000..c776058f1e --- /dev/null +++ b/app/internal/proxymux/manager_test.go @@ -0,0 +1,110 @@ +package proxymux + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestListenSOCKS(t *testing.T) { + address := "127.2.39.129:11081" + + sl, err := ListenSOCKS(address) + if !assert.NoError(t, err) { + return + } + defer func() { + sl.Close() + }() + + hl, err := ListenHTTP(address) + if !assert.NoError(t, err) { + return + } + defer hl.Close() + + _, err = ListenSOCKS(address) + if !assert.ErrorIs(t, err, ErrProtocolInUse) { + return + } + sl.Close() + + sl, err = ListenSOCKS(address) + if !assert.NoError(t, err) { + return + } +} + +func TestListenHTTP(t *testing.T) { + address := "127.2.39.129:11082" + + hl, err := ListenHTTP(address) + if !assert.NoError(t, err) { + return + } + defer func() { + hl.Close() + }() + + sl, err := ListenSOCKS(address) + if !assert.NoError(t, err) { + return + } + defer sl.Close() + + _, err = ListenHTTP(address) + if !assert.ErrorIs(t, err, ErrProtocolInUse) { + return + } + hl.Close() + + hl, err = ListenHTTP(address) + if !assert.NoError(t, err) { + return + } +} + +func TestRelease(t *testing.T) { + address := "127.2.39.129:11083" + + hl, err := ListenHTTP(address) + if !assert.NoError(t, err) { + return + } + sl, err := ListenSOCKS(address) + if !assert.NoError(t, err) { + return + } + + if !assert.True(t, globalMuxManager.testAddressExists(address)) { + return + } + _, err = net.Listen("tcp", address) + if !assert.Error(t, err) { + return + } + + hl.Close() + sl.Close() + + // Wait for muxListener released + time.Sleep(time.Second) + if !assert.False(t, globalMuxManager.testAddressExists(address)) { + return + } + lis, err := net.Listen("tcp", address) + if !assert.NoError(t, err) { + return + } + defer lis.Close() +} + +func (m *muxManager) testAddressExists(address string) bool { + m.lock.Lock() + defer m.lock.Unlock() + + _, ok := m.listeners[address] + return ok +} diff --git a/app/internal/proxymux/mux.go b/app/internal/proxymux/mux.go new file mode 100644 index 0000000000..1f0b7b09dd --- /dev/null +++ b/app/internal/proxymux/mux.go @@ -0,0 +1,320 @@ +package proxymux + +import ( + "errors" + "fmt" + "io" + "net" + "sync" +) + +func newMuxListener(listener net.Listener, deleteFunc func()) *muxListener { + l := &muxListener{ + base: listener, + acceptChan: make(chan net.Conn), + closeChan: make(chan struct{}), + deleteFunc: deleteFunc, + } + go l.acceptLoop() + go l.mainLoop() + return l +} + +type muxListener struct { + lock sync.Mutex + base net.Listener + acceptErr error + + acceptChan chan net.Conn + closeChan chan struct{} + + socksListener *subListener + httpListener *subListener + + deleteFunc func() +} + +func (l *muxListener) acceptLoop() { + defer close(l.acceptChan) + + for { + conn, err := l.base.Accept() + if err != nil { + l.lock.Lock() + l.acceptErr = err + l.lock.Unlock() + return + } + select { + case <-l.closeChan: + return + case l.acceptChan <- conn: + } + } +} + +func (l *muxListener) mainLoop() { + defer func() { + l.deleteFunc() + l.base.Close() + + close(l.closeChan) + + l.lock.Lock() + defer l.lock.Unlock() + + if sl := l.httpListener; sl != nil { + close(sl.acceptChan) + l.httpListener = nil + } + if sl := l.socksListener; sl != nil { + close(sl.acceptChan) + l.socksListener = nil + } + }() + + for { + var socksCloseChan, httpCloseChan chan struct{} + if l.httpListener != nil { + httpCloseChan = l.httpListener.closeChan + } + if l.socksListener != nil { + socksCloseChan = l.socksListener.closeChan + } + select { + case <-l.closeChan: + return + case conn, ok := <-l.acceptChan: + if !ok { + return + } + go l.dispatch(conn) + case <-socksCloseChan: + l.lock.Lock() + if socksCloseChan == l.socksListener.closeChan { + // not replaced by another ListenSOCKS() + l.socksListener = nil + } + l.lock.Unlock() + if l.checkIdle() { + return + } + case <-httpCloseChan: + l.lock.Lock() + if httpCloseChan == l.httpListener.closeChan { + // not replaced by another ListenHTTP() + l.httpListener = nil + } + l.lock.Unlock() + if l.checkIdle() { + return + } + } + } +} + +func (l *muxListener) dispatch(conn net.Conn) { + var b [1]byte + if _, err := io.ReadFull(conn, b[:]); err != nil { + conn.Close() + return + } + + l.lock.Lock() + var target *subListener + if b[0] == 5 { + target = l.socksListener + } else { + target = l.httpListener + } + l.lock.Unlock() + + if target == nil { + conn.Close() + return + } + + wconn := &connWithOneByte{Conn: conn, b: b[0]} + + select { + case <-target.closeChan: + case target.acceptChan <- wconn: + } +} + +func (l *muxListener) checkIdle() bool { + l.lock.Lock() + defer l.lock.Unlock() + + return l.httpListener == nil && l.socksListener == nil +} + +func (l *muxListener) getAndClearAcceptError() error { + l.lock.Lock() + defer l.lock.Unlock() + + if l.acceptErr == nil { + return nil + } + err := l.acceptErr + l.acceptErr = nil + return err +} + +func (l *muxListener) ListenHTTP() (net.Listener, error) { + l.lock.Lock() + defer l.lock.Unlock() + + if l.httpListener != nil { + subListenerPendingClosed := false + select { + case <-l.httpListener.closeChan: + subListenerPendingClosed = true + default: + } + if !subListenerPendingClosed { + return nil, OpErr{ + Addr: l.base.Addr(), + Protocol: "http", + Op: "bind-protocol", + Err: ErrProtocolInUse, + } + } + l.httpListener = nil + } + + select { + case <-l.closeChan: + return nil, net.ErrClosed + default: + } + + sl := newSubListener(l.getAndClearAcceptError, l.base.Addr) + l.httpListener = sl + return sl, nil +} + +func (l *muxListener) ListenSOCKS() (net.Listener, error) { + l.lock.Lock() + defer l.lock.Unlock() + + if l.socksListener != nil { + subListenerPendingClosed := false + select { + case <-l.socksListener.closeChan: + subListenerPendingClosed = true + default: + } + if !subListenerPendingClosed { + return nil, OpErr{ + Addr: l.base.Addr(), + Protocol: "socks", + Op: "bind-protocol", + Err: ErrProtocolInUse, + } + } + l.socksListener = nil + } + + select { + case <-l.closeChan: + return nil, net.ErrClosed + default: + } + + sl := newSubListener(l.getAndClearAcceptError, l.base.Addr) + l.socksListener = sl + return sl, nil +} + +func newSubListener(acceptErrorFunc func() error, addrFunc func() net.Addr) *subListener { + return &subListener{ + acceptChan: make(chan net.Conn), + acceptErrorFunc: acceptErrorFunc, + closeChan: make(chan struct{}), + addrFunc: addrFunc, + } +} + +type subListener struct { + // receive connections or closure from upstream + acceptChan chan net.Conn + // get an error of Accept() from upstream + acceptErrorFunc func() error + // notify upstream that we are closed + closeChan chan struct{} + + // Listener.Addr() implementation of base listener + addrFunc func() net.Addr +} + +func (l *subListener) Accept() (net.Conn, error) { + select { + case <-l.closeChan: + // closed by ourselves + return nil, net.ErrClosed + case conn, ok := <-l.acceptChan: + if !ok { + // closed by upstream + if acceptErr := l.acceptErrorFunc(); acceptErr != nil { + return nil, acceptErr + } + return nil, net.ErrClosed + } + return conn, nil + } +} + +func (l *subListener) Addr() net.Addr { + return l.addrFunc() +} + +// Close implements net.Listener.Close. +// Upstream should use close(l.acceptChan) instead. +func (l *subListener) Close() error { + select { + case <-l.closeChan: + return nil + default: + } + close(l.closeChan) + return nil +} + +// connWithOneByte is a net.Conn that returns b for the first read +// request, then forwards everything else to Conn. +type connWithOneByte struct { + net.Conn + + b byte + bRead bool +} + +func (c *connWithOneByte) Read(bs []byte) (int, error) { + if c.bRead { + return c.Conn.Read(bs) + } + if len(bs) == 0 { + return 0, nil + } + c.bRead = true + bs[0] = c.b + return 1, nil +} + +type OpErr struct { + Addr net.Addr + Protocol string + Op string + Err error +} + +func (m OpErr) Error() string { + return fmt.Sprintf("mux-listen: %s[%s]: %s: %v", m.Addr, m.Protocol, m.Op, m.Err) +} + +func (m OpErr) Unwrap() error { + return m.Err +} + +var ErrProtocolInUse = errors.New("protocol already in use") diff --git a/app/internal/proxymux/mux_test.go b/app/internal/proxymux/mux_test.go new file mode 100644 index 0000000000..46cbf95c9f --- /dev/null +++ b/app/internal/proxymux/mux_test.go @@ -0,0 +1,154 @@ +package proxymux + +import ( + "bytes" + "net" + "sync" + "testing" + "time" + + "github.com/apernet/hysteria/app/internal/proxymux/internal/mocks" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +//go:generate mockery + +func testMockListener(t *testing.T, connChan <-chan net.Conn) net.Listener { + closedChan := make(chan struct{}) + mockListener := mocks.NewMockListener(t) + mockListener.EXPECT().Accept().RunAndReturn(func() (net.Conn, error) { + select { + case <-closedChan: + return nil, net.ErrClosed + case conn, ok := <-connChan: + if !ok { + panic("unexpected closed channel (connChan)") + } + return conn, nil + } + }) + mockListener.EXPECT().Close().RunAndReturn(func() error { + select { + case <-closedChan: + default: + close(closedChan) + } + return nil + }) + return mockListener +} + +func testMockConn(t *testing.T, b []byte) net.Conn { + buf := bytes.NewReader(b) + isClosed := false + mockConn := mocks.NewMockConn(t) + mockConn.EXPECT().Read(mock.Anything).RunAndReturn(func(b []byte) (int, error) { + if isClosed { + return 0, net.ErrClosed + } + return buf.Read(b) + }) + mockConn.EXPECT().Close().RunAndReturn(func() error { + isClosed = true + return nil + }) + return mockConn +} + +func TestMuxHTTP(t *testing.T) { + connChan := make(chan net.Conn) + mockListener := testMockListener(t, connChan) + mockConn := testMockConn(t, []byte("CONNECT example.com:443 HTTP/1.1\r\n\r\n")) + + mux := newMuxListener(mockListener, func() {}) + hl, err := mux.ListenHTTP() + if !assert.NoError(t, err) { + return + } + sl, err := mux.ListenSOCKS() + if !assert.NoError(t, err) { + return + } + + connChan <- mockConn + + var socksConn, httpConn net.Conn + var socksErr, httpErr error + + var wg sync.WaitGroup + wg.Add(2) + go func() { + socksConn, socksErr = sl.Accept() + wg.Done() + }() + go func() { + httpConn, httpErr = hl.Accept() + wg.Done() + }() + + time.Sleep(time.Second) + + sl.Close() + hl.Close() + + wg.Wait() + + assert.Nil(t, socksConn) + assert.ErrorIs(t, socksErr, net.ErrClosed) + assert.NotNil(t, httpConn) + httpConn.Close() + assert.NoError(t, httpErr) + + // Wait for muxListener released + <-mux.acceptChan +} + +func TestMuxSOCKS(t *testing.T) { + connChan := make(chan net.Conn) + mockListener := testMockListener(t, connChan) + mockConn := testMockConn(t, []byte{0x05, 0x02, 0x00, 0x01}) // SOCKS5 Connect Request: NOAUTH+GSSAPI + + mux := newMuxListener(mockListener, func() {}) + hl, err := mux.ListenHTTP() + if !assert.NoError(t, err) { + return + } + sl, err := mux.ListenSOCKS() + if !assert.NoError(t, err) { + return + } + + connChan <- mockConn + + var socksConn, httpConn net.Conn + var socksErr, httpErr error + + var wg sync.WaitGroup + wg.Add(2) + go func() { + socksConn, socksErr = sl.Accept() + wg.Done() + }() + go func() { + httpConn, httpErr = hl.Accept() + wg.Done() + }() + + time.Sleep(time.Second) + + sl.Close() + hl.Close() + + wg.Wait() + + assert.NotNil(t, socksConn) + socksConn.Close() + assert.NoError(t, socksErr) + assert.Nil(t, httpConn) + assert.ErrorIs(t, httpErr, net.ErrClosed) + + // Wait for muxListener released + <-mux.acceptChan +}