diff --git a/config/websockets.php b/config/websockets.php index 8001a3b444..b4256118fe 100644 --- a/config/websockets.php +++ b/config/websockets.php @@ -84,23 +84,12 @@ 'capacity' => null, 'enable_client_messages' => false, 'enable_statistics' => true, + 'allowed_origins' => [ + // + ], ], ], - /* - |-------------------------------------------------------------------------- - | Allowed Origins - |-------------------------------------------------------------------------- - | - | If not empty, you can whitelist certain origins that will be allowed - | to connect to the websocket server. - | - */ - - 'allowed_origins' => [ - // - ], - /* |-------------------------------------------------------------------------- | Maximum Request Size diff --git a/docs/basic-usage/pusher.md b/docs/basic-usage/pusher.md index df6de5d39b..cc0589e4ea 100644 --- a/docs/basic-usage/pusher.md +++ b/docs/basic-usage/pusher.md @@ -74,6 +74,7 @@ You may add additional apps in your `config/websockets.php` file. 'capacity' => null, 'enable_client_messages' => false, 'enable_statistics' => true, + 'allowed_origins' => [], ], ], ``` diff --git a/src/Apps/App.php b/src/Apps/App.php index 980e5546d9..8844079465 100644 --- a/src/Apps/App.php +++ b/src/Apps/App.php @@ -33,6 +33,9 @@ class App /** @var bool */ public $statisticsEnabled = true; + /** @var array */ + public $allowedOrigins = []; + public static function findById($appId) { return app(AppManager::class)->findById($appId); @@ -106,4 +109,11 @@ public function enableStatistics(bool $enabled = true) return $this; } + + public function setAllowedOrigins(array $allowedOrigins) + { + $this->allowedOrigins = $allowedOrigins; + + return $this; + } } diff --git a/src/Apps/ConfigAppManager.php b/src/Apps/ConfigAppManager.php index e3f3217e99..235d89affe 100644 --- a/src/Apps/ConfigAppManager.php +++ b/src/Apps/ConfigAppManager.php @@ -78,7 +78,8 @@ protected function instantiate(?array $appAttributes): ?App $app ->enableClientMessages($appAttributes['enable_client_messages']) ->enableStatistics($appAttributes['enable_statistics']) - ->setCapacity($appAttributes['capacity'] ?? null); + ->setCapacity($appAttributes['capacity'] ?? null) + ->setAllowedOrigins($appAttributes['allowed_origins'] ?? []); return $app; } diff --git a/src/Server/OriginCheck.php b/src/Server/OriginCheck.php deleted file mode 100644 index 5a3bd050c5..0000000000 --- a/src/Server/OriginCheck.php +++ /dev/null @@ -1,60 +0,0 @@ -_component = $component; - - $this->allowedOrigins = $allowedOrigins; - } - - public function onOpen(ConnectionInterface $connection, RequestInterface $request = null) - { - if ($request->hasHeader('Origin')) { - $this->verifyOrigin($connection, $request); - } - - return $this->_component->onOpen($connection, $request); - } - - public function onMessage(ConnectionInterface $from, $msg) - { - return $this->_component->onMessage($from, $msg); - } - - public function onClose(ConnectionInterface $connection) - { - return $this->_component->onClose($connection); - } - - public function onError(ConnectionInterface $connection, \Exception $e) - { - return $this->_component->onError($connection, $e); - } - - protected function verifyOrigin(ConnectionInterface $connection, RequestInterface $request) - { - $header = (string) $request->getHeader('Origin')[0]; - $origin = parse_url($header, PHP_URL_HOST) ?: $header; - - if (! empty($this->allowedOrigins) && ! in_array($origin, $this->allowedOrigins)) { - return $this->close($connection, 403); - } - } -} diff --git a/src/Server/WebSocketServerFactory.php b/src/Server/WebSocketServerFactory.php index 0e4ab4bc9c..bafeaa146a 100644 --- a/src/Server/WebSocketServerFactory.php +++ b/src/Server/WebSocketServerFactory.php @@ -79,11 +79,9 @@ public function createServer(): IoServer $socket = new SecureServer($socket, $this->loop, config('websockets.ssl')); } - $urlMatcher = new UrlMatcher($this->routes, new RequestContext); - - $router = new Router($urlMatcher); - - $app = new OriginCheck($router, config('websockets.allowed_origins', [])); + $app = new Router( + new UrlMatcher($this->routes, new RequestContext) + ); $httpServer = new HttpServer($app, config('websockets.max_request_size_in_kb') * 1024); diff --git a/src/WebSockets/Exceptions/OriginNotAllowed.php b/src/WebSockets/Exceptions/OriginNotAllowed.php new file mode 100644 index 0000000000..aebbe37af7 --- /dev/null +++ b/src/WebSockets/Exceptions/OriginNotAllowed.php @@ -0,0 +1,12 @@ +message = "The origin is not allowed for `{$appKey}`."; + $this->code = 4009; + } +} diff --git a/src/WebSockets/WebSocketHandler.php b/src/WebSockets/WebSocketHandler.php index 7820960b6b..3a49a4de90 100644 --- a/src/WebSockets/WebSocketHandler.php +++ b/src/WebSockets/WebSocketHandler.php @@ -8,6 +8,7 @@ use BeyondCode\LaravelWebSockets\QueryParameters; use BeyondCode\LaravelWebSockets\WebSockets\Channels\ChannelManager; use BeyondCode\LaravelWebSockets\WebSockets\Exceptions\ConnectionsOverCapacity; +use BeyondCode\LaravelWebSockets\WebSockets\Exceptions\OriginNotAllowed; use BeyondCode\LaravelWebSockets\WebSockets\Exceptions\UnknownAppKey; use BeyondCode\LaravelWebSockets\WebSockets\Exceptions\WebSocketException; use BeyondCode\LaravelWebSockets\WebSockets\Messages\PusherMessageFactory; @@ -30,6 +31,7 @@ public function onOpen(ConnectionInterface $connection) { $this ->verifyAppKey($connection) + ->verifyOrigin($connection) ->limitConcurrentConnections($connection) ->generateSocketId($connection) ->establishConnection($connection); @@ -77,6 +79,23 @@ protected function verifyAppKey(ConnectionInterface $connection) return $this; } + protected function verifyOrigin(ConnectionInterface $connection) + { + if (! $connection->app->allowedOrigins) { + return $this; + } + + $header = (string) ($connection->httpRequest->getHeader('Origin')[0] ?? null); + + $origin = parse_url($header, PHP_URL_HOST) ?: $header; + + if (! $header || ! in_array($origin, $connection->app->allowedOrigins)) { + throw new OriginNotAllowed($connection->app->key); + } + + return $this; + } + protected function limitConcurrentConnections(ConnectionInterface $connection) { if (! is_null($capacity = $connection->app->capacity)) { diff --git a/tests/ClientProviders/ConfigAppManagerTest.php b/tests/ClientProviders/ConfigAppManagerTest.php index 14b73821c5..9ba5561515 100644 --- a/tests/ClientProviders/ConfigAppManagerTest.php +++ b/tests/ClientProviders/ConfigAppManagerTest.php @@ -22,7 +22,7 @@ public function it_can_get_apps_from_the_config_file() { $apps = $this->appManager->all(); - $this->assertCount(1, $apps); + $this->assertCount(2, $apps); /** @var $app */ $app = $apps[0]; diff --git a/tests/ConnectionTest.php b/tests/ConnectionTest.php index 81f4ac0e65..0aba6eccf9 100644 --- a/tests/ConnectionTest.php +++ b/tests/ConnectionTest.php @@ -5,6 +5,7 @@ use BeyondCode\LaravelWebSockets\Apps\App; use BeyondCode\LaravelWebSockets\Tests\Mocks\Message; use BeyondCode\LaravelWebSockets\WebSockets\Exceptions\ConnectionsOverCapacity; +use BeyondCode\LaravelWebSockets\WebSockets\Exceptions\OriginNotAllowed; use BeyondCode\LaravelWebSockets\WebSockets\Exceptions\UnknownAppKey; class ConnectionTest extends TestCase @@ -14,7 +15,7 @@ public function unknown_app_keys_can_not_connect() { $this->expectException(UnknownAppKey::class); - $this->pusherServer->onOpen($this->getWebSocketConnection('/?appKey=test')); + $this->pusherServer->onOpen($this->getWebSocketConnection('test')); } /** @test */ @@ -65,4 +66,38 @@ public function ping_returns_pong() $connection->assertSentEvent('pusher:pong'); } + + /** @test */ + public function origin_validation_should_fail_for_no_origin() + { + $this->expectException(OriginNotAllowed::class); + + $connection = $this->getWebSocketConnection('TestOrigin'); + + $this->pusherServer->onOpen($connection); + + $connection->assertSentEvent('pusher:connection_established'); + } + + /** @test */ + public function origin_validation_should_fail_for_wrong_origin() + { + $this->expectException(OriginNotAllowed::class); + + $connection = $this->getWebSocketConnection('TestOrigin', ['Origin' => 'https://google.ro']); + + $this->pusherServer->onOpen($connection); + + $connection->assertSentEvent('pusher:connection_established'); + } + + /** @test */ + public function origin_validation_should_pass_for_the_right_origin() + { + $connection = $this->getWebSocketConnection('TestOrigin', ['Origin' => 'https://test.origin.com']); + + $this->pusherServer->onOpen($connection); + + $connection->assertSentEvent('pusher:connection_established'); + } } diff --git a/tests/TestCase.php b/tests/TestCase.php index 4ad82dd83c..9deb436a3e 100644 --- a/tests/TestCase.php +++ b/tests/TestCase.php @@ -70,6 +70,19 @@ protected function getEnvironmentSetUp($app) 'capacity' => null, 'enable_client_messages' => false, 'enable_statistics' => true, + 'allowed_origins' => [], + ], + [ + 'name' => 'Origin Test App', + 'id' => '1234', + 'key' => 'TestOrigin', + 'secret' => 'TestSecret', + 'capacity' => null, + 'enable_client_messages' => false, + 'enable_statistics' => true, + 'allowed_origins' => [ + 'test.origin.com', + ], ], ]); @@ -107,20 +120,20 @@ protected function getEnvironmentSetUp($app) } } - protected function getWebSocketConnection(string $url = '/?appKey=TestKey'): Connection + protected function getWebSocketConnection(string $appKey = 'TestKey', array $headers = []): Connection { $connection = new Connection(); - $connection->httpRequest = new Request('GET', $url); + $connection->httpRequest = new Request('GET', "/?appKey={$appKey}", $headers); return $connection; } - protected function getConnectedWebSocketConnection(array $channelsToJoin = [], string $url = '/?appKey=TestKey'): Connection + protected function getConnectedWebSocketConnection(array $channelsToJoin = [], string $appKey = 'TestKey', array $headers = []): Connection { $connection = new Connection(); - $connection->httpRequest = new Request('GET', $url); + $connection->httpRequest = new Request('GET', "/?appKey={$appKey}", $headers); $this->pusherServer->onOpen($connection);