From a953d4e4986a1fdc318c7ec6bfef01ae73a2635b Mon Sep 17 00:00:00 2001 From: Quang Le Date: Sat, 5 Apr 2025 13:38:11 +0700 Subject: [PATCH 1/3] refactor(sse): streamline SSE connection management logic --- package-lock.json | 3 +- src/transports/sse/server.ts | 278 ++++++++++++++++++++--------------- 2 files changed, 162 insertions(+), 119 deletions(-) diff --git a/package-lock.json b/package-lock.json index 0b54df2..52ed759 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1323,7 +1323,8 @@ "version": "1.1.8", "resolved": "https://registry.npmjs.org/@types/content-type/-/content-type-1.1.8.tgz", "integrity": "sha512-1tBhmVUeso3+ahfyaKluXe38p+94lovUZdoVfQ3OnJo9uJC42JT7CBoN3k9HYhAae+GwiBYmHu+N9FZhOG+2Pg==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/@types/estree": { "version": "1.0.7", diff --git a/src/transports/sse/server.ts b/src/transports/sse/server.ts index 8b52e96..100695c 100644 --- a/src/transports/sse/server.ts +++ b/src/transports/sse/server.ts @@ -1,19 +1,16 @@ -import { randomUUID } from "node:crypto" -import { IncomingMessage, Server as HttpServer, ServerResponse, createServer } from "node:http" -import { JSONRPCMessage, ClientRequest } from "@modelcontextprotocol/sdk/types.js" -import contentType from "content-type" -import getRawBody from "raw-body" -import { APIKeyAuthProvider } from "../../auth/providers/apikey.js" -import { DEFAULT_AUTH_ERROR } from "../../auth/types.js" -import { AbstractTransport } from "../base.js" -import { DEFAULT_SSE_CONFIG, SSETransportConfig, SSETransportConfigInternal, DEFAULT_CORS_CONFIG, CORSConfig } from "./types.js" -import { logger } from "../../core/Logger.js" -import { getRequestHeader, setResponseHeaders } from "../../utils/headers.js" +import { randomUUID } from "node:crypto"; +import { IncomingMessage, Server as HttpServer, ServerResponse, createServer } from "node:http"; +import { JSONRPCMessage, ClientRequest } from "@modelcontextprotocol/sdk/types.js"; +import contentType from "content-type"; +import getRawBody from "raw-body"; +import { APIKeyAuthProvider } from "../../auth/providers/apikey.js"; +import { DEFAULT_AUTH_ERROR } from "../../auth/types.js"; +import { AbstractTransport } from "../base.js"; +import { DEFAULT_SSE_CONFIG, SSETransportConfig, SSETransportConfigInternal, DEFAULT_CORS_CONFIG, CORSConfig } from "./types.js"; +import { logger } from "../../core/Logger.js"; +import { getRequestHeader, setResponseHeaders } from "../../utils/headers.js"; import { PING_SSE_MESSAGE } from "../utils/ping-message.js"; -interface ExtendedIncomingMessage extends IncomingMessage { - body?: ClientRequest -} const SSE_HEADERS = { "Content-Type": "text/event-stream", @@ -25,14 +22,14 @@ export class SSEServerTransport extends AbstractTransport { readonly type = "sse" private _server?: HttpServer - private _sseResponse?: ServerResponse - private _sessionId: string + private _connections: Map // Map + private _sessionId: string // Server instance ID private _config: SSETransportConfigInternal - private _keepAliveInterval?: NodeJS.Timeout constructor(config: SSETransportConfig = {}) { super() - this._sessionId = randomUUID() + this._connections = new Map() + this._sessionId = randomUUID() // Used to validate POST messages belong to this server instance this._config = { ...DEFAULT_SSE_CONFIG, ...config @@ -76,11 +73,11 @@ export class SSEServerTransport extends AbstractTransport { } return new Promise((resolve) => { - this._server = createServer(async (req, res) => { + this._server = createServer(async (req: IncomingMessage, res: ServerResponse) => { try { await this.handleRequest(req, res) - } catch (error) { - logger.error(`Error handling request: ${error}`) + } catch (error: any) { + logger.error(`Error handling request: ${error instanceof Error ? error.message : String(error)}`) res.writeHead(500).end("Internal Server Error") } }) @@ -90,8 +87,8 @@ export class SSEServerTransport extends AbstractTransport { resolve() }) - this._server.on("error", (error) => { - logger.error(`SSE server error: ${error}`) + this._server.on("error", (error: Error) => { + logger.error(`SSE server error: ${error.message}`) this._onerror?.(error) }) @@ -102,7 +99,7 @@ export class SSEServerTransport extends AbstractTransport { }) } - private async handleRequest(req: ExtendedIncomingMessage, res: ServerResponse): Promise { + private async handleRequest(req: IncomingMessage, res: ServerResponse): Promise { logger.debug(`Incoming request: ${req.method} ${req.url}`) if (req.method === "OPTIONS") { @@ -122,25 +119,23 @@ export class SSEServerTransport extends AbstractTransport { if (!isAuthenticated) return } - if (this._sseResponse?.writableEnded) { - this._sseResponse = undefined - } - - if (this._sseResponse) { - logger.warn("SSE connection already established; closing the old connection to allow a new one.") - this._sseResponse.end() - this.cleanupConnection() - } - - this.setupSSEConnection(res) - return + // Remove check for existing single _sseResponse + // Generate a unique ID for this specific connection + const connectionId = randomUUID(); + this.setupSSEConnection(res, connectionId); + return; } if (req.method === "POST" && url.pathname === this._config.messageEndpoint) { - if (sessionId !== this._sessionId) { - logger.warn(`Invalid session ID received: ${sessionId}, expected: ${this._sessionId}`) - res.writeHead(403).end("Invalid session ID") - return + // **Connection Validation (User Requested):** + // Check if the 'sessionId' from the POST request URL query parameter + // (which should contain a connectionId provided by the server via the 'endpoint' event) + // corresponds to an active connection in the `_connections` map. + if (!sessionId || !this._connections.has(sessionId)) { + logger.warn(`Invalid or inactive connection ID in POST request URL: ${sessionId}`); + // Use 403 Forbidden as the client is attempting an operation for an invalid/unknown connection + res.writeHead(403).end("Invalid or inactive connection ID"); + return; } if (this._config.auth?.endpoints?.messages !== false) { @@ -155,7 +150,7 @@ export class SSEServerTransport extends AbstractTransport { res.writeHead(404).end("Not Found") } - private async handleAuthentication(req: ExtendedIncomingMessage, res: ServerResponse, context: string): Promise { + private async handleAuthentication(req: IncomingMessage, res: ServerResponse, context: string): Promise { if (!this._config.auth?.provider) { return true } @@ -203,9 +198,8 @@ export class SSEServerTransport extends AbstractTransport { return true } - private setupSSEConnection(res: ServerResponse): void { - logger.debug(`Setting up SSE connection for session: ${this._sessionId}`) - + private setupSSEConnection(res: ServerResponse, connectionId: string): void { + logger.debug(`Setting up SSE connection: ${connectionId} for server session: ${this._sessionId}`); const headers = { ...SSE_HEADERS, ...this.getCorsHeaders(), @@ -218,60 +212,65 @@ export class SSEServerTransport extends AbstractTransport { res.socket.setNoDelay(true) res.socket.setTimeout(0) res.socket.setKeepAlive(true, 1000) - logger.debug('Socket optimized for SSE connection') + logger.debug('Socket optimized for SSE connection'); } - - const endpointUrl = `${this._config.messageEndpoint}?sessionId=${this._sessionId}` - logger.debug(`Sending endpoint URL: ${endpointUrl}`) - res.write(`event: endpoint\ndata: ${endpointUrl}\n\n`) - - logger.debug('Sending initial keep-alive') - - this._keepAliveInterval = setInterval(() => { - if (this._sseResponse && !this._sseResponse.writableEnded) { - try { - this._sseResponse.write(PING_SSE_MESSAGE); - } catch (error) { - logger.error(`Error sending keep-alive: ${error}`) - this.cleanupConnection() + // **Important Change:** The endpoint URL now includes the specific connectionId + // in the 'sessionId' query parameter, as requested by user feedback. + // The client should use this exact URL for subsequent POST messages. + const endpointUrl = `${this._config.messageEndpoint}?sessionId=${connectionId}`; + logger.debug(`Sending endpoint URL for connection ${connectionId}: ${endpointUrl}`); + res.write(`event: endpoint\ndata: ${endpointUrl}\n\n`); + // Send the unique connection ID separately as well for potential client-side use + res.write(`event: connectionId\ndata: ${connectionId}\n\n`); + logger.debug(`Sending initial keep-alive for connection: ${connectionId}`); + const intervalId = setInterval(() => { + const connection = this._connections.get(connectionId); + if (connection && !connection.res.writableEnded) { + try { + connection.res.write(PING_SSE_MESSAGE); + } + catch (error: any) { + logger.error(`Error sending keep-alive for connection ${connectionId}: ${error instanceof Error ? error.message : String(error)}`); + this.cleanupConnection(connectionId); + } } - } - }, 15000) - - this._sseResponse = res - - const cleanup = () => this.cleanupConnection() - + else { + // Should not happen if cleanup is working, but clear interval just in case + logger.warn(`Keep-alive interval running for missing/ended connection: ${connectionId}`); + this.cleanupConnection(connectionId); // Will clear interval + } + }, 15000); + this._connections.set(connectionId, { res, intervalId }); + const cleanup = () => this.cleanupConnection(connectionId); res.on("close", () => { - logger.info(`SSE connection closed for session: ${this._sessionId}`) - cleanup() - }) - - res.on("error", (error) => { - logger.error(`SSE connection error for session ${this._sessionId}: ${error}`) - this._onerror?.(error) - cleanup() - }) - + logger.info(`SSE connection closed: ${connectionId}`); + cleanup(); + }); + res.on("error", (error: Error) => { + logger.error(`SSE connection error for ${connectionId}: ${error.message}`); + this._onerror?.(error); + cleanup(); + }); res.on("end", () => { - logger.info(`SSE connection ended for session: ${this._sessionId}`) - cleanup() - }) - - logger.info(`SSE connection established successfully for session: ${this._sessionId}`) + logger.info(`SSE connection ended: ${connectionId}`); + cleanup(); + }); + logger.info(`SSE connection established successfully: ${connectionId}`); } - private async handlePostMessage(req: ExtendedIncomingMessage, res: ServerResponse): Promise { - if (!this._sseResponse || this._sseResponse.writableEnded) { - logger.warn(`Rejecting message: no active SSE connection for session ${this._sessionId}`) - res.writeHead(409).end("SSE connection not established") - return + private async handlePostMessage(req: IncomingMessage, res: ServerResponse): Promise { + // Check if *any* connection is active, not just the old single _sseResponse + if (this._connections.size === 0) { + logger.warn(`Rejecting message: no active SSE connections for server session ${this._sessionId}`); + // Use 409 Conflict as it indicates the server state prevents fulfilling the request + res.writeHead(409).end("No active SSE connection established"); + return; } let currentMessage: { id?: string | number; method?: string } = {} try { - const rawMessage = req.body || await (async () => { + const rawMessage = (req as any).body || await (async () => { // Cast req to any to access potential body property const ct = contentType.parse(req.headers["content-type"] ?? "") if (ct.type !== "application/json") { throw new Error(`Unsupported content-type: ${ct.type}`) @@ -316,7 +315,7 @@ export class SSEServerTransport extends AbstractTransport { logger.debug(`Successfully processed message ${rpcMessage.id}`) - } catch (error) { + } catch (error: any) { const errorMessage = error instanceof Error ? error.message : String(error) logger.error(`Error handling message for session ${this._sessionId}:`) logger.error(`- Error: ${errorMessage}`) @@ -332,7 +331,7 @@ export class SSEServerTransport extends AbstractTransport { data: { method: currentMessage.method || "unknown", sessionId: this._sessionId, - connectionActive: Boolean(this._sseResponse), + connectionActive: Boolean(this._connections.size > 0), type: "message_handler_error" } } @@ -343,42 +342,85 @@ export class SSEServerTransport extends AbstractTransport { } } + // Broadcast message to all connected clients async send(message: JSONRPCMessage): Promise { - if (!this._sseResponse || this._sseResponse.writableEnded) { - throw new Error("SSE connection not established") - } - - this._sseResponse.write(`data: ${JSON.stringify(message)}\n\n`) + if (this._connections.size === 0) { + logger.warn("Attempted to send message, but no clients are connected."); + // Optionally throw an error or just log + // throw new Error("No SSE connections established"); + return; + } + const messageString = `data: ${JSON.stringify(message)}\n\n`; + logger.debug(`Broadcasting message to ${this._connections.size} clients: ${JSON.stringify(message)}`); + let failedSends = 0; + for (const [connectionId, connection] of this._connections.entries()) { + if (connection.res && !connection.res.writableEnded) { + try { + connection.res.write(messageString); + } + catch (error: any) { + failedSends++; + logger.error(`Error sending message to connection ${connectionId}: ${error instanceof Error ? error.message : String(error)}`); + // Clean up the problematic connection + this.cleanupConnection(connectionId); + } + } + else { + // Should not happen if cleanup is working, but handle defensively + logger.warn(`Attempted to send to ended connection: ${connectionId}`); + this.cleanupConnection(connectionId); + } + } + if (failedSends > 0) { + logger.warn(`Failed to send message to ${failedSends} connections.`); + } } async close(): Promise { - if (this._sseResponse && !this._sseResponse.writableEnded) { - this._sseResponse.end() - } - - this.cleanupConnection() - - return new Promise((resolve) => { - if (!this._server) { - resolve() - return + logger.info(`Closing SSE transport and ${this._connections.size} connections.`); + // Close all active client connections + for (const connectionId of this._connections.keys()) { + this.cleanupConnection(connectionId, true); // Pass true to end the response } - - this._server.close(() => { - logger.info("SSE server stopped") - this._server = undefined - this._onclose?.() - resolve() - }) - }) + this._connections.clear(); // Ensure map is empty + // Close the main server + return new Promise((resolve) => { + if (!this._server) { + logger.debug("Server already stopped."); + resolve(); + return; + } + this._server.close(() => { + logger.info("SSE server stopped"); + this._server = undefined; + this._onclose?.(); + resolve(); + }); + }); } - private cleanupConnection(): void { - if (this._keepAliveInterval) { - clearInterval(this._keepAliveInterval) - this._keepAliveInterval = undefined - } - this._sseResponse = undefined + // Clean up a specific connection by its ID + private cleanupConnection(connectionId: string, endResponse = false): void { + const connection = this._connections.get(connectionId); + if (connection) { + logger.debug(`Cleaning up connection: ${connectionId}`); + if (connection.intervalId) { + clearInterval(connection.intervalId); + } + if (endResponse && connection.res && !connection.res.writableEnded) { + try { + connection.res.end(); + } + catch (e: any) { + logger.warn(`Error ending response for connection ${connectionId}: ${e instanceof Error ? e.message : String(e)}`); + } + } + this._connections.delete(connectionId); + logger.debug(`Connection removed: ${connectionId}. Remaining connections: ${this._connections.size}`); + } + else { + logger.debug(`Attempted to clean up non-existent connection: ${connectionId}`); + } } isRunning(): boolean { From d75f0d920880557f734e46907a6fbea25a9f5ead Mon Sep 17 00:00:00 2001 From: Quang Le Date: Sat, 5 Apr 2025 13:43:27 +0700 Subject: [PATCH 2/3] build: add prepare script to package.json --- package.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/package.json b/package.json index 26f2d1d..b2cc201 100644 --- a/package.json +++ b/package.json @@ -24,7 +24,8 @@ "watch": "tsc --watch", "lint": "eslint", "lint:fix": "eslint --fix", - "format": "prettier --write \"src/**/*.ts\"" + "format": "prettier --write \"src/**/*.ts\"", + "prepare": "npm run build" }, "engines": { "node": ">=18.19.0" From 94d48cf8c7edd3e1a83a78658dd8bfe3ec3eeee2 Mon Sep 17 00:00:00 2001 From: Quang Le Date: Tue, 22 Apr 2025 14:51:19 +0700 Subject: [PATCH 3/3] feat(sse): add session ID handling for connections --- src/transports/sse/server.ts | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/transports/sse/server.ts b/src/transports/sse/server.ts index 100695c..82ab638 100644 --- a/src/transports/sse/server.ts +++ b/src/transports/sse/server.ts @@ -119,7 +119,25 @@ export class SSEServerTransport extends AbstractTransport { if (!isAuthenticated) return } - // Remove check for existing single _sseResponse + // Check if a sessionId was provided in the request + if (sessionId) { + // If sessionId exists but is not in our connections map, it's invalid or inactive + if (!this._connections.has(sessionId)) { + logger.info(`Invalid or inactive session ID in GET request: ${sessionId}. Creating new connection.`); + // Continue execution to create a new connection below + } else { + // If the connection exists and is still active, we could either: + // 1. Return an error (409 Conflict) as a client shouldn't create duplicate connections + // 2. Close the old connection and create a new one + // 3. Keep the old connection and return its details + + // Option 2: Close old connection and create new one + logger.info(`Replacing existing connection for session ID: ${sessionId}`); + this.cleanupConnection(sessionId); + // Continue execution to create a new connection below + } + } + // Generate a unique ID for this specific connection const connectionId = randomUUID(); this.setupSSEConnection(res, connectionId);