diff --git a/README.md b/README.md index 4032c58..5706852 100644 --- a/README.md +++ b/README.md @@ -66,3 +66,6 @@ npm publish --access public git push git push --tags ``` + +Domains Diagram: +![diagram_encapuslated.svg](images%2Fdiagram_encapuslated.svg) diff --git a/images/diagram_encapuslated.svg b/images/diagram_encapuslated.svg new file mode 100644 index 0000000..bad2f68 --- /dev/null +++ b/images/diagram_encapuslated.svg @@ -0,0 +1,17 @@ + + + + + + + + RPCCLientHandlers1treamCallerMiddlwarerequestdata transformationstreamFactoryJsonRPCRequestUint8ArraymethodsProxyCall TypeTimer attachedRPCCLienthandleStreamHandlerAbortController Timerdata transformationResponseMiddlware \ No newline at end of file diff --git a/package.json b/package.json index 66c6d91..f9303d9 100644 --- a/package.json +++ b/package.json @@ -53,6 +53,16 @@ "ts-node": "^10.9.1", "tsconfig-paths": "^3.9.0", "typedoc": "^0.23.21", - "typescript": "^4.9.3" + "typescript": "^4.9.3", + "@fast-check/jest": "^1.1.0" + }, + "dependencies": { + "@matrixai/async-init": "^1.9.4", + "@matrixai/contexts": "^1.2.0", + "@matrixai/logger": "^3.1.0", + "@matrixai/errors": "^1.2.0", + "@matrixai/events": "^3.2.0", + "@streamparser/json": "^0.0.17", + "ix": "^5.0.0" } } diff --git a/src/RPCClient.ts b/src/RPCClient.ts new file mode 100644 index 0000000..6d19363 --- /dev/null +++ b/src/RPCClient.ts @@ -0,0 +1,592 @@ +import type { WritableStream, ReadableStream } from 'stream/web'; +import type { ContextTimedInput } from '@matrixai/contexts'; +import type { + HandlerType, + JSONRPCRequestMessage, + StreamFactory, + ClientManifest, + RPCStream, + JSONRPCResponseResult, +} from './types'; +import type { JSONValue, IdGen } from './types'; +import type { + JSONRPCRequest, + JSONRPCResponse, + MiddlewareFactory, + MapCallers, +} from './types'; +import type { ErrorRPCRemote } from './errors'; +import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; +import Logger from '@matrixai/logger'; +import { Timer } from '@matrixai/timer'; +import { createDestroy } from '@matrixai/async-init'; +import * as rpcUtilsMiddleware from './utils/middleware'; +import * as rpcErrors from './errors'; +import * as rpcUtils from './utils/utils'; +import { promise } from './utils'; +import { ErrorRPCStreamEnded, never } from './errors'; +import * as events from './events'; + +const timerCleanupReasonSymbol = Symbol('timerCleanUpReasonSymbol'); + +/** + * Events: + * - {@link events.Event} + */ +interface RPCClient + extends createDestroy.CreateDestroy {} +/** + * You must provide an error handler `addEventListener('error')`. + * Otherwise, errors will just be ignored. + * + * Events: + * - {@link events.EventRPCClientDestroy} + * - {@link events.EventRPCClientDestroyed} + */ +@createDestroy.CreateDestroy({ + eventDestroy: events.EventRPCClientDestroy, + eventDestroyed: events.EventRPCClientDestroyed, +}) +class RPCClient { + /** + * @param obj + * @param obj.manifest - Client manifest that defines the types for the rpc + * methods. + * @param obj.streamFactory - An arrow function that when called, creates a + * new stream for each rpc method call. + * @param obj.middlewareFactory - Middleware used to process the rpc messages. + * The middlewareFactory needs to be a function that creates a pair of + * transform streams that convert `JSONRPCRequest` to `Uint8Array` on the forward + * path and `Uint8Array` to `JSONRPCResponse` on the reverse path. + * @param obj.streamKeepAliveTimeoutTime - Timeout time used if no timeout timer was provided when making a call. + * Defaults to 60,000 milliseconds. + * for a client call. + * @param obj.logger + */ + static async createRPCClient({ + manifest, + streamFactory, + middlewareFactory = rpcUtilsMiddleware.defaultClientMiddlewareWrapper(), + streamKeepAliveTimeoutTime = Infinity, // 1 minute + logger = new Logger(this.name), + idGen = () => Promise.resolve(null), + }: { + manifest: M; + streamFactory: StreamFactory; + middlewareFactory?: MiddlewareFactory< + Uint8Array, + JSONRPCRequest, + JSONRPCResponse, + Uint8Array + >; + streamKeepAliveTimeoutTime?: number; + logger?: Logger; + idGen: IdGen; + toError?: (errorData, metadata?: JSONValue) => ErrorRPCRemote; + }) { + logger.info(`Creating ${this.name}`); + const rpcClient = new this({ + manifest, + streamFactory, + middlewareFactory, + streamKeepAliveTimeoutTime: streamKeepAliveTimeoutTime, + logger, + idGen, + }); + logger.info(`Created ${this.name}`); + return rpcClient; + } + protected onTimeoutCallback?: () => void; + protected idGen: IdGen; + protected logger: Logger; + protected streamFactory: StreamFactory; + protected middlewareFactory: MiddlewareFactory< + Uint8Array, + JSONRPCRequest, + JSONRPCResponse, + Uint8Array + >; + protected callerTypes: Record; + toError: (errorData: any, metadata?: JSONValue) => Error; + public registerOnTimeoutCallback(callback: () => void) { + this.onTimeoutCallback = callback; + } + // Method proxies + public readonly streamKeepAliveTimeoutTime: number; + public readonly methodsProxy = new Proxy( + {}, + { + get: (_, method) => { + if (typeof method === 'symbol') return; + switch (this.callerTypes[method]) { + case 'UNARY': + return (params, ctx) => this.unaryCaller(method, params, ctx); + case 'SERVER': + return (params, ctx) => + this.serverStreamCaller(method, params, ctx); + case 'CLIENT': + return (ctx) => this.clientStreamCaller(method, ctx); + case 'DUPLEX': + return (ctx) => this.duplexStreamCaller(method, ctx); + case 'RAW': + return (header, ctx) => this.rawStreamCaller(method, header, ctx); + default: + return; + } + }, + }, + ); + + public constructor({ + manifest, + streamFactory, + middlewareFactory, + streamKeepAliveTimeoutTime, + logger, + idGen = () => Promise.resolve(null), + toError, + }: { + manifest: M; + streamFactory: StreamFactory; + middlewareFactory: MiddlewareFactory< + Uint8Array, + JSONRPCRequest, + JSONRPCResponse, + Uint8Array + >; + streamKeepAliveTimeoutTime: number; + logger: Logger; + idGen: IdGen; + toError?: (errorData, metadata?: JSONValue) => ErrorRPCRemote; + }) { + this.idGen = idGen; + this.callerTypes = rpcUtils.getHandlerTypes(manifest); + this.streamFactory = streamFactory; + this.middlewareFactory = middlewareFactory; + this.streamKeepAliveTimeoutTime = streamKeepAliveTimeoutTime; + this.logger = logger; + this.toError = toError || rpcUtils.toError; + } + + public async destroy({ + errorCode = rpcErrors.JSONRPCErrorCode.RPCStopping, + errorMessage = '', + force = true, + }: { + errorCode?: number; + errorMessage?: string; + force?: boolean; + } = {}): Promise { + this.logger.info(`Destroying ${this.constructor.name}`); + + // You can dispatch an event before the actual destruction starts + this.dispatchEvent(new events.EventRPCClientDestroy()); + + // Dispatch an event after the client has been destroyed + this.dispatchEvent(new events.EventRPCClientDestroyed()); + + this.logger.info(`Destroyed ${this.constructor.name}`); + } + + @ready(new rpcErrors.ErrorRPCCallerFailed()) + public get methods(): MapCallers { + return this.methodsProxy as MapCallers; + } + + /** + * Generic caller for unary RPC calls. + * This returns the response in the provided type. No validation is done so + * make sure the types match the handler types. + * @param method - Method name of the RPC call + * @param parameters - Parameters to be provided with the RPC message. Matches + * the provided I type. + * @param ctx - ContextTimed used for timeouts and cancellation. + */ + @ready(new rpcErrors.ErrorMissingCaller()) + public async unaryCaller( + method: string, + parameters: I, + ctx: Partial = {}, + ): Promise { + const callerInterface = await this.duplexStreamCaller(method, ctx); + const reader = callerInterface.readable.getReader(); + const writer = callerInterface.writable.getWriter(); + try { + await writer.write(parameters); + const output = await reader.read(); + if (output.done) { + throw new rpcErrors.ErrorMissingCaller('Missing response', { + cause: ctx.signal?.reason, + }); + } + await reader.cancel(); + await writer.close(); + return output.value; + } finally { + // Attempt clean up, ignore errors if already cleaned up + await reader.cancel().catch(() => {}); + await writer.close().catch(() => {}); + } + } + + /** + * Generic caller for server streaming RPC calls. + * This returns a ReadableStream of the provided type. When finished, the + * readable needs to be cleaned up, otherwise cleanup happens mostly + * automatically. + * @param method - Method name of the RPC call + * @param parameters - Parameters to be provided with the RPC message. Matches + * the provided I type. + * @param ctx - ContextTimed used for timeouts and cancellation. + */ + @ready(new rpcErrors.ErrorRPCCallerFailed()) + public async serverStreamCaller( + method: string, + parameters: I, + ctx: Partial = {}, + ): Promise> { + const callerInterface = await this.duplexStreamCaller(method, ctx); + const writer = callerInterface.writable.getWriter(); + try { + await writer.write(parameters); + await writer.close(); + } catch (e) { + // Clean up if any problems, ignore errors if already closed + await callerInterface.readable.cancel(e); + throw e; + } + return callerInterface.readable; + } + + /** + * Generic caller for Client streaming RPC calls. + * This returns a WritableStream for writing the input to and a Promise that + * resolves when the output is received. + * When finished the writable stream must be ended. Failing to do so will + * hold the connection open and result in a resource leak until the + * call times out. + * @param method - Method name of the RPC call + * @param ctx - ContextTimed used for timeouts and cancellation. + */ + @ready(new rpcErrors.ErrorRPCCallerFailed()) + public async clientStreamCaller( + method: string, + ctx: Partial = {}, + ): Promise<{ + output: Promise; + writable: WritableStream; + }> { + const callerInterface = await this.duplexStreamCaller(method, ctx); + const reader = callerInterface.readable.getReader(); + const output = reader.read().then(({ value, done }) => { + if (done) { + throw new rpcErrors.ErrorMissingCaller('Missing response', { + cause: ctx.signal?.reason, + }); + } + return value; + }); + return { + output, + writable: callerInterface.writable, + }; + } + + /** + * Generic caller for duplex RPC calls. + * This returns a `ReadableWritablePair` of the types specified. No validation + * is applied to these types so make sure they match the types of the handler + * you are calling. + * When finished the streams must be ended manually. Failing to do so will + * hold the connection open and result in a resource leak until the + * call times out. + * @param method - Method name of the RPC call + * @param ctx - ContextTimed used for timeouts and cancellation. + */ + @ready(new rpcErrors.ErrorRPCCallerFailed()) + public async duplexStreamCaller( + method: string, + ctx: Partial = {}, + ): Promise> { + // Setting up abort signal and timer + const abortController = new AbortController(); + const signal = abortController.signal; + // A promise that will reject if there is an abort signal or timeout + const abortRaceProm = promise(); + // Prevent unhandled rejection when we're done with the promise + abortRaceProm.p.catch(() => {}); + const abortRacePromHandler = () => { + abortRaceProm.rejectP(signal.reason); + }; + signal.addEventListener('abort', abortRacePromHandler); + + let abortHandler: () => void; + if (ctx.signal != null) { + // Propagate signal events + abortHandler = () => { + abortController.abort(ctx.signal?.reason); + }; + if (ctx.signal.aborted) abortHandler(); + ctx.signal.addEventListener('abort', abortHandler); + } + let timer: Timer; + if (!(ctx.timer instanceof Timer)) { + timer = new Timer({ + delay: ctx.timer ?? this.streamKeepAliveTimeoutTime, + }); + } else { + timer = ctx.timer; + } + const cleanUp = () => { + // Clean up the timer and signal + if (ctx.timer == null) timer.cancel(timerCleanupReasonSymbol); + if (ctx.signal != null) { + ctx.signal.removeEventListener('abort', abortHandler); + } + signal.addEventListener('abort', abortRacePromHandler); + }; + // Setting up abort events for timeout + const timeoutError = new rpcErrors.ErrorRPCTimedOut( + 'Error RPC has timed out', + { cause: ctx.signal?.reason }, + ); + void timer.then( + () => { + abortController.abort(timeoutError); + if (this.onTimeoutCallback) { + this.onTimeoutCallback(); + } + }, + () => {}, // Ignore cancellation error + ); + + // Hooking up agnostic stream side + let rpcStream: RPCStream; + const streamFactoryProm = this.streamFactory({ signal, timer }); + try { + rpcStream = await Promise.race([streamFactoryProm, abortRaceProm.p]); + } catch (e) { + cleanUp(); + void streamFactoryProm.then((stream) => + stream.cancel(ErrorRPCStreamEnded), + ); + throw e; + } + void timer.then( + () => { + rpcStream.cancel( + new rpcErrors.ErrorRPCTimedOut('RPC has timed out', { + cause: ctx.signal?.reason, + }), + ); + }, + () => {}, // Ignore cancellation error + ); + // Deciding if we want to allow refreshing + // We want to refresh timer if none was provided + const refreshingTimer: Timer | undefined = + ctx.timer == null ? timer : undefined; + // Composing stream transforms and middleware + const metadata = { + ...(rpcStream.meta ?? {}), + command: method, + }; + const outputMessageTransformStream = + rpcUtils.clientOutputTransformStream(metadata, refreshingTimer); + const inputMessageTransformStream = rpcUtils.clientInputTransformStream( + method, + refreshingTimer, + ); + const middleware = this.middlewareFactory( + { signal, timer }, + rpcStream.cancel, + metadata, + ); + // This `Promise.allSettled` is used to asynchronously track the state + // of the streams. When both have finished we can clean up resources. + void Promise.allSettled([ + rpcStream.readable + .pipeThrough(middleware.reverse) + .pipeTo(outputMessageTransformStream.writable) + // Ignore any errors, we only care about stream ending + .catch(() => {}), + inputMessageTransformStream.readable + .pipeThrough(middleware.forward) + .pipeTo(rpcStream.writable) + // Ignore any errors, we only care about stream ending + .catch(() => {}), + ]).finally(() => { + cleanUp(); + }); + + // Returning interface + return { + readable: outputMessageTransformStream.readable, + writable: inputMessageTransformStream.writable, + cancel: rpcStream.cancel, + meta: metadata, + }; + } + + /** + * Generic caller for raw RPC calls. + * This returns a `ReadableWritablePair` of the raw RPC stream. + * When finished the streams must be ended manually. Failing to do so will + * hold the connection open and result in a resource leak until the + * call times out. + * Raw streams don't support the keep alive timeout. Timeout will only apply\ + * to the creation of the stream. + * @param method - Method name of the RPC call + * @param headerParams - Parameters for the header message. The header is a + * single RPC message that is sent to specify the method for the RPC call. + * Any metadata of extra parameters is provided here. + * @param ctx - ContextTimed used for timeouts and cancellation. + * @param id - Id is generated only once, and used throughout the stream for the rest of the communication + */ + @ready(new rpcErrors.ErrorRPCCallerFailed()) + public async rawStreamCaller( + method: string, + headerParams: JSONValue, + ctx: Partial = {}, + ): Promise< + RPCStream< + Uint8Array, + Uint8Array, + Record & { result: JSONValue; command: string } + > + > { + // Setting up abort signal and timer + const abortController = new AbortController(); + const signal = abortController.signal; + // A promise that will reject if there is an abort signal or timeout + const abortRaceProm = promise(); + // Prevent unhandled rejection when we're done with the promise + abortRaceProm.p.catch(() => {}); + const abortRacePromHandler = () => { + abortRaceProm.rejectP(signal.reason); + }; + signal.addEventListener('abort', abortRacePromHandler); + + let abortHandler: () => void; + if (ctx.signal != null) { + // Propagate signal events + abortHandler = () => { + abortController.abort(ctx.signal?.reason); + }; + if (ctx.signal.aborted) abortHandler(); + ctx.signal.addEventListener('abort', abortHandler); + } + let timer: Timer; + if (!(ctx.timer instanceof Timer)) { + timer = new Timer({ + delay: ctx.timer ?? this.streamKeepAliveTimeoutTime, + }); + } else { + timer = ctx.timer; + } + const cleanUp = () => { + // Clean up the timer and signal + if (ctx.timer == null) timer.cancel(timerCleanupReasonSymbol); + if (ctx.signal != null) { + ctx.signal.removeEventListener('abort', abortHandler); + } + signal.addEventListener('abort', abortRacePromHandler); + }; + // Setting up abort events for timeout + const timeoutError = new rpcErrors.ErrorRPCTimedOut('RPC has timed out', { + cause: ctx.signal?.reason, + }); + void timer.then( + () => { + abortController.abort(timeoutError); + }, + () => {}, // Ignore cancellation error + ); + + const setupStream = async (): Promise< + [JSONValue, RPCStream] + > => { + if (signal.aborted) throw signal.reason; + const abortProm = promise(); + // Ignore error if orphaned + void abortProm.p.catch(() => {}); + signal.addEventListener( + 'abort', + () => { + abortProm.rejectP(signal.reason); + }, + { once: true }, + ); + const rpcStream = await Promise.race([ + this.streamFactory({ signal, timer }), + abortProm.p, + ]); + const tempWriter = rpcStream.writable.getWriter(); + const id = await this.idGen(); + const header: JSONRPCRequestMessage = { + jsonrpc: '2.0', + method, + params: headerParams, + id, + }; + await tempWriter.write(Buffer.from(JSON.stringify(header))); + tempWriter.releaseLock(); + const headTransformStream = rpcUtils.parseHeadStream( + rpcUtils.parseJSONRPCResponse, + ); + void rpcStream.readable + // Allow us to re-use the readable after reading the first message + .pipeTo(headTransformStream.writable) + // Ignore any errors here, we only care that it ended + .catch(() => {}); + const tempReader = headTransformStream.readable.getReader(); + let leadingMessage: JSONRPCResponseResult; + try { + const message = await Promise.race([tempReader.read(), abortProm.p]); + const messageValue = message.value as JSONRPCResponse; + if (message.done) never(); + if ('error' in messageValue) { + const metadata = { + ...(rpcStream.meta ?? {}), + command: method, + }; + throw this.toError(messageValue.error.data, metadata); + } + leadingMessage = messageValue; + } catch (e) { + rpcStream.cancel( + new ErrorRPCStreamEnded('RPC Stream Ended', { cause: e }), + ); + throw e; + } + tempReader.releaseLock(); + const newRpcStream: RPCStream = { + writable: rpcStream.writable, + readable: headTransformStream.readable as ReadableStream, + cancel: rpcStream.cancel, + meta: rpcStream.meta, + }; + return [leadingMessage.result, newRpcStream]; + }; + let streamCreation: [JSONValue, RPCStream]; + try { + streamCreation = await setupStream(); + } finally { + cleanUp(); + } + const [result, rpcStream] = streamCreation; + const metadata = { + ...(rpcStream.meta ?? {}), + result, + command: method, + }; + return { + writable: rpcStream.writable, + readable: rpcStream.readable, + cancel: rpcStream.cancel, + meta: metadata, + }; + } +} + +export default RPCClient; diff --git a/src/RPCServer.ts b/src/RPCServer.ts new file mode 100644 index 0000000..71d8966 --- /dev/null +++ b/src/RPCServer.ts @@ -0,0 +1,662 @@ +import type { ReadableStreamDefaultReadResult } from 'stream/web'; +import type { + ClientHandlerImplementation, + DuplexHandlerImplementation, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResponseError, + JSONRPCResponseResult, + ServerManifest, + RawHandlerImplementation, + ServerHandlerImplementation, + UnaryHandlerImplementation, + RPCStream, + MiddlewareFactory, +} from './types'; +import type { JSONValue } from './types'; +import type { IdGen } from './types'; +import { ReadableStream, TransformStream } from 'stream/web'; +import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; +import Logger from '@matrixai/logger'; +import { PromiseCancellable } from '@matrixai/async-cancellable'; +import { Timer } from '@matrixai/timer'; +import { createDestroy } from '@matrixai/async-init'; +import { RawHandler } from './handlers'; +import { DuplexHandler } from './handlers'; +import { ServerHandler } from './handlers'; +import { UnaryHandler } from './handlers'; +import { ClientHandler } from './handlers'; +import * as rpcEvents from './events'; +import * as rpcUtils from './utils'; +import * as rpcErrors from './errors'; +import * as rpcUtilsMiddleware from './utils'; +import { ErrorHandlerAborted, JSONRPCErrorCode, never } from './errors'; +import * as events from './events'; + +const cleanupReason = Symbol('CleanupReason'); + +/** + * You must provide a error handler `addEventListener('error')`. + * Otherwise errors will just be ignored. + * + * Events: + * - error + */ +interface RPCServer extends createDestroy.CreateDestroy {} +/** + * You must provide an error handler `addEventListener('error')`. + * Otherwise, errors will just be ignored. + * + * Events: + * - {@link events.EventRPCServerDestroy} + * - {@link events.EventRPCServerDestroyed} + */ +@createDestroy.CreateDestroy({ + eventDestroy: events.EventRPCServerDestroy, + eventDestroyed: events.EventRPCServerDestroyed, +}) +class RPCServer extends EventTarget { + /** + * Creates RPC server. + + * @param obj + * @param obj.manifest - Server manifest used to define the rpc method + * handlers. + * @param obj.middlewareFactory - Middleware used to process the rpc messages. + * The middlewareFactory needs to be a function that creates a pair of + * transform streams that convert `Uint8Array` to `JSONRPCRequest` on the forward + * path and `JSONRPCResponse` to `Uint8Array` on the reverse path. + * @param obj.sensitive - If true, sanitises any rpc error messages of any + * sensitive information. + * @param obj.streamKeepAliveTimeoutTime - Time before a connection is cleaned up due to no activity. This is the + * value used if the handler doesn't specify its own timeout time. This timeout is advisory and only results in a + * signal sent to the handler. Stream is forced to end after the timeoutForceCloseTime. Defaults to 60,000 + * milliseconds. + * @param obj.timeoutForceCloseTime - Time before the stream is forced to end after the initial timeout time. + * The stream will be forced to close after this amount of time after the initial timeout. This is a grace period for + * the handler to handle timeout before it is forced to end. Defaults to 2,000 milliseconds. + * @param obj.logger + */ + public static async createRPCServer({ + manifest, + middlewareFactory = rpcUtilsMiddleware.defaultServerMiddlewareWrapper(), + sensitive = false, + handlerTimeoutTime = Infinity, // 1 minute + logger = new Logger(this.name), + idGen = () => Promise.resolve(null), + fromError = rpcUtils.fromError, + replacer = rpcUtils.replacer, + }: { + manifest: ServerManifest; + middlewareFactory?: MiddlewareFactory< + JSONRPCRequest, + Uint8Array, + Uint8Array, + JSONRPCResponse + >; + sensitive?: boolean; + handlerTimeoutTime?: number; + logger?: Logger; + idGen: IdGen; + fromError?: (error: Error) => JSONValue; + replacer?: (key: string, value: any) => any; + }): Promise { + logger.info(`Creating ${this.name}`); + const rpcServer = new this({ + manifest, + middlewareFactory, + sensitive, + handlerTimeoutTime, + logger, + idGen, + fromError, + replacer, + }); + logger.info(`Created ${this.name}`); + return rpcServer; + } + protected onTimeoutCallback?: () => void; + protected idGen: IdGen; + protected logger: Logger; + protected handlerMap: Map = new Map(); + protected defaultTimeoutMap: Map = new Map(); + protected handlerTimeoutTime: number; + protected activeStreams: Set> = new Set(); + protected sensitive: boolean; + protected fromError: (error: Error, sensitive?: boolean) => JSONValue; + protected replacer: (key: string, value: any) => any; + protected middlewareFactory: MiddlewareFactory< + JSONRPCRequest, + Uint8Array, + Uint8Array, + JSONRPCResponseResult + >; + // Function to register a callback for timeout + public registerOnTimeoutCallback(callback: () => void) { + this.onTimeoutCallback = callback; + } + public constructor({ + manifest, + middlewareFactory, + sensitive, + handlerTimeoutTime = Infinity, // 1 minuet + logger, + idGen = () => Promise.resolve(null), + fromError = rpcUtils.fromError, + replacer = rpcUtils.replacer, + }: { + manifest: ServerManifest; + + middlewareFactory: MiddlewareFactory< + JSONRPCRequest, + Uint8Array, + Uint8Array, + JSONRPCResponseResult + >; + handlerTimeoutTime?: number; + sensitive: boolean; + logger: Logger; + idGen: IdGen; + fromError?: (error: Error) => JSONValue; + replacer?: (key: string, value: any) => any; + }) { + super(); + for (const [key, manifestItem] of Object.entries(manifest)) { + if (manifestItem instanceof RawHandler) { + this.registerRawStreamHandler( + key, + manifestItem.handle, + manifestItem.timeout, + ); + continue; + } + if (manifestItem instanceof DuplexHandler) { + this.registerDuplexStreamHandler( + key, + manifestItem.handle, + manifestItem.timeout, + ); + continue; + } + if (manifestItem instanceof ServerHandler) { + this.registerServerStreamHandler( + key, + manifestItem.handle, + manifestItem.timeout, + ); + continue; + } + if (manifestItem instanceof ClientHandler) { + this.registerClientStreamHandler( + key, + manifestItem.handle, + manifestItem.timeout, + ); + continue; + } + if (manifestItem instanceof ClientHandler) { + this.registerClientStreamHandler( + key, + manifestItem.handle, + manifestItem.timeout, + ); + continue; + } + if (manifestItem instanceof UnaryHandler) { + this.registerUnaryHandler( + key, + manifestItem.handle, + manifestItem.timeout, + ); + continue; + } + never(); + } + this.idGen = idGen; + this.middlewareFactory = middlewareFactory; + this.sensitive = sensitive; + this.handlerTimeoutTime = handlerTimeoutTime; + this.logger = logger; + this.fromError = fromError || rpcUtils.fromError; + this.replacer = replacer || rpcUtils.replacer; + } + + public async destroy(force: boolean = true): Promise { + // Log and dispatch an event before starting the destruction + this.logger.info(`Destroying ${this.constructor.name}`); + this.dispatchEvent(new events.EventRPCServerDestroy()); + + // Your existing logic for stopping active streams and other cleanup + if (force) { + for await (const [activeStream] of this.activeStreams.entries()) { + activeStream.cancel(new rpcErrors.ErrorRPCStopping()); + } + } + + for await (const [activeStream] of this.activeStreams.entries()) { + await activeStream; + } + + // Log and dispatch an event after the destruction has been completed + this.dispatchEvent(new events.EventRPCServerDestroyed()); + this.logger.info(`Destroyed ${this.constructor.name}`); + } + + /** + * Registers a raw stream handler. This is the basis for all handlers as + * handling the streams is done with raw streams only. + * The raw streams do not automatically refresh the timeout timer when + * messages are sent or received. + */ + protected registerRawStreamHandler( + method: string, + handler: RawHandlerImplementation, + timeout: number | undefined, + ) { + this.handlerMap.set(method, handler); + this.defaultTimeoutMap.set(method, timeout); + } + + /** + * Registers a duplex stream handler. + * This handles all message parsing and conversion from generators + * to raw streams. + * + * @param method - The rpc method name. + * @param handler - The handler takes an input async iterable and returns an output async iterable. + * @param timeout + */ + /** + * The ID is generated only once when the function is called and stored in the id variable. + * the ID is associated with the entire stream + * Every response (whether successful or an error) produced within this stream will have the + * same ID, which is consistent with the originating request. + */ + protected registerDuplexStreamHandler< + I extends JSONValue, + O extends JSONValue, + >( + method: string, + handler: DuplexHandlerImplementation, + timeout: number | undefined, + ): void { + const rawSteamHandler: RawHandlerImplementation = async ( + [header, input], + cancel, + meta, + ctx, + ) => { + // Setting up abort controller + const abortController = new AbortController(); + if (ctx.signal.aborted) abortController.abort(ctx.signal.reason); + ctx.signal.addEventListener('abort', () => { + abortController.abort(ctx.signal.reason); + }); + const signal = abortController.signal; + // Setting up middleware + const middleware = this.middlewareFactory(ctx, cancel, meta); + // Forward from the client to the server + // Transparent TransformStream that re-inserts the header message into the + // stream. + const headerStream = new TransformStream({ + start(controller) { + controller.enqueue(Buffer.from(JSON.stringify(header))); + }, + transform(chunk, controller) { + controller.enqueue(chunk); + }, + }); + const forwardStream = input + .pipeThrough(headerStream) + .pipeThrough(middleware.forward); + // Reverse from the server to the client + const reverseStream = middleware.reverse.writable; + // Generator derived from handler + const id = await this.idGen(); + const outputGen = async function* (): AsyncGenerator { + if (signal.aborted) throw signal.reason; + // Input generator derived from the forward stream + const inputGen = async function* (): AsyncIterable { + for await (const data of forwardStream) { + ctx.timer.refresh(); + yield data.params as I; + } + }; + const handlerG = handler(inputGen(), cancel, meta, { + signal, + timer: ctx.timer, + }); + for await (const response of handlerG) { + ctx.timer.refresh(); + const responseMessage: JSONRPCResponseResult = { + jsonrpc: '2.0', + result: response, + id, + }; + yield responseMessage; + } + }; + const outputGenerator = outputGen(); + const reverseMiddlewareStream = new ReadableStream({ + pull: async (controller) => { + try { + const { value, done } = await outputGenerator.next(); + if (done) { + controller.close(); + return; + } + controller.enqueue(value); + } catch (e) { + const rpcError: JSONRPCError = { + code: e.exitCode ?? JSONRPCErrorCode.InternalError, + message: e.description ?? '', + data: JSON.stringify(this.fromError(e), this.replacer), + }; + const rpcErrorMessage: JSONRPCResponseError = { + jsonrpc: '2.0', + error: rpcError, + id, + }; + controller.enqueue(rpcErrorMessage); + // Clean up the input stream here, ignore error if already ended + await forwardStream + .cancel( + new rpcErrors.ErrorRPCHandlerFailed('Error clean up', { + cause: e, + }), + ) + .catch(() => {}); + controller.close(); + } + }, + cancel: async (reason) => { + this.dispatchEvent( + new rpcEvents.RPCErrorEvent({ + detail: new rpcErrors.ErrorRPCStreamEnded( + 'Stream has been cancelled', + { + cause: reason, + }, + ), + }), + ); + // Abort with the reason + abortController.abort(reason); + // If the output stream path fails then we need to end the generator + // early. + await outputGenerator.return(undefined); + }, + }); + // Ignore any errors here, it should propagate to the ends of the stream + void reverseMiddlewareStream.pipeTo(reverseStream).catch(() => {}); + return [undefined, middleware.reverse.readable]; + }; + this.registerRawStreamHandler(method, rawSteamHandler, timeout); + } + + protected registerUnaryHandler( + method: string, + handler: UnaryHandlerImplementation, + timeout: number | undefined, + ) { + const wrapperDuplex: DuplexHandlerImplementation = async function* ( + input, + cancel, + meta, + ctx, + ) { + // The `input` is expected to be an async iterable with only 1 value. + // Unlike generators, there is no `next()` method. + // So we use `break` after the first iteration. + for await (const inputVal of input) { + yield await handler(inputVal, cancel, meta, ctx); + break; + } + }; + this.registerDuplexStreamHandler(method, wrapperDuplex, timeout); + } + + protected registerServerStreamHandler< + I extends JSONValue, + O extends JSONValue, + >( + method: string, + handler: ServerHandlerImplementation, + timeout: number | undefined, + ) { + const wrapperDuplex: DuplexHandlerImplementation = async function* ( + input, + cancel, + meta, + ctx, + ) { + for await (const inputVal of input) { + yield* handler(inputVal, cancel, meta, ctx); + break; + } + }; + this.registerDuplexStreamHandler(method, wrapperDuplex, timeout); + } + + protected registerClientStreamHandler< + I extends JSONValue, + O extends JSONValue, + >( + method: string, + handler: ClientHandlerImplementation, + timeout: number | undefined, + ) { + const wrapperDuplex: DuplexHandlerImplementation = async function* ( + input, + cancel, + meta, + ctx, + ) { + yield await handler(input, cancel, meta, ctx); + }; + this.registerDuplexStreamHandler(method, wrapperDuplex, timeout); + } + + /** + * ID is associated with the stream, not individual messages. + */ + @ready(new rpcErrors.ErrorRPCHandlerFailed()) + public handleStream(rpcStream: RPCStream) { + // This will take a buffer stream of json messages and set up service + // handling for it. + // Constructing the PromiseCancellable for tracking the active stream + const abortController = new AbortController(); + // Setting up timeout timer logic + const timer = new Timer({ + delay: this.handlerTimeoutTime, + handler: () => { + abortController.abort(new rpcErrors.ErrorRPCTimedOut()); + if (this.onTimeoutCallback) { + this.onTimeoutCallback(); + } + }, + }); + + const prom = (async () => { + const id = await this.idGen(); + const headTransformStream = rpcUtilsMiddleware.binaryToJsonMessageStream( + rpcUtils.parseJSONRPCRequest, + ); + // Transparent transform used as a point to cancel the input stream from + const passthroughTransform = new TransformStream< + Uint8Array, + Uint8Array + >(); + const inputStream = passthroughTransform.readable; + const inputStreamEndProm = rpcStream.readable + .pipeTo(passthroughTransform.writable) + // Ignore any errors here, we only care that it ended + .catch(() => {}); + void inputStream + // Allow us to re-use the readable after reading the first message + .pipeTo(headTransformStream.writable, { + preventClose: true, + preventCancel: true, + }) + // Ignore any errors here, we only care that it ended + .catch(() => {}); + const cleanUp = async (reason: any) => { + await inputStream.cancel(reason); + await rpcStream.writable.abort(reason); + await inputStreamEndProm; + timer.cancel(cleanupReason); + await timer.catch(() => {}); + }; + // Read a single empty value to consume the first message + const reader = headTransformStream.readable.getReader(); + // Allows timing out when waiting for the first message + let headerMessage: + | ReadableStreamDefaultReadResult + | undefined + | void; + try { + headerMessage = await Promise.race([ + reader.read(), + timer.then( + () => undefined, + () => {}, + ), + ]); + } catch (e) { + const newErr = new rpcErrors.ErrorRPCHandlerFailed( + 'Stream failed waiting for header', + { cause: e }, + ); + await inputStreamEndProm; + timer.cancel(cleanupReason); + await timer.catch(() => {}); + this.dispatchEvent( + new rpcEvents.RPCErrorEvent({ + detail: new rpcErrors.ErrorRPCOutputStreamError( + 'Stream failed waiting for header', + { + cause: newErr, + }, + ), + }), + ); + return; + } + // Downgrade back to the raw stream + await reader.cancel(); + // There are 2 conditions where we just end here + // 1. The timeout timer resolves before the first message + // 2. the stream ends before the first message + if (headerMessage == null) { + const newErr = new rpcErrors.ErrorRPCTimedOut( + 'Timed out waiting for header', + { cause: new rpcErrors.ErrorRPCStreamEnded() }, + ); + await cleanUp(newErr); + this.dispatchEvent( + new rpcEvents.RPCErrorEvent({ + detail: new rpcErrors.ErrorRPCTimedOut( + 'Timed out waiting for header', + { + cause: newErr, + }, + ), + }), + ); + return; + } + if (headerMessage.done) { + const newErr = new rpcErrors.ErrorMissingHeader('Missing header'); + await cleanUp(newErr); + this.dispatchEvent( + new rpcEvents.RPCErrorEvent({ + detail: new rpcErrors.ErrorRPCOutputStreamError('Missing header', { + cause: newErr, + }), + }), + ); + return; + } + const method = headerMessage.value.method; + const handler = this.handlerMap.get(method); + if (handler == null) { + await cleanUp(new rpcErrors.ErrorRPCHandlerFailed('Missing handler')); + return; + } + if (abortController.signal.aborted) { + await cleanUp( + new rpcErrors.ErrorHandlerAborted('Aborted', { + cause: new ErrorHandlerAborted(), + }), + ); + return; + } + // Setting up Timeout logic + const timeout = this.defaultTimeoutMap.get(method); + if (timeout != null && timeout < this.handlerTimeoutTime) { + // Reset timeout with new delay if it is less than the default + timer.reset(timeout); + } else { + // Otherwise refresh + timer.refresh(); + } + this.logger.info(`Handling stream with method (${method})`); + let handlerResult: [JSONValue | undefined, ReadableStream]; + const headerWriter = rpcStream.writable.getWriter(); + try { + handlerResult = await handler( + [headerMessage.value, inputStream], + rpcStream.cancel, + rpcStream.meta, + { signal: abortController.signal, timer }, + ); + } catch (e) { + const rpcError: JSONRPCError = { + code: e.exitCode ?? JSONRPCErrorCode.InternalError, + message: e.description ?? '', + data: JSON.stringify(this.fromError(e), this.replacer), + }; + const rpcErrorMessage: JSONRPCResponseError = { + jsonrpc: '2.0', + error: rpcError, + id, + }; + await headerWriter.write(Buffer.from(JSON.stringify(rpcErrorMessage))); + await headerWriter.close(); + // Clean up and return + timer.cancel(cleanupReason); + rpcStream.cancel(Error('TMP header message was an error')); + return; + } + const [leadingResult, outputStream] = handlerResult; + + if (leadingResult !== undefined) { + // Writing leading metadata + const leadingMessage: JSONRPCResponseResult = { + jsonrpc: '2.0', + result: leadingResult, + id, + }; + await headerWriter.write(Buffer.from(JSON.stringify(leadingMessage))); + } + headerWriter.releaseLock(); + const outputStreamEndProm = outputStream + .pipeTo(rpcStream.writable) + .catch(() => {}); // Ignore any errors, we only care that it finished + await Promise.allSettled([inputStreamEndProm, outputStreamEndProm]); + this.logger.info(`Handled stream with method (${method})`); + // Cleaning up abort and timer + timer.cancel(cleanupReason); + abortController.abort(new rpcErrors.ErrorRPCStreamEnded()); + })(); + const handlerProm = PromiseCancellable.from(prom, abortController).finally( + () => this.activeStreams.delete(handlerProm), + abortController, + ); + // Putting the PromiseCancellable into the active streams map + this.activeStreams.add(handlerProm); + } +} + +export default RPCServer; diff --git a/src/callers/Caller.ts b/src/callers/Caller.ts new file mode 100644 index 0000000..ddc54a8 --- /dev/null +++ b/src/callers/Caller.ts @@ -0,0 +1,13 @@ +import type { HandlerType, JSONValue } from '../types'; + +abstract class Caller< + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> { + protected _inputType: Input; + protected _outputType: Output; + // Need this to distinguish the classes when inferring types + abstract type: HandlerType; +} + +export default Caller; diff --git a/src/callers/ClientCaller.ts b/src/callers/ClientCaller.ts new file mode 100644 index 0000000..7fb44da --- /dev/null +++ b/src/callers/ClientCaller.ts @@ -0,0 +1,11 @@ +import type { JSONValue } from '../types'; +import Caller from './Caller'; + +class ClientCaller< + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Caller { + public type: 'CLIENT' = 'CLIENT' as const; +} + +export default ClientCaller; diff --git a/src/callers/DuplexCaller.ts b/src/callers/DuplexCaller.ts new file mode 100644 index 0000000..4c079b3 --- /dev/null +++ b/src/callers/DuplexCaller.ts @@ -0,0 +1,11 @@ +import type { JSONValue } from '../types'; +import Caller from './Caller'; + +class DuplexCaller< + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Caller { + public type: 'DUPLEX' = 'DUPLEX' as const; +} + +export default DuplexCaller; diff --git a/src/callers/RawCaller.ts b/src/callers/RawCaller.ts new file mode 100644 index 0000000..a4721cf --- /dev/null +++ b/src/callers/RawCaller.ts @@ -0,0 +1,7 @@ +import type { JSONValue } from '../types'; +import Caller from './Caller'; +class RawCaller extends Caller { + public type: 'RAW' = 'RAW' as const; +} + +export default RawCaller; diff --git a/src/callers/ServerCaller.ts b/src/callers/ServerCaller.ts new file mode 100644 index 0000000..11a9fe9 --- /dev/null +++ b/src/callers/ServerCaller.ts @@ -0,0 +1,11 @@ +import type { JSONValue } from '../types'; +import Caller from './Caller'; + +class ServerCaller< + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Caller { + public type: 'SERVER' = 'SERVER' as const; +} + +export default ServerCaller; diff --git a/src/callers/UnaryCaller.ts b/src/callers/UnaryCaller.ts new file mode 100644 index 0000000..c446073 --- /dev/null +++ b/src/callers/UnaryCaller.ts @@ -0,0 +1,11 @@ +import type { JSONValue } from '../types'; +import Caller from './Caller'; + +class UnaryCaller< + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Caller { + public type: 'UNARY' = 'UNARY' as const; +} + +export default UnaryCaller; diff --git a/src/callers/index.ts b/src/callers/index.ts new file mode 100644 index 0000000..17e8c87 --- /dev/null +++ b/src/callers/index.ts @@ -0,0 +1,6 @@ +export { default as Caller } from './Caller'; +export { default as ClientCaller } from './ClientCaller'; +export { default as DuplexCaller } from './DuplexCaller'; +export { default as RawCaller } from './RawCaller'; +export { default as ServerCaller } from './ServerCaller'; +export { default as UnaryCaller } from './UnaryCaller'; diff --git a/src/errors/errors.ts b/src/errors/errors.ts new file mode 100644 index 0000000..2acc942 --- /dev/null +++ b/src/errors/errors.ts @@ -0,0 +1,266 @@ +import type { Class } from '@matrixai/errors'; +import type { JSONValue } from '@/types'; +import { AbstractError } from '@matrixai/errors'; + +const enum JSONRPCErrorCode { + ParseError = -32700, + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603, + HandlerNotFound = -32000, + RPCStopping = -32001, + RPCDestroyed = -32002, + RPCMessageLength = -32003, + RPCMissingResponse = -32004, + RPCOutputStreamError = -32005, + RPCRemote = -32006, + RPCStreamEnded = -32007, + RPCTimedOut = -32008, + RPCConnectionLocal = -32010, + RPCConnectionPeer = -32011, + RPCConnectionKeepAliveTimeOut = -32012, + RPCConnectionInternal = -32013, + MissingHeader = -32014, + HandlerAborted = -32015, + MissingCaller = -32016, +} +interface RPCError extends Error { + code?: number; +} +class ErrorRPC extends AbstractError implements RPCError { + private _description: string = 'Generic Error'; + constructor(message?: string) { + super(message); + } + code?: number; + + get description(): string { + return this._description; + } + + set description(value: string) { + this._description = value; + } +} + +class ErrorRPCDestroyed extends ErrorRPC { + constructor(message?: string) { + super(message); // Call the parent constructor + this.description = 'Rpc is destroyed'; // Set the specific description + this.code = JSONRPCErrorCode.MethodNotFound; + } +} + +class ErrorRPCParse extends ErrorRPC { + static description = 'Failed to parse Buffer stream'; + + constructor(message?: string, options?: { cause: Error }) { + super(message); // Call the parent constructor + this.description = 'Failed to parse Buffer stream'; // Set the specific description + this.code = JSONRPCErrorCode.ParseError; + } +} + +class ErrorRPCStopping extends ErrorRPC { + constructor(message?: string) { + super(message); // Call the parent constructor + this.description = 'Rpc is stopping'; // Set the specific description + this.code = JSONRPCErrorCode.RPCStopping; + } +} + +/** + * This is an internal error, it should not reach the top level. + */ +class ErrorRPCHandlerFailed extends ErrorRPC { + constructor(message?: string, options?: { cause: Error }) { + super(message); // Call the parent constructor + this.description = 'Failed to handle stream'; // Set the specific description + this.code = JSONRPCErrorCode.HandlerNotFound; + } +} +class ErrorRPCCallerFailed extends ErrorRPC { + constructor(message?: string, options?: { cause: Error }) { + super(message); // Call the parent constructor + this.description = 'Failed to call stream'; // Set the specific description + this.code = JSONRPCErrorCode.MissingCaller; + } +} +class ErrorMissingCaller extends ErrorRPC { + constructor(message?: string, options?: { cause: Error }) { + super(message); // Call the parent constructor + this.description = 'Header information is missing'; // Set the specific description + this.code = JSONRPCErrorCode.MissingCaller; + } +} +class ErrorMissingHeader extends ErrorRPC { + constructor(message?: string, options?: { cause: Error }) { + super(message); // Call the parent constructor + this.description = 'Header information is missing'; // Set the specific description + this.code = JSONRPCErrorCode.MissingHeader; + } +} + +class ErrorHandlerAborted extends ErrorRPC { + constructor(message?: string, options?: { cause: Error }) { + super(message); // Call the parent constructor + this.description = 'Handler Aborted Stream.'; // Set the specific description + this.code = JSONRPCErrorCode.HandlerAborted; + } +} +class ErrorRPCMessageLength extends ErrorRPC { + static description = 'RPC Message exceeds maximum size'; + code? = JSONRPCErrorCode.RPCMessageLength; +} + +class ErrorRPCMissingResponse extends ErrorRPC { + constructor(message?: string) { + super(message); + this.description = 'Stream ended before response'; + this.code = JSONRPCErrorCode.RPCMissingResponse; + } +} + +interface ErrorRPCOutputStreamErrorOptions { + cause?: Error; +} +class ErrorRPCOutputStreamError extends ErrorRPC { + constructor(message: string, options: ErrorRPCOutputStreamErrorOptions) { + super(message); + this.description = 'Output stream failed, unable to send data'; + this.code = JSONRPCErrorCode.RPCOutputStreamError; + } +} + +class ErrorRPCRemote extends ErrorRPC { + static description = 'Remote error from RPC call'; + static message: string = 'The server responded with an error'; + metadata: JSONValue | undefined; + + constructor(metadata?: JSONValue, message?: string, options?) { + super(message); + this.metadata = metadata; + this.code = JSONRPCErrorCode.RPCRemote; + this.data = options?.data; + } + + public static fromJSON>( + this: T, + json: any, + ): InstanceType { + if ( + typeof json !== 'object' || + json.type !== this.name || + typeof json.data !== 'object' || + typeof json.data.message !== 'string' || + isNaN(Date.parse(json.data.timestamp)) || + typeof json.data.metadata !== 'object' || + typeof json.data.data !== 'object' || + ('stack' in json.data && typeof json.data.stack !== 'string') + ) { + throw new TypeError(`Cannot decode JSON to ${this.name}`); + } + + // Here, you can define your own metadata object, or just use the one from JSON directly. + const parsedMetadata = json.data.metadata; + + const e = new this(parsedMetadata, json.data.message, { + timestamp: new Date(json.data.timestamp), + data: json.data.data, + cause: json.data.cause, + }); + e.stack = json.data.stack; + return e; + } + public toJSON(): any { + return { + type: this.name, + data: { + description: this.description, + }, + }; + } +} + +class ErrorRPCStreamEnded extends ErrorRPC { + constructor(message?: string, options?: { cause: Error }) { + super(message); + this.description = 'Handled stream has ended'; + this.code = JSONRPCErrorCode.RPCStreamEnded; + } +} + +class ErrorRPCTimedOut extends ErrorRPC { + constructor(message?: string, options?: { cause: Error }) { + super(message); + this.description = 'RPC handler has timed out'; + this.code = JSONRPCErrorCode.RPCTimedOut; + } +} + +class ErrorUtilsUndefinedBehaviour extends ErrorRPC { + constructor(message?: string) { + super(message); + this.description = 'You should never see this error'; + this.code = JSONRPCErrorCode.MethodNotFound; + } +} +export function never(): never { + throw new ErrorRPC('This function should never be called'); +} + +class ErrorRPCMethodNotImplemented extends ErrorRPC { + constructor(message?: string) { + super(message || 'This method must be overridden'); // Default message if none provided + this.name = 'ErrorRPCMethodNotImplemented'; + this.description = + 'This abstract method must be implemented in a derived class'; + this.code = JSONRPCErrorCode.MethodNotFound; + } +} + +class ErrorRPCConnectionLocal extends ErrorRPC { + static description = 'RPC Connection local error'; + code? = JSONRPCErrorCode.RPCConnectionLocal; +} + +class ErrorRPCConnectionPeer extends ErrorRPC { + static description = 'RPC Connection peer error'; + code? = JSONRPCErrorCode.RPCConnectionPeer; +} + +class ErrorRPCConnectionKeepAliveTimeOut extends ErrorRPC { + static description = 'RPC Connection keep alive timeout'; + code? = JSONRPCErrorCode.RPCConnectionKeepAliveTimeOut; +} + +class ErrorRPCConnectionInternal extends ErrorRPC { + static description = 'RPC Connection internal error'; + code? = JSONRPCErrorCode.RPCConnectionInternal; +} + +export { + ErrorRPC, + ErrorRPCDestroyed, + ErrorRPCStopping, + ErrorRPCParse, + ErrorRPCHandlerFailed, + ErrorRPCMessageLength, + ErrorRPCMissingResponse, + ErrorRPCOutputStreamError, + ErrorRPCRemote, + ErrorRPCStreamEnded, + ErrorRPCTimedOut, + ErrorUtilsUndefinedBehaviour, + ErrorRPCMethodNotImplemented, + ErrorRPCConnectionLocal, + ErrorRPCConnectionPeer, + ErrorRPCConnectionKeepAliveTimeOut, + ErrorRPCConnectionInternal, + ErrorMissingHeader, + ErrorHandlerAborted, + ErrorRPCCallerFailed, + ErrorMissingCaller, + JSONRPCErrorCode, +}; diff --git a/src/errors/index.ts b/src/errors/index.ts new file mode 100644 index 0000000..f72bc43 --- /dev/null +++ b/src/errors/index.ts @@ -0,0 +1 @@ +export * from './errors'; diff --git a/src/events.ts b/src/events.ts new file mode 100644 index 0000000..828cca4 --- /dev/null +++ b/src/events.ts @@ -0,0 +1,85 @@ +import type RPCServer from './RPCServer'; +import type RPCClient from './RPCClient'; +import type { + ErrorRPCConnectionLocal, + ErrorRPCConnectionPeer, + ErrorRPCConnectionKeepAliveTimeOut, + ErrorRPCConnectionInternal, +} from './errors'; +import { AbstractEvent } from '@matrixai/events'; +import * as rpcErrors from './errors'; + +abstract class EventRPC extends AbstractEvent {} + +abstract class EventRPCClient extends AbstractEvent {} + +abstract class EventRPCServer extends AbstractEvent {} + +abstract class EventRPCConnection extends AbstractEvent {} + +// Client events +class EventRPCClientDestroy extends EventRPCClient {} + +class EventRPCClientDestroyed extends EventRPCClient {} + +class EventRPCClientCreate extends EventRPCClient {} + +class EventRPCClientCreated extends EventRPCClient {} + +class EventRPCClientError extends EventRPCClient {} + +class EventRPCClientConnect extends EventRPCClient {} + +// Server events + +class EventRPCServerConnection extends EventRPCServer {} + +class EventRPCServerCreate extends EventRPCServer {} + +class EventRPCServerCreated extends EventRPCServer {} + +class EventRPCServerDestroy extends EventRPCServer {} + +class EventRPCServerDestroyed extends EventRPCServer {} + +class EventRPCServerError extends EventRPCServer {} + +class EventRPCConnectionError extends EventRPCConnection< + | ErrorRPCConnectionLocal + | ErrorRPCConnectionPeer + | ErrorRPCConnectionKeepAliveTimeOut + | ErrorRPCConnectionInternal +> {} + +class RPCErrorEvent extends Event { + public detail: Error; + constructor( + options: EventInit & { + detail: Error; + }, + ) { + super('error', options); + this.detail = options.detail; + } +} + +export { + RPCErrorEvent, + EventRPC, + EventRPCClient, + EventRPCServer, + EventRPCConnection, + EventRPCClientDestroy, + EventRPCClientDestroyed, + EventRPCClientCreate, + EventRPCClientCreated, + EventRPCClientError, + EventRPCClientConnect, + EventRPCServerConnection, + EventRPCServerCreate, + EventRPCServerCreated, + EventRPCServerDestroy, + EventRPCServerDestroyed, + EventRPCServerError, + EventRPCConnectionError, +}; diff --git a/src/handlers/ClientHandler.ts b/src/handlers/ClientHandler.ts new file mode 100644 index 0000000..0aea354 --- /dev/null +++ b/src/handlers/ClientHandler.ts @@ -0,0 +1,21 @@ +import type { ContainerType, JSONValue } from '../types'; +import type { ContextTimed } from '@matrixai/contexts'; +import Handler from './Handler'; +import { ErrorRPCMethodNotImplemented } from '../errors'; + +abstract class ClientHandler< + Container extends ContainerType = ContainerType, + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Handler { + public handle = async ( + input: AsyncIterableIterator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): Promise => { + throw new ErrorRPCMethodNotImplemented(); + }; +} + +export default ClientHandler; diff --git a/src/handlers/DuplexHandler.ts b/src/handlers/DuplexHandler.ts new file mode 100644 index 0000000..6534ef6 --- /dev/null +++ b/src/handlers/DuplexHandler.ts @@ -0,0 +1,26 @@ +import type { ContainerType, JSONValue } from '../types'; +import type { ContextTimed } from '@matrixai/contexts'; +import Handler from './Handler'; +import { ErrorRPCMethodNotImplemented } from '../errors'; + +abstract class DuplexHandler< + Container extends ContainerType = ContainerType, + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Handler { + /** + * Note that if the output has an error, the handler will not see this as an + * error. If you need to handle any clean up it should be handled in a + * `finally` block and check the abort signal for potential errors. + */ + public handle = async function* ( + input: AsyncIterableIterator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncIterableIterator { + throw new ErrorRPCMethodNotImplemented('This method must be overwrtitten.'); + }; +} + +export default DuplexHandler; diff --git a/src/handlers/Handler.ts b/src/handlers/Handler.ts new file mode 100644 index 0000000..fbf2f4e --- /dev/null +++ b/src/handlers/Handler.ts @@ -0,0 +1,19 @@ +import type { ContainerType, JSONValue } from '../types'; +abstract class Handler< + Container extends ContainerType = ContainerType, + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> { + // These are used to distinguish the handlers in the type system. + // Without these the map types can't tell the types of handlers apart. + protected _inputType: Input; + protected _outputType: Output; + /** + * This is the timeout used for the handler. + * If it is not set then the default timeout time for the `RPCServer` is used. + */ + public timeout?: number; + + constructor(protected container: Container) {} +} +export default Handler; diff --git a/src/handlers/RawHandler.ts b/src/handlers/RawHandler.ts new file mode 100644 index 0000000..01fd1d7 --- /dev/null +++ b/src/handlers/RawHandler.ts @@ -0,0 +1,20 @@ +import type { ContextTimed } from '@matrixai/contexts'; +import type { ReadableStream } from 'stream/web'; +import type { ContainerType, JSONRPCRequest, JSONValue } from '../types'; +import Handler from './Handler'; +import { ErrorRPCMethodNotImplemented } from '../errors'; + +abstract class RawHandler< + Container extends ContainerType = ContainerType, +> extends Handler { + public handle = async ( + input: [JSONRPCRequest, ReadableStream], + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): Promise<[JSONValue, ReadableStream]> => { + throw new ErrorRPCMethodNotImplemented('This method must be overridden'); + }; +} + +export default RawHandler; diff --git a/src/handlers/ServerHandler.ts b/src/handlers/ServerHandler.ts new file mode 100644 index 0000000..bebd177 --- /dev/null +++ b/src/handlers/ServerHandler.ts @@ -0,0 +1,21 @@ +import type { ContextTimed } from '@matrixai/contexts'; +import type { ContainerType, JSONValue } from '../types'; +import Handler from './Handler'; +import { ErrorRPCMethodNotImplemented } from '../errors'; + +abstract class ServerHandler< + Container extends ContainerType = ContainerType, + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Handler { + public handle = async function* ( + input: Input, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncIterableIterator { + throw new ErrorRPCMethodNotImplemented('This method must be overridden'); + }; +} + +export default ServerHandler; diff --git a/src/handlers/UnaryHandler.ts b/src/handlers/UnaryHandler.ts new file mode 100644 index 0000000..0a1e37b --- /dev/null +++ b/src/handlers/UnaryHandler.ts @@ -0,0 +1,21 @@ +import type { ContextTimed } from '@matrixai/contexts'; +import type { ContainerType, JSONValue } from '../types'; +import Handler from './Handler'; +import { ErrorRPCMethodNotImplemented } from '../errors'; + +abstract class UnaryHandler< + Container extends ContainerType = ContainerType, + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Handler { + public handle = async ( + input: Input, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): Promise => { + throw new ErrorRPCMethodNotImplemented('This method must be overridden'); + }; +} + +export default UnaryHandler; diff --git a/src/handlers/index.ts b/src/handlers/index.ts new file mode 100644 index 0000000..2df4ee2 --- /dev/null +++ b/src/handlers/index.ts @@ -0,0 +1,6 @@ +export { default as Handler } from './Handler'; +export { default as ClientHandler } from './ClientHandler'; +export { default as DuplexHandler } from './DuplexHandler'; +export { default as RawHandler } from './RawHandler'; +export { default as ServerHandler } from './ServerHandler'; +export { default as UnaryHandler } from './UnaryHandler'; diff --git a/src/index.ts b/src/index.ts index e69de29..c3a5052 100644 --- a/src/index.ts +++ b/src/index.ts @@ -0,0 +1,8 @@ +export { default as RPCClient } from './RPCClient'; +export { default as RPCServer } from './RPCServer'; +export * as utils from './utils'; +export * as types from './types'; +export * as errors from './errors'; +export * as events from './events'; +export * as handlers from './handlers'; +export * as callers from './callers'; diff --git a/src/types.ts b/src/types.ts new file mode 100644 index 0000000..97e367c --- /dev/null +++ b/src/types.ts @@ -0,0 +1,367 @@ +import type { ReadableStream, ReadableWritablePair } from 'stream/web'; +import type { ContextTimed, ContextTimedInput } from '@matrixai/contexts'; +import type { Caller } from './callers'; +import type { RawCaller } from './callers'; +import type { DuplexCaller } from './callers'; +import type { ServerCaller } from './callers'; +import type { ClientCaller } from './callers'; +import type { UnaryCaller } from './callers'; +import type Handler from './handlers/Handler'; + +/** + * This is the type for the IdGenFunction. It is used to generate the request + */ +type IdGen = () => PromiseLike; + +/** + * This is the JSON RPC request object. this is the generic message type used for the RPC. + */ +type JSONRPCRequestMessage = { + /** + * A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" + */ + jsonrpc: '2.0'; + /** + * A String containing the name of the method to be invoked. Method names that begin with the word rpc followed by a + * period character (U+002E or ASCII 46) are reserved for rpc-internal methods and extensions and MUST NOT be used + * for anything else. + */ + method: string; + /** + * A Structured value that holds the parameter values to be used during the invocation of the method. + * This member MAY be omitted. + */ + params?: T; + /** + * An identifier established by the Client that MUST contain a String, Number, or NULL value if included. + * If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers + * SHOULD NOT contain fractional parts [2] + */ + id: string | number | null; +}; + +/** + * This is the JSON RPC notification object. this is used for a request that + * doesn't expect a response. + */ +type JSONRPCRequestNotification = { + /** + * A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" + */ + jsonrpc: '2.0'; + /** + * A String containing the name of the method to be invoked. Method names that begin with the word rpc followed by a + * period character (U+002E or ASCII 46) are reserved for rpc-internal methods and extensions and MUST NOT be used + * for anything else. + */ + method: string; + /** + * A Structured value that holds the parameter values to be used during the invocation of the method. + * This member MAY be omitted. + */ + params?: T; +}; + +/** + * This is the JSON RPC response result object. It contains the response data for a + * corresponding request. + */ +type JSONRPCResponseResult = { + /** + * A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0". + */ + jsonrpc: '2.0'; + /** + * This member is REQUIRED on success. + * This member MUST NOT exist if there was an error invoking the method. + * The value of this member is determined by the method invoked on the Server. + */ + result: T; + /** + * This member is REQUIRED. + * It MUST be the same as the value of the id member in the Request Object. + * If there was an error in detecting the id in the Request object (e.g. Parse error/Invalid Request), + * it MUST be Null. + */ + id: string | number | null; +}; + +/** + * This is the JSON RPC response Error object. It contains any errors that have + * occurred when responding to a request. + */ +type JSONRPCResponseError = { + /** + * A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0". + */ + jsonrpc: '2.0'; + /** + * This member is REQUIRED on error. + * This member MUST NOT exist if there was no error triggered during invocation. + * The value for this member MUST be an Object as defined in section 5.1. + */ + error: JSONRPCError; + /** + * This member is REQUIRED. + * It MUST be the same as the value of the id member in the Request Object. + * If there was an error in detecting the id in the Request object (e.g. Parse error/Invalid Request), + * it MUST be Null. + */ + id: string | number | null; +}; + +/** + * This is a JSON RPC error object, it encodes the error data for the JSONRPCResponseError object. + */ +type JSONRPCError = { + /** + * A Number that indicates the error type that occurred. + * This MUST be an integer. + */ + code: number; + /** + * A String providing a short description of the error. + * The message SHOULD be limited to a concise single sentence. + */ + message: string; + /** + * A Primitive or Structured value that contains additional information about the error. + * This may be omitted. + * The value of this member is defined by the Server (e.g. detailed error information, nested errors etc.). + */ + data?: JSONValue; +}; + +/** + * This is the JSON RPC Request object. It can be a request message or + * notification. + */ +type JSONRPCRequest = + | JSONRPCRequestMessage + | JSONRPCRequestNotification; + +/** + * This is a JSON RPC response object. It can be a response result or error. + */ +type JSONRPCResponse = + | JSONRPCResponseResult + | JSONRPCResponseError; + +/** + * This is a JSON RPC Message object. This is top level and can be any kind of + * message. + */ +type JSONRPCMessage = + | JSONRPCRequest + | JSONRPCResponse; + +// Handler types +type HandlerImplementation = ( + input: I, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, +) => O; + +type RawHandlerImplementation = HandlerImplementation< + [JSONRPCRequest, ReadableStream], + Promise<[JSONValue | undefined, ReadableStream]> +>; + +type DuplexHandlerImplementation< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = HandlerImplementation, AsyncIterable>; + +type ServerHandlerImplementation< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = HandlerImplementation>; + +type ClientHandlerImplementation< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = HandlerImplementation, Promise>; + +type UnaryHandlerImplementation< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = HandlerImplementation>; + +type ContainerType = Record; + +/** + * This interface extends the `ReadableWritablePair` with a method to cancel + * the connection. It also includes some optional generic metadata. This is + * mainly used as the return type for the `StreamFactory`. But the interface + * can be propagated across the RPC system. + */ +interface RPCStream< + R, + W, + M extends Record = Record, +> extends ReadableWritablePair { + cancel: (reason?: any) => void; + meta?: M; +} + +/** + * This is a factory for creating a `RPCStream` when making a RPC call. + * The transport mechanism is a black box to the RPC system. So long as it is + * provided as a RPCStream the RPC system should function. It is assumed that + * the RPCStream communicates with an `RPCServer`. + */ +type StreamFactory = ( + ctx: ContextTimed, +) => PromiseLike>; + +/** + * Middleware factory creates middlewares. + * Each middleware is a pair of forward and reverse. + * Each forward and reverse is a `ReadableWritablePair`. + * The forward pair is used transform input from client to server. + * The reverse pair is used to transform output from server to client. + * FR, FW is the readable and writable types of the forward pair. + * RR, RW is the readable and writable types of the reverse pair. + * FW -> FR is the direction of data flow from client to server. + * RW -> RR is the direction of data flow from server to client. + */ +type MiddlewareFactory = ( + ctx: ContextTimed, + cancel: (reason?: any) => void, + meta: Record | undefined, +) => { + forward: ReadableWritablePair; + reverse: ReadableWritablePair; +}; + +// Convenience callers + +type UnaryCallerImplementation< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = (parameters: I, ctx?: Partial) => Promise; + +type ServerCallerImplementation< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = ( + parameters: I, + ctx?: Partial, +) => Promise>; + +type ClientCallerImplementation< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = ( + ctx?: Partial, +) => Promise<{ output: Promise; writable: WritableStream }>; + +type DuplexCallerImplementation< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = (ctx?: Partial) => Promise>; + +type RawCallerImplementation = ( + headerParams: JSONValue, + ctx?: Partial, +) => Promise< + RPCStream< + Uint8Array, + Uint8Array, + Record & { result: JSONValue; command: string } + > +>; + +type ConvertDuplexCaller = T extends DuplexCaller + ? DuplexCallerImplementation + : never; + +type ConvertServerCaller = T extends ServerCaller + ? ServerCallerImplementation + : never; + +type ConvertClientCaller = T extends ClientCaller + ? ClientCallerImplementation + : never; + +type ConvertUnaryCaller = T extends UnaryCaller + ? UnaryCallerImplementation + : never; + +type ConvertCaller = T extends DuplexCaller + ? ConvertDuplexCaller + : T extends ServerCaller + ? ConvertServerCaller + : T extends ClientCaller + ? ConvertClientCaller + : T extends UnaryCaller + ? ConvertUnaryCaller + : T extends RawCaller + ? RawCallerImplementation + : never; + +/** + * Contains the handler Classes that defines the handling logic and types for the server handlers. + */ +type ServerManifest = Record; + +/** + * Contains the Caller classes that defines the types for the client callers. + */ +type ClientManifest = Record; + +type HandlerType = 'DUPLEX' | 'SERVER' | 'CLIENT' | 'UNARY' | 'RAW'; + +type MapCallers = { + [K in keyof T]: ConvertCaller; +}; + +declare const brand: unique symbol; + +type Opaque = T & { readonly [brand]: K }; + +type JSONValue = + | { [key: string]: JSONValue | undefined } + | Array + | string + | number + | boolean + | null + | undefined; + +type POJO = { [key: string]: any }; +type PromiseDeconstructed = { + p: Promise; + resolveP: (value: T | PromiseLike) => void; + rejectP: (reason?: any) => void; +}; + +export type { + IdGen, + JSONRPCRequestMessage, + JSONRPCRequestNotification, + JSONRPCResponseResult, + JSONRPCResponseError, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCMessage, + HandlerImplementation, + RawHandlerImplementation, + DuplexHandlerImplementation, + ServerHandlerImplementation, + ClientHandlerImplementation, + UnaryHandlerImplementation, + ContainerType, + RPCStream, + StreamFactory, + MiddlewareFactory, + ServerManifest, + ClientManifest, + HandlerType, + MapCallers, + JSONValue, + POJO, + PromiseDeconstructed, +}; diff --git a/src/utils/index.ts b/src/utils/index.ts new file mode 100644 index 0000000..d799db4 --- /dev/null +++ b/src/utils/index.ts @@ -0,0 +1,2 @@ +export * from './middleware'; +export * from './utils'; diff --git a/src/utils/middleware.ts b/src/utils/middleware.ts new file mode 100644 index 0000000..49f5504 --- /dev/null +++ b/src/utils/middleware.ts @@ -0,0 +1,203 @@ +import type { + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResponseResult, + MiddlewareFactory, +} from '../types'; +import { TransformStream } from 'stream/web'; +import { JSONParser } from '@streamparser/json'; +import * as rpcUtils from './utils'; +import { promise } from './utils'; +import * as rpcErrors from '../errors'; + +/** + * This function is a factory to create a TransformStream that will + * transform a `Uint8Array` stream to a JSONRPC message stream. + * The parsed messages will be validated with the provided messageParser, this + * also infers the type of the stream output. + * @param messageParser - Validates the JSONRPC messages, so you can select for a + * specific type of message + * @param bufferByteLimit - sets the number of bytes buffered before throwing an + * error. This is used to avoid infinitely buffering the input. + */ +function binaryToJsonMessageStream( + messageParser: (message: unknown) => T, + bufferByteLimit: number = 1024 * 1024, +): TransformStream { + const parser = new JSONParser({ + separator: '', + paths: ['$'], + }); + let bytesWritten: number = 0; + + return new TransformStream({ + flush: async () => { + // Avoid potential race conditions by allowing parser to end first + const waitP = promise(); + parser.onEnd = () => waitP.resolveP(); + parser.end(); + await waitP.p; + }, + start: (controller) => { + parser.onValue = (value) => { + const jsonMessage = messageParser(value.value); + controller.enqueue(jsonMessage); + bytesWritten = 0; + }; + }, + transform: (chunk) => { + try { + bytesWritten += chunk.byteLength; + parser.write(chunk); + } catch (e) { + throw new rpcErrors.ErrorRPCParse(undefined, { cause: e }); + } + if (bytesWritten > bufferByteLimit) { + throw new rpcErrors.ErrorRPCMessageLength(); + } + }, + }); +} + +/** + * This function is a factory for a TransformStream that will transform + * JsonRPCMessages into the `Uint8Array` form. This is used for the stream + * output. + */ +function jsonMessageToBinaryStream(): TransformStream< + JSONRPCMessage, + Uint8Array +> { + return new TransformStream({ + transform: (chunk, controller) => { + controller.enqueue(Buffer.from(JSON.stringify(chunk))); + }, + }); +} + +/** + * This function is a factory for creating a pass-through streamPair. It is used + * as the default middleware for the middleware wrappers. + */ +function defaultMiddleware() { + return { + forward: new TransformStream(), + reverse: new TransformStream(), + }; +} + +/** + * This convenience factory for creating wrapping middleware with the basic + * message processing and parsing for the server middleware. + * In the forward path, it will transform the binary stream into the validated + * JsonRPCMessages and pipe it through the provided middleware. + * The reverse path will pipe the output stream through the provided middleware + * and then transform it back to a binary stream. + * @param middlewareFactory - The provided middleware + * @param parserBufferByteLimit + */ +function defaultServerMiddlewareWrapper( + middlewareFactory: MiddlewareFactory< + JSONRPCRequest, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResponse + > = defaultMiddleware, + parserBufferByteLimit: number = 1024 * 1024, +): MiddlewareFactory { + return (ctx, cancel, meta) => { + const inputTransformStream = binaryToJsonMessageStream( + rpcUtils.parseJSONRPCRequest, + parserBufferByteLimit, + ); + const outputTransformStream = new TransformStream< + JSONRPCResponseResult, + JSONRPCResponseResult + >(); + + const middleMiddleware = middlewareFactory(ctx, cancel, meta); + + const forwardReadable = inputTransformStream.readable.pipeThrough( + middleMiddleware.forward, + ); // Usual middleware here + const reverseReadable = outputTransformStream.readable + .pipeThrough(middleMiddleware.reverse) // Usual middleware here + .pipeThrough(jsonMessageToBinaryStream()); + + return { + forward: { + readable: forwardReadable, + writable: inputTransformStream.writable, + }, + reverse: { + readable: reverseReadable, + writable: outputTransformStream.writable, + }, + }; + }; +} + +/** + * This convenience factory for creating wrapping middleware with the basic + * message processing and parsing for the server middleware. + * The forward path will pipe the input through the provided middleware and then + * transform it to the binary stream. + * The reverse path will parse and validate the output and pipe it through the + * provided middleware. + * @param middleware - the provided middleware + * @param parserBufferByteLimit - Max number of bytes to buffer when parsing the stream. Exceeding this results in an + * `ErrorRPCMessageLength` error. + */ +const defaultClientMiddlewareWrapper = ( + middleware: MiddlewareFactory< + JSONRPCRequest, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResponse + > = defaultMiddleware, + parserBufferByteLimit?: number, +): MiddlewareFactory< + Uint8Array, + JSONRPCRequest, + JSONRPCResponse, + Uint8Array +> => { + return (ctx, cancel, meta) => { + const outputTransformStream = binaryToJsonMessageStream( + rpcUtils.parseJSONRPCResponse, + parserBufferByteLimit, + ); + const inputTransformStream = new TransformStream< + JSONRPCRequest, + JSONRPCRequest + >(); + + const middleMiddleware = middleware(ctx, cancel, meta); + const forwardReadable = inputTransformStream.readable + .pipeThrough(middleMiddleware.forward) // Usual middleware here + .pipeThrough(jsonMessageToBinaryStream()); + const reverseReadable = outputTransformStream.readable.pipeThrough( + middleMiddleware.reverse, + ); // Usual middleware here + + return { + forward: { + readable: forwardReadable, + writable: inputTransformStream.writable, + }, + reverse: { + readable: reverseReadable, + writable: outputTransformStream.writable, + }, + }; + }; +}; + +export { + binaryToJsonMessageStream, + jsonMessageToBinaryStream, + defaultMiddleware, + defaultServerMiddlewareWrapper, + defaultClientMiddlewareWrapper, +}; diff --git a/src/utils/utils.ts b/src/utils/utils.ts new file mode 100644 index 0000000..8dbcab2 --- /dev/null +++ b/src/utils/utils.ts @@ -0,0 +1,545 @@ +import type { + ClientManifest, + HandlerType, + JSONRPCError, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCRequestMessage, + JSONRPCRequestNotification, + JSONRPCResponse, + JSONRPCResponseError, + JSONRPCResponseResult, + PromiseDeconstructed, +} from '../types'; +import type { JSONValue, IdGen } from '../types'; +import type { Timer } from '@matrixai/timer'; +import { TransformStream } from 'stream/web'; +import { JSONParser } from '@streamparser/json'; +import { AbstractError } from '@matrixai/errors'; +import * as rpcErrors from '../errors'; +import * as errors from '../errors'; +import { ErrorRPCRemote } from '../errors'; +import { ErrorRPC } from '../errors'; + +// Importing PK funcs and utils which are essential for RPC +function isObject(o: unknown): o is object { + return o !== null && typeof o === 'object'; +} +function promise(): PromiseDeconstructed { + let resolveP, rejectP; + const p = new Promise((resolve, reject) => { + resolveP = resolve; + rejectP = reject; + }); + return { + p, + resolveP, + rejectP, + }; +} + +async function sleep(ms: number): Promise { + return await new Promise((r) => setTimeout(r, ms)); +} + +function parseJSONRPCRequest( + message: unknown, +): JSONRPCRequest { + if (!isObject(message)) { + throw new rpcErrors.ErrorRPCParse('must be a JSON POJO'); + } + if (!('method' in message)) { + throw new rpcErrors.ErrorRPCParse('`method` property must be defined'); + } + if (typeof message.method !== 'string') { + throw new rpcErrors.ErrorRPCParse('`method` property must be a string'); + } + // If ('params' in message && !utils.isObject(message.params)) { + // throw new rpcErrors.ErrorRPCParse('`params` property must be a POJO'); + // } + return message as JSONRPCRequest; +} + +function parseJSONRPCRequestMessage( + message: unknown, +): JSONRPCRequestMessage { + const jsonRequest = parseJSONRPCRequest(message); + if (!('id' in jsonRequest)) { + throw new rpcErrors.ErrorRPCParse('`id` property must be defined'); + } + if ( + typeof jsonRequest.id !== 'string' && + typeof jsonRequest.id !== 'number' && + jsonRequest.id !== null + ) { + throw new rpcErrors.ErrorRPCParse( + '`id` property must be a string, number or null', + ); + } + return jsonRequest as JSONRPCRequestMessage; +} + +function parseJSONRPCRequestNotification( + message: unknown, +): JSONRPCRequestNotification { + const jsonRequest = parseJSONRPCRequest(message); + if ('id' in jsonRequest) { + throw new rpcErrors.ErrorRPCParse('`id` property must not be defined'); + } + return jsonRequest as JSONRPCRequestNotification; +} + +function parseJSONRPCResponseResult( + message: unknown, +): JSONRPCResponseResult { + if (!isObject(message)) { + throw new rpcErrors.ErrorRPCParse('must be a JSON POJO'); + } + if (!('result' in message)) { + throw new rpcErrors.ErrorRPCParse('`result` property must be defined'); + } + if ('error' in message) { + throw new rpcErrors.ErrorRPCParse('`error` property must not be defined'); + } + // If (!utils.isObject(message.result)) { + // throw new rpcErrors.ErrorRPCParse('`result` property must be a POJO'); + // } + if (!('id' in message)) { + throw new rpcErrors.ErrorRPCParse('`id` property must be defined'); + } + if ( + typeof message.id !== 'string' && + typeof message.id !== 'number' && + message.id !== null + ) { + throw new rpcErrors.ErrorRPCParse( + '`id` property must be a string, number or null', + ); + } + return message as JSONRPCResponseResult; +} + +function parseJSONRPCResponseError(message: unknown): JSONRPCResponseError { + if (!isObject(message)) { + throw new rpcErrors.ErrorRPCParse('must be a JSON POJO'); + } + if ('result' in message) { + throw new rpcErrors.ErrorRPCParse('`result` property must not be defined'); + } + if (!('error' in message)) { + throw new rpcErrors.ErrorRPCParse('`error` property must be defined'); + } + parseJSONRPCError(message.error); + if (!('id' in message)) { + throw new rpcErrors.ErrorRPCParse('`id` property must be defined'); + } + if ( + typeof message.id !== 'string' && + typeof message.id !== 'number' && + message.id !== null + ) { + throw new rpcErrors.ErrorRPCParse( + '`id` property must be a string, number or null', + ); + } + return message as JSONRPCResponseError; +} + +function parseJSONRPCError(message: unknown): JSONRPCError { + if (!isObject(message)) { + throw new rpcErrors.ErrorRPCParse('must be a JSON POJO'); + } + if (!('code' in message)) { + throw new rpcErrors.ErrorRPCParse('`code` property must be defined'); + } + if (typeof message.code !== 'number') { + throw new rpcErrors.ErrorRPCParse('`code` property must be a number'); + } + if (!('message' in message)) { + throw new rpcErrors.ErrorRPCParse('`message` property must be defined'); + } + if (typeof message.message !== 'string') { + throw new rpcErrors.ErrorRPCParse('`message` property must be a string'); + } + // If ('data' in message && !utils.isObject(message.data)) { + // throw new rpcErrors.ErrorRPCParse('`data` property must be a POJO'); + // } + return message as JSONRPCError; +} + +function parseJSONRPCResponse( + message: unknown, +): JSONRPCResponse { + if (!isObject(message)) { + throw new rpcErrors.ErrorRPCParse('must be a JSON POJO'); + } + try { + return parseJSONRPCResponseResult(message); + } catch (e) { + // Do nothing + } + try { + return parseJSONRPCResponseError(message); + } catch (e) { + // Do nothing + } + throw new rpcErrors.ErrorRPCParse( + 'structure did not match a `JSONRPCResponse`', + ); +} + +function parseJSONRPCMessage( + message: unknown, +): JSONRPCMessage { + if (!isObject(message)) { + throw new rpcErrors.ErrorRPCParse('must be a JSON POJO'); + } + if (!('jsonrpc' in message)) { + throw new rpcErrors.ErrorRPCParse('`jsonrpc` property must be defined'); + } + if (message.jsonrpc !== '2.0') { + throw new rpcErrors.ErrorRPCParse( + '`jsonrpc` property must be a string of "2.0"', + ); + } + try { + return parseJSONRPCRequest(message); + } catch { + // Do nothing + } + try { + return parseJSONRPCResponse(message); + } catch { + // Do nothing + } + throw new rpcErrors.ErrorRPCParse( + 'Message structure did not match a `JSONRPCMessage`', + ); +} +/** + * Serializes an ErrorRPC instance into a JSONValue object suitable for RPC. + * @param {ErrorRPC} error - The ErrorRPC instance to serialize. + * @param {any} [id] - Optional id for the error object in the RPC response. + * @returns {JSONValue} The serialized ErrorRPC instance. + */ +function fromError(error: ErrorRPC, id?: any): JSONValue { + const data: { [key: string]: JSONValue } = { + message: error.message, + description: error.description, + data: error.data, + }; + if (error.code !== undefined) { + data.code = error.code; + } + return { + jsonrpc: '2.0', + error: { + type: error.name, + ...data, + }, + id: id !== undefined ? id : null, + }; +} + +/** + * Error constructors for non-Polykey errors + * Allows these errors to be reconstructed from RPC metadata + */ +const standardErrors = { + Error, + TypeError, + SyntaxError, + ReferenceError, + EvalError, + RangeError, + URIError, + AggregateError, + AbstractError, + ErrorRPCRemote, + ErrorRPC, +}; +/** + * Creates a replacer function that omits a specific key during serialization. + * @returns {Function} The replacer function. + */ +const createReplacer = () => { + return (keyToRemove) => { + return (key, value) => { + if (key === keyToRemove) { + return undefined; + } + + if (key !== 'code') { + if (value instanceof ErrorRPC) { + return { + code: value.code, + message: value.message, + data: value.data, + type: value.constructor.name, + }; + } + + if (value instanceof AggregateError) { + return { + type: value.constructor.name, + data: { + errors: value.errors, + message: value.message, + stack: value.stack, + }, + }; + } + } + + return value; + }; + }; +}; +/** + * The replacer function to customize the serialization process. + */ +const replacer = createReplacer(); + +/** + * Reviver function for deserializing errors sent over RPC. + * @param {string} key - The key in the JSON object. + * @param {any} value - The value corresponding to the key in the JSON object. + * @returns {any} The reconstructed error object or the original value. + */ +function reviver(key: string, value: any): any { + // If the value is an error then reconstruct it + if ( + typeof value === 'object' && + typeof value.type === 'string' && + typeof value.data === 'object' + ) { + try { + let eClass = errors[value.type]; + if (eClass != null) return eClass.fromJSON(value); + eClass = standardErrors[value.type]; + if (eClass != null) { + let e; + switch (eClass) { + case AbstractError: + return eClass.fromJSON(); + case AggregateError: + if ( + !Array.isArray(value.data.errors) || + typeof value.data.message !== 'string' || + ('stack' in value.data && typeof value.data.stack !== 'string') + ) { + throw new TypeError(`cannot decode JSON to ${value.type}`); + } + e = new eClass(value.data.errors, value.data.message); + e.stack = value.data.stack; + break; + default: + if ( + typeof value.data.message !== 'string' || + ('stack' in value.data && typeof value.data.stack !== 'string') + ) { + throw new TypeError(`Cannot decode JSON to ${value.type}`); + } + e = new eClass(value.data.message); + e.stack = value.data.stack; + break; + } + return e; + } + } catch (e) { + // If `TypeError` which represents decoding failure + // then return value as-is + // Any other exception is a bug + if (!(e instanceof TypeError)) { + throw e; + } + } + // Other values are returned as-is + return value; + } else if (key === '') { + // Root key will be '' + // Reaching here means the root JSON value is not a valid exception + // Therefore ErrorPolykeyUnknown is only ever returned at the top-level + return new rpcErrors.ErrorRPC('Unknown error JSON'); + } else { + return value; + } +} +/** + * Deserializes an error response object into an ErrorRPCRemote instance. + * @param {any} errorResponse - The error response object. + * @param {any} [metadata] - Optional metadata for the deserialized error. + * @returns {ErrorRPCRemote} The deserialized ErrorRPCRemote instance. + * @throws {TypeError} If the errorResponse object is invalid. + */ +function toError(errorResponse: any, metadata?: any): ErrorRPCRemote { + if ( + typeof errorResponse !== 'object' || + errorResponse === null || + !('error' in errorResponse) || + !('type' in errorResponse.error) || + !('message' in errorResponse.error) + ) { + throw new TypeError('Invalid error data object'); + } + + const errorData = errorResponse.error; + const error = new ErrorRPCRemote(metadata, errorData.message, { + cause: errorData.cause, + data: errorData.data === undefined ? null : errorData.data, + }); + error.message = errorData.message; + error.code = errorData.code; + error.description = errorData.description; + error.data = errorData.data; + + return error; +} + +/** + * This constructs a transformation stream that converts any input into a + * JSONRCPRequest message. It also refreshes a timer each time a message is processed if + * one is provided. + * @param method - Name of the method that was called, used to select the + * server side. + * @param timer - Timer that gets refreshed each time a message is provided. + */ +function clientInputTransformStream( + method: string, + timer?: Timer, +): TransformStream { + return new TransformStream({ + transform: (chunk, controller) => { + timer?.refresh(); + const message: JSONRPCRequest = { + method, + jsonrpc: '2.0', + id: null, + params: chunk, + }; + controller.enqueue(message); + }, + }); +} + +/** + * This constructs a transformation stream that converts any error messages + * into errors. It also refreshes a timer each time a message is processed if + * one is provided. + * @param clientMetadata - Metadata that is attached to an error when one is + * created. + * @param timer - Timer that gets refreshed each time a message is provided. + */ +function clientOutputTransformStream( + clientMetadata?: JSONValue, + timer?: Timer, +): TransformStream, O> { + return new TransformStream, O>({ + transform: (chunk, controller) => { + timer?.refresh(); + // `error` indicates it's an error message + if ('error' in chunk) { + throw toError(chunk.error.data, clientMetadata); + } + controller.enqueue(chunk.result); + }, + }); +} + +function getHandlerTypes( + manifest: ClientManifest, +): Record { + const out: Record = {}; + for (const [k, v] of Object.entries(manifest)) { + out[k] = v.type; + } + return out; +} + +/** + * This function is a factory to create a TransformStream that will + * transform a `Uint8Array` stream to a JSONRPC message stream. + * The parsed messages will be validated with the provided messageParser, this + * also infers the type of the stream output. + * @param messageParser - Validates the JSONRPC messages, so you can select for a + * specific type of message + * @param bufferByteLimit - sets the number of bytes buffered before throwing an + * error. This is used to avoid infinitely buffering the input. + */ +function parseHeadStream( + messageParser: (message: unknown) => T, + bufferByteLimit: number = 1024 * 1024, +): TransformStream { + const parser = new JSONParser({ + separator: '', + paths: ['$'], + }); + let bytesWritten: number = 0; + let parsing = true; + let ended = false; + + const endP = promise(); + parser.onEnd = () => endP.resolveP(); + + return new TransformStream( + { + flush: async () => { + if (!parser.isEnded) parser.end(); + await endP.p; + }, + start: (controller) => { + parser.onValue = async (value) => { + const jsonMessage = messageParser(value.value); + controller.enqueue(jsonMessage); + bytesWritten = 0; + parsing = false; + }; + }, + transform: async (chunk, controller) => { + if (parsing) { + try { + bytesWritten += chunk.byteLength; + parser.write(chunk); + } catch (e) { + throw new rpcErrors.ErrorRPCParse(undefined, { + cause: e, + }); + } + if (bytesWritten > bufferByteLimit) { + throw new rpcErrors.ErrorRPCMessageLength(); + } + } else { + // Wait for parser to end + if (!ended) { + parser.end(); + await endP.p; + ended = true; + } + // Pass through normal chunks + controller.enqueue(chunk); + } + }, + }, + { highWaterMark: 1 }, + ); +} + +export { + parseJSONRPCRequest, + parseJSONRPCRequestMessage, + parseJSONRPCRequestNotification, + parseJSONRPCResponseResult, + parseJSONRPCResponseError, + parseJSONRPCResponse, + parseJSONRPCMessage, + replacer, + fromError, + toError, + clientInputTransformStream, + clientOutputTransformStream, + getHandlerTypes, + parseHeadStream, + promise, + isObject, + sleep, +}; diff --git a/tests/RPC.test.ts b/tests/RPC.test.ts new file mode 100644 index 0000000..591ff5d --- /dev/null +++ b/tests/RPC.test.ts @@ -0,0 +1,1046 @@ +import type { ContainerType, JSONRPCRequest } from '@/types'; +import type { ReadableStream } from 'stream/web'; +import type { JSONValue, IdGen } from '@/types'; +import type { ContextTimed } from '@matrixai/contexts'; +import { TransformStream } from 'stream/web'; +import { fc, testProp } from '@fast-check/jest'; +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import RawCaller from '@/callers/RawCaller'; +import DuplexCaller from '@/callers/DuplexCaller'; +import ServerCaller from '@/callers/ServerCaller'; +import ClientCaller from '@/callers/ClientCaller'; +import UnaryCaller from '@/callers/UnaryCaller'; +import * as rpcUtilsMiddleware from '@/utils/middleware'; +import { + ErrorRPC, + ErrorRPCHandlerFailed, + ErrorRPCParse, + ErrorRPCRemote, + ErrorRPCTimedOut, + JSONRPCErrorCode, +} from '@/errors'; +import * as rpcErrors from '@/errors'; +import RPCClient from '@/RPCClient'; +import RPCServer from '@/RPCServer'; +import * as utils from '@/utils'; +import DuplexHandler from '@/handlers/DuplexHandler'; +import RawHandler from '@/handlers/RawHandler'; +import ServerHandler from '@/handlers/ServerHandler'; +import UnaryHandler from '@/handlers/UnaryHandler'; +import ClientHandler from '@/handlers/ClientHandler'; +import { RPCStream } from '@/types'; +import { fromError, promise, replacer, toError } from '@/utils'; +import * as rpcTestUtils from './utils'; + +describe('RPC', () => { + const logger = new Logger(`RPC Test`, LogLevel.WARN, [new StreamHandler()]); + const idGen: IdGen = () => Promise.resolve(null); + testProp( + 'RPC communication with raw stream', + [rpcTestUtils.rawDataArb], + async (inputData) => { + const [outputResult, outputWriterStream] = + rpcTestUtils.streamToArray(); + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + let header: JSONRPCRequest | undefined; + + class TestMethod extends RawHandler { + public handle = async ( + input: [JSONRPCRequest, ReadableStream], + _cancel: (reason?: any) => void, + _meta: Record | undefined, + ): Promise<[JSONValue, ReadableStream]> => { + return new Promise((resolve) => { + const [header_, stream] = input; + header = header_; + resolve(['some leading data', stream]); + }); + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + rpcServer.handleStream({ + ...serverPair, + cancel: () => {}, + }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new RawCaller(), + }, + streamFactory: async () => { + return { + ...clientPair, + cancel: () => {}, + }; + }, + logger, + idGen, + }); + + const callerInterface = await rpcClient.methods.testMethod({ + hello: 'world', + }); + const writer = callerInterface.writable.getWriter(); + const pipeProm = callerInterface.readable.pipeTo(outputWriterStream); + for (const value of inputData) { + await writer.write(value); + } + await writer.close(); + const expectedHeader: JSONRPCRequest = { + jsonrpc: '2.0', + method: 'testMethod', + params: { hello: 'world' }, + id: null, + }; + expect(header).toStrictEqual(expectedHeader); + expect(callerInterface.meta?.result).toBe('some leading data'); + expect(await outputResult).toStrictEqual(inputData); + await pipeProm; + await rpcServer.destroy(); + await rpcClient.destroy(); + }, + ); + test('RPC communication with raw stream times out waiting for leading message', async () => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + void (async () => { + for await (const _ of serverPair.readable) { + // Just consume + } + })(); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new RawCaller(), + }, + streamFactory: async () => { + return { + ...clientPair, + cancel: () => {}, + }; + }, + logger, + idGen, + }); + + await expect( + rpcClient.methods.testMethod( + { + hello: 'world', + }, + { timer: 100 }, + ), + ).rejects.toThrow(rpcErrors.ErrorRPCTimedOut); + await rpcClient.destroy(); + }); + test('RPC communication with raw stream, raw handler throws', async () => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + class TestMethod extends RawHandler { + public handle = async ( + input: [JSONRPCRequest, ReadableStream], + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): Promise<[JSONValue, ReadableStream]> => { + throw new Error('some error'); + }; + } + + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + rpcServer.handleStream({ + ...serverPair, + cancel: () => {}, + }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new RawCaller(), + }, + streamFactory: async () => { + return { + ...clientPair, + cancel: () => {}, + }; + }, + logger, + idGen, + }); + + await expect( + rpcClient.methods.testMethod({ + hello: 'world', + }), + ).rejects.toThrow(rpcErrors.ErrorRPCRemote); + + await rpcServer.destroy(); + await rpcClient.destroy(); + }); + testProp( + 'RPC communication with duplex stream', + [fc.array(rpcTestUtils.safeJsonValueArb, { minLength: 1 })], + async (values) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncGenerator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncGenerator { + yield* input; + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + rpcServer.handleStream({ + ...serverPair, + cancel: () => {}, + }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new DuplexCaller(), + }, + streamFactory: async () => { + return { + ...clientPair, + cancel: () => {}, + }; + }, + logger, + idGen, + }); + + const callerInterface = await rpcClient.methods.testMethod(); + const writer = callerInterface.writable.getWriter(); + const reader = callerInterface.readable.getReader(); + for (const value of values) { + await writer.write(value); + expect((await reader.read()).value).toStrictEqual(value); + } + await writer.close(); + const result = await reader.read(); + expect(result.value).toBeUndefined(); + expect(result.done).toBeTrue(); + await rpcServer.destroy(); + await rpcClient.destroy(); + }, + ); + testProp( + 'RPC communication with server stream', + [fc.integer({ min: 1, max: 100 })], + async (value) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + class TestMethod extends ServerHandler { + public handle = async function* ( + input: number, + ): AsyncGenerator { + for (let i = 0; i < input; i++) { + yield i; + } + }; + } + + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + rpcServer.handleStream({ + ...serverPair, + cancel: () => {}, + }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new ServerCaller(), + }, + streamFactory: async () => { + return { + ...clientPair, + cancel: () => {}, + }; + }, + logger, + idGen, + }); + + const callerInterface = await rpcClient.methods.testMethod(value); + + const outputs: Array = []; + for await (const num of callerInterface) { + outputs.push(num); + } + expect(outputs.length).toEqual(value); + await rpcServer.destroy(); + await rpcClient.destroy(); + }, + ); + testProp( + 'RPC communication with client stream', + [fc.array(fc.integer(), { minLength: 1 }).noShrink()], + async (values) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + class TestMethod extends ClientHandler { + public handle = async ( + input: AsyncIterable, + ): Promise => { + let acc = 0; + for await (const number of input) { + acc += number; + } + return acc; + }; + } + + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + rpcServer.handleStream({ + ...serverPair, + cancel: () => {}, + }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new ClientCaller(), + }, + streamFactory: async () => { + return { + ...clientPair, + cancel: () => {}, + }; + }, + logger, + idGen, + }); + + const { output, writable } = await rpcClient.methods.testMethod(); + const writer = writable.getWriter(); + for (const value of values) { + await writer.write(value); + } + await writer.close(); + const expectedResult = values.reduce((p, c) => p + c); + await expect(output).resolves.toEqual(expectedResult); + await rpcServer.destroy(); + await rpcClient.destroy(); + }, + ); + testProp( + 'RPC communication with unary call', + [rpcTestUtils.safeJsonValueArb], + async (value) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + class TestMethod extends UnaryHandler { + public handle = async (input: JSONValue): Promise => { + return input; + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + rpcServer.handleStream({ + ...serverPair, + cancel: () => {}, + }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new UnaryCaller(), + }, + streamFactory: async () => { + return { + ...clientPair, + cancel: () => {}, + }; + }, + logger, + idGen, + }); + + const result = await rpcClient.methods.testMethod(value); + expect(result).toStrictEqual(value); + await rpcServer.destroy(); + await rpcClient.destroy(); + }, + ); + testProp( + 'RPC handles and sends errors', + [ + rpcTestUtils.safeJsonValueArb, + rpcTestUtils.errorArb(rpcTestUtils.errorArb()), + ], + async (value, error) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + class TestMethod extends UnaryHandler { + public handle = async ( + _input: JSONValue, + _cancel: (reason?: any) => void, + _meta: Record | undefined, + _ctx: ContextTimed, + ): Promise => { + throw error; + }; + } + + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + rpcServer.handleStream({ ...serverPair, cancel: () => {} }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new UnaryCaller(), + }, + streamFactory: async () => { + return { ...clientPair, cancel: () => {} }; + }, + logger, + idGen, + }); + + // Create a new promise so we can await it multiple times for assertions + const callProm = rpcClient.methods.testMethod(value).catch((e) => e); + + // The promise should be rejected + const rejection = await callProm; + + // The error should have specific properties + expect(rejection).toBeInstanceOf(rpcErrors.ErrorRPCRemote); + expect(rejection).toMatchObject({ code: -32006 }); + + // Cleanup + await rpcServer.destroy(); + await rpcClient.destroy(); + }, + ); + + testProp( + 'RPC handles and sends sensitive errors', + [ + rpcTestUtils.safeJsonValueArb, + rpcTestUtils.errorArb(rpcTestUtils.errorArb()), + ], + async (value, error) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + class TestMethod extends UnaryHandler { + public handle = async ( + _input: JSONValue, + _cancel: (reason?: any) => void, + _meta: Record | undefined, + _ctx: ContextTimed, + ): Promise => { + throw error; + }; + } + + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + sensitive: true, + logger, + idGen, + }); + rpcServer.handleStream({ ...serverPair, cancel: () => {} }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new UnaryCaller(), + }, + streamFactory: async () => { + return { ...clientPair, cancel: () => {} }; + }, + logger, + idGen, + }); + + const callProm = rpcClient.methods.testMethod(ErrorRPCRemote.description); + + // Use Jest's `.rejects` to handle the promise rejection + await expect(callProm).rejects.toBeInstanceOf(rpcErrors.ErrorRPCRemote); + await expect(callProm).rejects.not.toHaveProperty('cause.stack'); + + await rpcServer.destroy(); + await rpcClient.destroy(); + }, + ); + + test('middleware can end stream early', async () => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncIterableIterator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncIterableIterator { + yield* input; + }; + } + + const middleware = rpcUtilsMiddleware.defaultServerMiddlewareWrapper(() => { + return { + forward: new TransformStream({ + start: (controller) => { + // Controller.terminate(); + controller.error(Error('SOME ERROR')); + }, + }), + reverse: new TransformStream({ + start: (controller) => { + controller.error(Error('SOME ERROR')); + }, + }), + }; + }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + middlewareFactory: middleware, + logger, + idGen, + }); + rpcServer.handleStream({ + ...serverPair, + cancel: () => {}, + }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new DuplexCaller(), + }, + streamFactory: async () => { + return { + ...clientPair, + cancel: () => {}, + }; + }, + logger, + idGen, + }); + + const callerInterface = await rpcClient.methods.testMethod(); + const writer = callerInterface.writable.getWriter(); + await writer.write({}); + // Allow time to process buffer + await utils.sleep(0); + await expect(writer.write({})).toReject(); + const reader = callerInterface.readable.getReader(); + await expect(reader.read()).toReject(); + await expect(writer.closed).toReject(); + await expect(reader.closed).toReject(); + await expect(rpcServer.destroy(false)).toResolve(); + await rpcClient.destroy(); + }); + test('RPC client and server timeout concurrently', async () => { + let serverTimedOut = false; + let clientTimedOut = false; + // Generate test data (assuming fc.array generates some mock array) + const values = fc.array(rpcTestUtils.safeJsonValueArb, { minLength: 1 }); + + // Setup server and client communication pairs + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + const timeout = 1; + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncIterableIterator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncIterableIterator { + // Check for abort event + ctx.signal.throwIfAborted(); + const abortProm = utils.promise(); + ctx.signal.addEventListener('abort', () => { + abortProm.rejectP(ctx.signal.reason); + }); + await abortProm.p; + }; + } + const testMethodInstance = new TestMethod({}); + // Set up a client and server with matching timeout settings + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: testMethodInstance, + }, + logger, + idGen, + handlerTimeoutTime: timeout, + }); + // Register callback + rpcServer.registerOnTimeoutCallback(() => { + serverTimedOut = true; + }); + rpcServer.handleStream({ + ...serverPair, + cancel: () => {}, + }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new DuplexCaller(), + }, + streamFactory: async () => { + return { + ...clientPair, + cancel: () => {}, + }; + }, + logger, + idGen, + }); + const callerInterface = await rpcClient.methods.testMethod({ + timer: timeout, + }); + // Register callback + rpcClient.registerOnTimeoutCallback(() => { + clientTimedOut = true; + }); + const writer = callerInterface.writable.getWriter(); + const reader = callerInterface.readable.getReader(); + // Wait for server and client to timeout by checking the flag + await new Promise((resolve) => { + const checkFlag = () => { + if (serverTimedOut && clientTimedOut) resolve(); + else setTimeout(() => checkFlag(), 10); + }; + checkFlag(); + }); + // Expect both the client and the server to time out + await expect(writer.write(values[0])).rejects.toThrow( + 'Timed out waiting for header', + ); + + await expect(reader.read()).rejects.toThrow('Timed out waiting for header'); + + await rpcServer.destroy(); + await rpcClient.destroy(); + }); + // Test description + test('RPC server times out before client', async () => { + let serverTimedOut = false; + + // Generate test data (assuming fc.array generates some mock array) + const values = fc.array(rpcTestUtils.safeJsonValueArb, { minLength: 1 }); + + // Setup server and client communication pairs + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + // Define the server's method behavior + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncIterableIterator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ) { + ctx.signal.throwIfAborted(); + const abortProm = utils.promise(); + ctx.signal.addEventListener('abort', () => { + abortProm.rejectP(ctx.signal.reason); + }); + await abortProm.p; + }; + } + + // Create an instance of the RPC server with a shorter timeout + const rpcServer = await RPCServer.createRPCServer({ + manifest: { testMethod: new TestMethod({}) }, + logger, + idGen, + handlerTimeoutTime: 1, + }); + // Register callback + rpcServer.registerOnTimeoutCallback(() => { + serverTimedOut = true; + }); + rpcServer.handleStream({ ...serverPair, cancel: () => {} }); + + // Create an instance of the RPC client with a longer timeout + const rpcClient = await RPCClient.createRPCClient({ + manifest: { testMethod: new DuplexCaller() }, + streamFactory: async () => ({ ...clientPair, cancel: () => {} }), + logger, + idGen, + }); + + // Get server and client interfaces + const callerInterface = await rpcClient.methods.testMethod({ + timer: 10, + }); + const writer = callerInterface.writable.getWriter(); + const reader = callerInterface.readable.getReader(); + // Wait for server to timeout by checking the flag + await new Promise((resolve) => { + const checkFlag = () => { + if (serverTimedOut) resolve(); + else setTimeout(() => checkFlag(), 10); + }; + checkFlag(); + }); + + // We expect server to timeout before the client + await expect(writer.write(values[0])).rejects.toThrow( + 'Timed out waiting for header', + ); + await expect(reader.read()).rejects.toThrow('Timed out waiting for header'); + + // Cleanup + await rpcServer.destroy(); + await rpcClient.destroy(); + }); + test('RPC client times out before server', async () => { + // Generate test data (assuming fc.array generates some mock array) + const values = fc.array(rpcTestUtils.safeJsonValueArb, { minLength: 1 }); + + // Setup server and client communication pairs + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncIterableIterator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncIterableIterator { + ctx.signal.throwIfAborted(); + const abortProm = utils.promise(); + ctx.signal.addEventListener('abort', () => { + abortProm.rejectP(ctx.signal.reason); + }); + await abortProm.p; + }; + } + // Set up a client and server with matching timeout settings + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + + handlerTimeoutTime: 400, + }); + rpcServer.handleStream({ + ...serverPair, + cancel: () => {}, + }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new DuplexCaller(), + }, + streamFactory: async () => { + return { + ...clientPair, + cancel: () => {}, + }; + }, + logger, + idGen, + }); + const callerInterface = await rpcClient.methods.testMethod({ timer: 300 }); + const writer = callerInterface.writable.getWriter(); + const reader = callerInterface.readable.getReader(); + // Expect the client to time out first + await expect(writer.write(values[0])).toResolve(); + await expect(reader.read()).toReject(); + + await rpcServer.destroy(); + await rpcClient.destroy(); + }); + test('RPC client and server with infinite timeout', async () => { + // Set up a client and server with infinite timeout settings + const values = fc.array(rpcTestUtils.safeJsonValueArb, { minLength: 3 }); + + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncIterableIterator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ) { + ctx.signal.throwIfAborted(); + const abortProm = utils.promise(); + ctx.signal.addEventListener('abort', () => { + abortProm.rejectP(ctx.signal.reason); + }); + await abortProm.p; + }; + } + + const rpcServer = await RPCServer.createRPCServer({ + manifest: { testMethod: new TestMethod({}) }, + logger, + idGen, + handlerTimeoutTime: Infinity, + }); + rpcServer.handleStream({ ...serverPair, cancel: () => {} }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { testMethod: new DuplexCaller() }, + streamFactory: async () => ({ ...clientPair, cancel: () => {} }), + logger, + idGen, + }); + + const callerInterface = await rpcClient.methods.testMethod({ + timer: Infinity, + }); + + const writer = callerInterface.writable.getWriter(); + const reader = callerInterface.readable.getReader(); + + // Trigger a call that will hang indefinitely or for a long time #TODO + + // Write a value to the stream + const writePromise = writer.write(values[0]); + + // Trigger a read that will hang indefinitely + + const readPromise = reader.read(); + // Adding a randomized sleep here to check that neither timeout + const randomSleepTime = Math.floor(Math.random() * 1000) + 1; + // Random time between 1 and 1,000 ms + await utils.sleep(randomSleepTime); + // At this point, writePromise and readPromise should neither be resolved nor rejected + // because the server method is hanging. + + // Check if the promises are neither resolved nor rejected + const timeoutPromise = new Promise((resolve) => + setTimeout(() => resolve('timeout'), 1000), + ); + + const readStatus = await Promise.race([readPromise, timeoutPromise]); + // Check if read status is still pending; + + expect(readStatus).toBe('timeout'); + + // Expect neither to time out and verify that they can still handle other operations #TODO + await rpcServer.destroy(); + await rpcClient.destroy(); + }); + + testProp( + 'RPC Serializes and Deserializes ErrorRPCRemote', + [ + rpcTestUtils.safeJsonValueArb, + rpcTestUtils.errorArb(rpcTestUtils.errorArb()), + ], + async (value, error) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + class TestMethod extends UnaryHandler { + public handle = async ( + _input: JSONValue, + _cancel: (reason?: any) => void, + _meta: Record | undefined, + _ctx: ContextTimed, + ): Promise => { + throw error; + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + sensitive: true, + logger, + idGen, + fromError: utils.fromError, + replacer: utils.replacer, + }); + rpcServer.handleStream({ ...serverPair, cancel: () => {} }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new UnaryCaller(), + }, + streamFactory: async () => { + return { ...clientPair, cancel: () => {} }; + }, + logger, + idGen, + }); + + const errorInstance = new ErrorRPCRemote( + { code: -32006 }, + 'Parse error', + { cause: error, data: 'The server responded with an error' }, + ); + + const serializedError = fromError(errorInstance); + const deserializedError = rpcClient.toError(serializedError); + + expect(deserializedError).toBeInstanceOf(ErrorRPCRemote); + + // Check properties explicitly + const { code, message, data } = deserializedError as ErrorRPCRemote; + expect(code).toBe(-32006); + expect(message).toBe('Parse error'); + expect(data).toBe('The server responded with an error'); + + await rpcServer.destroy(); + await rpcClient.destroy(); + }, + ); + testProp( + 'RPC Serializes and Deserializes ErrorRPCRemote with Custom Replacer Function', + [ + rpcTestUtils.safeJsonValueArb, + rpcTestUtils.errorArb(rpcTestUtils.errorArb()), + ], + async (value, error) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + class TestMethod extends UnaryHandler { + public handle = async ( + _input: JSONValue, + _cancel: (reason?: any) => void, + _meta: Record | undefined, + _ctx: ContextTimed, + ): Promise => { + throw error; + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + sensitive: true, + logger, + idGen, + fromError: utils.fromError, + replacer: utils.replacer, + }); + rpcServer.handleStream({ ...serverPair, cancel: () => {} }); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new UnaryCaller(), + }, + streamFactory: async () => { + return { ...clientPair, cancel: () => {} }; + }, + logger, + idGen, + }); + + const errorInstance = new ErrorRPCRemote( + { code: -32006 }, + 'Parse error', + { cause: error, data: 'asda' }, + ); + + const serializedError = JSON.parse( + JSON.stringify(fromError(errorInstance), replacer('data')), + ); + + const callProm = rpcClient.methods.testMethod(serializedError); + const catchError = await callProm.catch((e) => e); + + const deserializedError = toError(serializedError); + + expect(deserializedError).toBeInstanceOf(ErrorRPCRemote); + + // Check properties explicitly + const { code, message, data } = deserializedError as ErrorRPCRemote; + expect(code).toBe(-32006); + expect(message).toBe('Parse error'); + expect(data).toBe(undefined); + + await rpcServer.destroy(); + await rpcClient.destroy(); + }, + ); +}); diff --git a/tests/RPCClient.test.ts b/tests/RPCClient.test.ts new file mode 100644 index 0000000..ade9717 --- /dev/null +++ b/tests/RPCClient.test.ts @@ -0,0 +1,1201 @@ +import type { ContextTimed } from '@matrixai/contexts'; +import type { JSONValue } from '@/types'; +import type { + JSONRPCRequest, + JSONRPCRequestMessage, + JSONRPCResponse, + JSONRPCResponseResult, + RPCStream, +} from '@/types'; +import type { IdGen } from '@/types'; +import { TransformStream, ReadableStream } from 'stream/web'; +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import { testProp, fc } from '@fast-check/jest'; +import RawCaller from '@/callers/RawCaller'; +import DuplexCaller from '@/callers/DuplexCaller'; +import ServerCaller from '@/callers/ServerCaller'; +import ClientCaller from '@/callers/ClientCaller'; +import UnaryCaller from '@/callers/UnaryCaller'; +import RPCClient from '@/RPCClient'; +import RPCServer from '@/RPCServer'; +import * as rpcErrors from '@/errors'; +import * as rpcUtilsMiddleware from '@/utils/middleware'; +import { promise, sleep } from '@/utils'; +import { ErrorRPCRemote } from '@/errors'; +import * as rpcTestUtils from './utils'; + +describe(`${RPCClient.name}`, () => { + const logger = new Logger(`${RPCServer.name} Test`, LogLevel.WARN, [ + new StreamHandler(), + ]); + const idGen: IdGen = () => Promise.resolve(null); + + const methodName = 'testMethod'; + const specificMessageArb = fc + .array(rpcTestUtils.jsonRpcResponseResultArb(), { + minLength: 5, + }) + .noShrink(); + + testProp( + 'raw caller', + [ + rpcTestUtils.safeJsonValueArb, + rpcTestUtils.rawDataArb, + rpcTestUtils.rawDataArb, + ], + async (headerParams, inputData, outputData) => { + const [inputResult, inputWritableStream] = + rpcTestUtils.streamToArray(); + const [outputResult, outputWritableStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: new ReadableStream({ + start: (controller) => { + const leadingResponse: JSONRPCResponseResult = { + jsonrpc: '2.0', + result: null, + id: null, + }; + controller.enqueue(Buffer.from(JSON.stringify(leadingResponse))); + for (const datum of outputData) { + controller.enqueue(datum); + } + controller.close(); + }, + }), + writable: inputWritableStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async () => streamPair, + logger, + idGen, + }); + const callerInterface = await rpcClient.rawStreamCaller( + 'testMethod', + headerParams, + ); + await callerInterface.readable.pipeTo(outputWritableStream); + const writer = callerInterface.writable.getWriter(); + for (const inputDatum of inputData) { + await writer.write(inputDatum); + } + await writer.close(); + + const expectedHeader: JSONRPCRequest = { + jsonrpc: '2.0', + method: methodName, + params: headerParams, + id: null, + }; + expect(await inputResult).toStrictEqual([ + Buffer.from(JSON.stringify(expectedHeader)), + ...inputData, + ]); + expect(await outputResult).toStrictEqual(outputData); + }, + ); + testProp('generic duplex caller', [specificMessageArb], async (messages) => { + const inputStream = rpcTestUtils.messagesToReadableStream(messages); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async () => streamPair, + logger, + idGen, + }); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName); + const writable = callerInterface.writable.getWriter(); + for await (const value of callerInterface.readable) { + await writable.write(value); + } + await writable.close(); + + const expectedMessages: Array = messages.map((v) => { + const request: JSONRPCRequestMessage = { + jsonrpc: '2.0', + method: methodName, + id: null, + ...(v.result === undefined ? {} : { params: v.result }), + }; + return request; + }); + const outputMessages = (await outputResult).map((v) => + JSON.parse(v.toString()), + ); + expect(outputMessages).toStrictEqual(expectedMessages); + await rpcClient.destroy(); + }); + testProp( + 'generic server stream caller', + [specificMessageArb, rpcTestUtils.safeJsonValueArb], + async (messages, params) => { + const inputStream = rpcTestUtils.messagesToReadableStream(messages); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async () => streamPair, + logger, + idGen, + }); + const callerInterface = await rpcClient.serverStreamCaller< + JSONValue, + JSONValue + >(methodName, params as JSONValue); + const values: Array = []; + for await (const value of callerInterface) { + values.push(value); + } + const expectedValues = messages.map((v) => v.result); + expect(values).toStrictEqual(expectedValues); + expect((await outputResult)[0]?.toString()).toStrictEqual( + JSON.stringify({ + method: methodName, + jsonrpc: '2.0', + id: null, + params, + }), + ); + await rpcClient.destroy(); + }, + ); + testProp( + 'generic client stream caller', + [ + rpcTestUtils.jsonRpcResponseResultArb(), + fc.array(rpcTestUtils.safeJsonValueArb), + ], + async (message, params) => { + const inputStream = rpcTestUtils.messagesToReadableStream([message]); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async () => streamPair, + logger, + idGen, + }); + const { output, writable } = await rpcClient.clientStreamCaller< + JSONValue, + JSONValue + >(methodName); + const writer = writable.getWriter(); + for (const param of params) { + await writer.write(param); + } + await writer.close(); + expect(await output).toStrictEqual(message.result); + const expectedOutput = params.map((v) => + JSON.stringify({ + method: methodName, + jsonrpc: '2.0', + id: null, + params: v, + }), + ); + expect((await outputResult).map((v) => v.toString())).toStrictEqual( + expectedOutput, + ); + await rpcClient.destroy(); + }, + ); + testProp( + 'generic unary caller', + [rpcTestUtils.jsonRpcResponseResultArb(), rpcTestUtils.safeJsonValueArb], + async (message, params) => { + const inputStream = rpcTestUtils.messagesToReadableStream([message]); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async () => streamPair, + logger, + idGen, + }); + const result = await rpcClient.unaryCaller( + methodName, + params as JSONValue, + ); + expect(result).toStrictEqual(message.result); + expect((await outputResult)[0]?.toString()).toStrictEqual( + JSON.stringify({ + method: methodName, + jsonrpc: '2.0', + id: null, + params: params, + }), + ); + await rpcClient.destroy(); + }, + ); + testProp( + 'generic duplex caller can throw received error message', + [ + fc.array(rpcTestUtils.jsonRpcResponseResultArb()), + rpcTestUtils.jsonRpcResponseErrorArb(rpcTestUtils.errorArb()), + ], + async (messages, errorMessage) => { + const inputStream = rpcTestUtils.messagesToReadableStream([ + ...messages, + errorMessage, + ]); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async () => streamPair, + logger, + idGen, + }); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName); + await callerInterface.writable.close(); + const callProm = (async () => { + for await (const _ of callerInterface.readable) { + // Only consume + } + })(); + await expect(callProm).rejects.toThrow(rpcErrors.ErrorRPCRemote); + await outputResult; + await rpcClient.destroy(); + }, + ); + testProp( + 'generic duplex caller can throw received error message with sensitive', + [ + fc.array(rpcTestUtils.jsonRpcResponseResultArb()), + rpcTestUtils.jsonRpcResponseErrorArb(rpcTestUtils.errorArb(), true), + ], + async (messages, errorMessage) => { + const inputStream = rpcTestUtils.messagesToReadableStream([ + ...messages, + errorMessage, + ]); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async () => streamPair, + logger, + idGen, + }); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName); + await callerInterface.writable.close(); + const callProm = (async () => { + for await (const _ of callerInterface.readable) { + // Only consume + } + })(); + await expect(callProm).rejects.toThrow(rpcErrors.ErrorRPCRemote); + await outputResult; + await rpcClient.destroy(); + }, + ); + testProp( + 'generic duplex caller can throw received error message with causes', + [ + fc.array(rpcTestUtils.jsonRpcResponseResultArb()), + rpcTestUtils.jsonRpcResponseErrorArb( + rpcTestUtils.errorArb(rpcTestUtils.errorArb()), + true, + ), + ], + async (messages, errorMessage) => { + const inputStream = rpcTestUtils.messagesToReadableStream([ + ...messages, + errorMessage, + ]); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async () => streamPair, + logger, + idGen, + }); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName); + await callerInterface.writable.close(); + const callProm = (async () => { + for await (const _ of callerInterface.readable) { + // Only consume + } + })(); + await expect(callProm).rejects.toThrow(rpcErrors.ErrorRPCRemote); + await outputResult; + await rpcClient.destroy(); + }, + ); + testProp( + 'generic duplex caller with forward Middleware', + [specificMessageArb], + async (messages) => { + const inputStream = rpcTestUtils.messagesToReadableStream(messages); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async () => streamPair, + middlewareFactory: rpcUtilsMiddleware.defaultClientMiddlewareWrapper( + () => { + return { + forward: new TransformStream({ + transform: (chunk, controller) => { + controller.enqueue({ + ...chunk, + params: 'one', + }); + }, + }), + reverse: new TransformStream(), + }; + }, + ), + logger, + idGen, + }); + + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName); + const reader = callerInterface.readable.getReader(); + const writer = callerInterface.writable.getWriter(); + while (true) { + const { value, done } = await reader.read(); + if (done) { + // We have to end the writer otherwise the stream never closes + await writer.close(); + break; + } + await writer.write(value); + } + + const expectedMessages: Array = messages.map( + () => { + const request: JSONRPCRequestMessage = { + jsonrpc: '2.0', + method: methodName, + id: null, + params: 'one', + }; + return request; + }, + ); + const outputMessages = (await outputResult).map((v) => + JSON.parse(v.toString()), + ); + expect(outputMessages).toStrictEqual(expectedMessages); + await rpcClient.destroy(); + }, + ); + testProp( + 'generic duplex caller with reverse Middleware', + [specificMessageArb], + async (messages) => { + const inputStream = rpcTestUtils.messagesToReadableStream(messages); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async () => streamPair, + middlewareFactory: rpcUtilsMiddleware.defaultClientMiddlewareWrapper( + () => { + return { + forward: new TransformStream(), + reverse: new TransformStream({ + transform: (chunk, controller) => { + controller.enqueue({ + ...chunk, + result: 'one', + }); + }, + }), + }; + }, + ), + logger, + idGen, + }); + + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName); + const reader = callerInterface.readable.getReader(); + const writer = callerInterface.writable.getWriter(); + while (true) { + const { value, done } = await reader.read(); + if (done) { + // We have to end the writer otherwise the stream never closes + await writer.close(); + break; + } + expect(value).toBe('one'); + await writer.write(value); + } + await outputResult; + await rpcClient.destroy(); + }, + ); + testProp( + 'manifest server call', + [specificMessageArb, fc.string()], + async (messages, params) => { + const inputStream = rpcTestUtils.messagesToReadableStream(messages); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + server: new ServerCaller(), + }, + streamFactory: async () => streamPair, + logger, + idGen, + }); + const callerInterface = await rpcClient.methods.server(params); + const values: Array = []; + for await (const value of callerInterface) { + values.push(value); + } + const expectedValues = messages.map((v) => v.result); + expect(values).toStrictEqual(expectedValues); + expect((await outputResult)[0]?.toString()).toStrictEqual( + JSON.stringify({ + method: 'server', + jsonrpc: '2.0', + id: null, + params, + }), + ); + await rpcClient.destroy(); + }, + ); + testProp( + 'manifest client call', + [ + rpcTestUtils.jsonRpcResponseResultArb(fc.string()), + fc.array(fc.string(), { minLength: 5 }), + ], + async (message, params) => { + const inputStream = rpcTestUtils.messagesToReadableStream([message]); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + client: new ClientCaller(), + }, + streamFactory: async () => streamPair, + logger, + idGen, + }); + const { output, writable } = await rpcClient.methods.client(); + const writer = writable.getWriter(); + for (const param of params) { + await writer.write(param); + } + expect(await output).toStrictEqual(message.result); + await writer.close(); + const expectedOutput = params.map((v) => + JSON.stringify({ + method: 'client', + jsonrpc: '2.0', + id: null, + params: v, + }), + ); + expect((await outputResult).map((v) => v.toString())).toStrictEqual( + expectedOutput, + ); + await rpcClient.destroy(); + }, + ); + testProp( + 'manifest unary call', + [rpcTestUtils.jsonRpcResponseResultArb().noShrink(), fc.string()], + async (message, params) => { + const inputStream = rpcTestUtils.messagesToReadableStream([message]); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + unary: new UnaryCaller(), + }, + streamFactory: async () => streamPair, + logger, + idGen, + }); + const result = await rpcClient.methods.unary(params); + expect(result).toStrictEqual(message.result); + expect((await outputResult)[0]?.toString()).toStrictEqual( + JSON.stringify({ + method: 'unary', + jsonrpc: '2.0', + id: null, + params: params, + }), + ); + await rpcClient.destroy(); + }, + ); + testProp( + 'manifest raw caller', + [ + rpcTestUtils.safeJsonValueArb, + rpcTestUtils.rawDataArb, + rpcTestUtils.rawDataArb, + ], + async (headerParams, inputData, outputData) => { + const [inputResult, inputWritableStream] = + rpcTestUtils.streamToArray(); + const [outputResult, outputWritableStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: new ReadableStream({ + start: (controller) => { + const leadingResponse: JSONRPCResponseResult = { + jsonrpc: '2.0', + result: null, + id: null, + }; + controller.enqueue(Buffer.from(JSON.stringify(leadingResponse))); + for (const datum of outputData) { + controller.enqueue(datum); + } + controller.close(); + }, + }), + writable: inputWritableStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + raw: new RawCaller(), + }, + streamFactory: async () => streamPair, + logger, + idGen, + }); + const callerInterface = await rpcClient.methods.raw(headerParams); + await callerInterface.readable.pipeTo(outputWritableStream); + const writer = callerInterface.writable.getWriter(); + for (const inputDatum of inputData) { + await writer.write(inputDatum); + } + await writer.close(); + + const expectedHeader: JSONRPCRequest = { + jsonrpc: '2.0', + method: 'raw', + params: headerParams, + id: null, + }; + expect(await inputResult).toStrictEqual([ + Buffer.from(JSON.stringify(expectedHeader)), + ...inputData, + ]); + expect(await outputResult).toStrictEqual(outputData); + }, + { seed: -783452149, path: '0:0:0:0:0:0:0', endOnFailure: true }, + ); + testProp( + 'manifest duplex caller', + [ + fc.array(rpcTestUtils.jsonRpcResponseResultArb(fc.string()), { + minLength: 1, + }), + ], + async (messages) => { + const inputStream = rpcTestUtils.messagesToReadableStream(messages); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + duplex: new DuplexCaller(), + }, + streamFactory: async () => streamPair, + logger, + idGen, + }); + let count = 0; + const callerInterface = await rpcClient.methods.duplex(); + const writer = callerInterface.writable.getWriter(); + for await (const value of callerInterface.readable) { + count += 1; + await writer.write(value); + } + await writer.close(); + const result = await outputResult; + // We're just checking that it's consuming the messages as expected + expect(result.length).toEqual(messages.length); + expect(count).toEqual(messages.length); + await rpcClient.destroy(); + }, + ); + test('manifest without handler errors', async () => { + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async () => { + return {} as RPCStream; + }, + logger, + idGen, + }); + // @ts-ignore: ignoring type safety here + expect(() => rpcClient.methods.someMethod()).toThrow(); + // @ts-ignore: ignoring type safety here + expect(() => rpcClient.withMethods.someMethod()).toThrow(); + await rpcClient.destroy(); + }); + describe('raw caller', () => { + test('raw caller uses default timeout when creating stream', async () => { + const holdProm = promise(); + let ctx: ContextTimed | undefined; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async (ctx_) => { + ctx = ctx_; + await holdProm.p; + // Should never reach this when testing + return {} as RPCStream; + }, + streamKeepAliveTimeoutTime: 100, + logger, + idGen, + }); + // Timing out on stream creation + const callerInterfaceProm = rpcClient.rawStreamCaller('testMethod', {}); + await expect(callerInterfaceProm).toReject(); + await expect(callerInterfaceProm).rejects.toThrow( + rpcErrors.ErrorRPCTimedOut, + ); + expect(ctx?.signal.aborted).toBeTrue(); + expect(ctx?.signal.reason).toBeInstanceOf(rpcErrors.ErrorRPCTimedOut); + }); + test('raw caller times out when creating stream', async () => { + const holdProm = promise(); + let ctx: ContextTimed | undefined; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async (ctx_) => { + ctx = ctx_; + await holdProm.p; + // Should never reach this when testing + return {} as RPCStream; + }, + logger, + idGen, + }); + // Timing out on stream creation + const callerInterfaceProm = rpcClient.rawStreamCaller( + 'testMethod', + {}, + { timer: 100 }, + ); + await expect(callerInterfaceProm).toReject(); + await expect(callerInterfaceProm).rejects.toThrow( + rpcErrors.ErrorRPCTimedOut, + ); + expect(ctx?.signal.aborted).toBeTrue(); + expect(ctx?.signal.reason).toBeInstanceOf(rpcErrors.ErrorRPCTimedOut); + }); + test('raw caller handles abort when creating stream', async () => { + const holdProm = promise(); + const ctxProm = promise(); + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async (ctx_) => { + ctxProm.resolveP(ctx_); + await holdProm.p; + // Should never reach this when testing + return {} as RPCStream; + }, + logger, + idGen, + }); + const abortController = new AbortController(); + const rejectReason = Symbol('rejectReason'); + + // Timing out on stream creation + const callerInterfaceProm = rpcClient.rawStreamCaller( + 'testMethod', + {}, + { signal: abortController.signal }, + ); + abortController.abort(rejectReason); + const ctx = await ctxProm.p; + await expect(callerInterfaceProm).toReject(); + await expect(callerInterfaceProm).rejects.toBe(rejectReason); + expect(ctx?.signal.aborted).toBeTrue(); + expect(ctx?.signal.reason).toBe(rejectReason); + }); + test('raw caller times out awaiting stream', async () => { + const forwardPassThroughStream = new TransformStream< + Uint8Array, + Uint8Array + >(); + const reversePassThroughStream = new TransformStream< + Uint8Array, + Uint8Array + >(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + writable: forwardPassThroughStream.writable, + readable: reversePassThroughStream.readable, + }; + const ctxProm = promise(); + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async (ctx_) => { + ctxProm.resolveP(ctx_); + return streamPair; + }, + logger, + idGen, + }); + // Timing out on stream + await expect( + Promise.all([ + rpcClient.rawStreamCaller('testMethod', {}, { timer: 100 }), + forwardPassThroughStream.readable.getReader().read(), + ]), + ).rejects.toThrow(rpcErrors.ErrorRPCTimedOut); + const ctx = await ctxProm.p; + await ctx?.timer; + expect(ctx?.signal.aborted).toBeTrue(); + expect(ctx?.signal.reason).toBeInstanceOf(rpcErrors.ErrorRPCTimedOut); + }); + test('raw caller handles abort awaiting stream', async () => { + const forwardPassThroughStream = new TransformStream< + Uint8Array, + Uint8Array + >(); + const reversePassThroughStream = new TransformStream< + Uint8Array, + Uint8Array + >(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + writable: forwardPassThroughStream.writable, + readable: reversePassThroughStream.readable, + }; + const ctxProm = promise(); + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async (ctx) => { + ctxProm.resolveP(ctx); + return streamPair; + }, + logger, + idGen, + }); + const abortController = new AbortController(); + const rejectReason = Symbol('rejectReason'); + // Timing out on stream + const reader = forwardPassThroughStream.readable.getReader(); + const abortProm = promise(); + const ctxWaitProm = ctxProm.p.then((ctx) => { + if (ctx.signal.aborted) abortProm.resolveP(); + ctx.signal.addEventListener('abort', () => { + abortProm.resolveP(); + }); + abortController.abort(rejectReason); + }); + const rawStreamProm = rpcClient.rawStreamCaller( + 'testMethod', + {}, + { signal: abortController.signal }, + ); + await Promise.allSettled([rawStreamProm, reader.read(), ctxWaitProm]); + await expect(rawStreamProm).rejects.toBe(rejectReason); + const ctx = await ctxProm.p; + await abortProm.p; + expect(ctx?.signal.aborted).toBeTrue(); + expect(ctx?.signal.reason).toBe(rejectReason); + }); + }); + describe('duplex caller', () => { + test('duplex caller uses default timeout when creating stream', async () => { + const holdProm = promise(); + let ctx: ContextTimed | undefined; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async (ctx_) => { + ctx = ctx_; + await holdProm.p; + // Should never reach this when testing + return {} as RPCStream; + }, + streamKeepAliveTimeoutTime: 100, + logger, + idGen, + }); + // Timing out on stream creation + const callerInterfaceProm = rpcClient.duplexStreamCaller('testMethod'); + await expect(callerInterfaceProm).toReject(); + await expect(callerInterfaceProm).rejects.toThrow( + rpcErrors.ErrorRPCTimedOut, + ); + expect(ctx?.signal.aborted).toBeTrue(); + expect(ctx?.signal.reason).toBeInstanceOf(rpcErrors.ErrorRPCTimedOut); + }); + test('duplex caller times out when creating stream', async () => { + const holdProm = promise(); + let ctx: ContextTimed | undefined; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async (ctx_) => { + ctx = ctx_; + await holdProm.p; + // Should never reach this when testing + return {} as RPCStream; + }, + logger, + idGen, + }); + // Timing out on stream creation + const callerInterfaceProm = rpcClient.duplexStreamCaller('testMethod', { + timer: 100, + }); + await expect(callerInterfaceProm).toReject(); + await expect(callerInterfaceProm).rejects.toThrow( + rpcErrors.ErrorRPCTimedOut, + ); + expect(ctx?.signal.aborted).toBeTrue(); + expect(ctx?.signal.reason).toBeInstanceOf(rpcErrors.ErrorRPCTimedOut); + }); + test('duplex caller handles abort when creating stream', async () => { + const holdProm = promise(); + let ctx: ContextTimed | undefined; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async (ctx_) => { + ctx = ctx_; + await holdProm.p; + // Should never reach this when testing + return {} as RPCStream; + }, + logger, + idGen, + }); + const abortController = new AbortController(); + const rejectReason = Symbol('rejectReason'); + abortController.abort(rejectReason); + + // Timing out on stream creation + const callerInterfaceProm = rpcClient.duplexStreamCaller('testMethod', { + signal: abortController.signal, + }); + await expect(callerInterfaceProm).toReject(); + await expect(callerInterfaceProm).rejects.toBe(rejectReason); + expect(ctx?.signal.aborted).toBeTrue(); + expect(ctx?.signal.reason).toBe(rejectReason); + }); + test('duplex caller uses default timeout awaiting stream', async () => { + const forwardPassThroughStream = new TransformStream< + Uint8Array, + Uint8Array + >(); + const reversePassThroughStream = new TransformStream< + Uint8Array, + Uint8Array + >(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + writable: forwardPassThroughStream.writable, + readable: reversePassThroughStream.readable, + }; + let ctx: ContextTimed | undefined; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async (ctx_) => { + ctx = ctx_; + return streamPair; + }, + streamKeepAliveTimeoutTime: 100, + logger, + idGen, + }); + + // Timing out on stream + await rpcClient.duplexStreamCaller('testMethod'); + await ctx?.timer; + expect(ctx?.signal.aborted).toBeTrue(); + expect(ctx?.signal.reason).toBeInstanceOf(rpcErrors.ErrorRPCTimedOut); + }); + test('duplex caller times out awaiting stream', async () => { + const forwardPassThroughStream = new TransformStream< + Uint8Array, + Uint8Array + >(); + const reversePassThroughStream = new TransformStream< + Uint8Array, + Uint8Array + >(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + writable: forwardPassThroughStream.writable, + readable: reversePassThroughStream.readable, + }; + let ctx: ContextTimed | undefined; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async (ctx_) => { + ctx = ctx_; + return streamPair; + }, + logger, + idGen, + }); + + // Timing out on stream + await rpcClient.duplexStreamCaller('testMethod', { + timer: 100, + }); + await ctx?.timer; + expect(ctx?.signal.aborted).toBeTrue(); + expect(ctx?.signal.reason).toBeInstanceOf(rpcErrors.ErrorRPCTimedOut); + }); + test('duplex caller handles abort awaiting stream', async () => { + const forwardPassThroughStream = new TransformStream< + Uint8Array, + Uint8Array + >(); + const reversePassThroughStream = new TransformStream< + Uint8Array, + Uint8Array + >(); + const streamPair: RPCStream = { + cancel: async (reason) => { + await forwardPassThroughStream.readable.cancel(reason); + await reversePassThroughStream.writable.abort(reason); + }, + meta: undefined, + writable: forwardPassThroughStream.writable, + readable: reversePassThroughStream.readable, + }; + const ctxProm = promise(); + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async (ctx) => { + ctxProm.resolveP(ctx); + return streamPair; + }, + logger, + idGen, + }); + const abortController = new AbortController(); + const rejectReason = Symbol('rejectReason'); + abortController.abort(rejectReason); + // Timing out on stream + const stream = await rpcClient.duplexStreamCaller('testMethod', { + signal: abortController.signal, + }); + const ctx = await ctxProm.p; + const abortProm = promise(); + if (ctx.signal.aborted) abortProm.resolveP(); + ctx.signal.addEventListener('abort', () => { + abortProm.resolveP(); + }); + expect(ctx?.signal.aborted).toBeTrue(); + expect(ctx?.signal.reason).toBe(rejectReason); + stream.cancel(Error('asd')); + }); + testProp( + 'duplex caller timeout is refreshed when sending message', + [specificMessageArb], + async (messages) => { + const inputStream = rpcTestUtils.messagesToReadableStream(messages); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const ctxProm = promise(); + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async (ctx) => { + ctxProm.resolveP(ctx); + return streamPair; + }, + logger, + idGen, + }); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName, { timer: 200 }); + + const ctx = await ctxProm.p; + // Reading refreshes timer + const reader = callerInterface.readable.getReader(); + await sleep(50); + let timeLeft = ctx.timer.getTimeout(); + const message = await reader.read(); + expect(ctx.timer.getTimeout()).toBeGreaterThanOrEqual(timeLeft); + reader.releaseLock(); + for await (const _ of callerInterface.readable) { + // Do nothing + } + + // Writing should refresh timer + const writer = callerInterface.writable.getWriter(); + await sleep(50); + timeLeft = ctx.timer.getTimeout(); + await writer.write(message.value); + expect(ctx.timer.getTimeout()).toBeGreaterThanOrEqual(timeLeft); + await writer.close(); + + await outputResult; + await rpcClient.destroy(); + }, + { numRuns: 5 }, + ); + testProp( + 'Check that ctx is provided to the middleWare and that the middleware can reset the timer', + [specificMessageArb], + async (messages) => { + const inputStream = rpcTestUtils.messagesToReadableStream(messages); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: RPCStream = { + cancel: () => {}, + meta: undefined, + readable: inputStream, + writable: outputStream, + }; + const ctxProm = promise(); + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamFactory: async (ctx) => { + ctxProm.resolveP(ctx); + return streamPair; + }, + middlewareFactory: rpcUtilsMiddleware.defaultClientMiddlewareWrapper( + (ctx) => { + ctx.timer.reset(123); + return { + forward: new TransformStream(), + reverse: new TransformStream(), + }; + }, + ), + logger, + idGen, + }); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName); + + const ctx = await ctxProm.p; + // Writing should refresh timer engage the middleware + const writer = callerInterface.writable.getWriter(); + await writer.write({}); + expect(ctx.timer.delay).toBe(123); + await writer.close(); + + await outputResult; + await rpcClient.destroy(); + }, + { numRuns: 1 }, + ); + }); +}); diff --git a/tests/RPCServer.test.ts b/tests/RPCServer.test.ts new file mode 100644 index 0000000..80213ae --- /dev/null +++ b/tests/RPCServer.test.ts @@ -0,0 +1,1155 @@ +import type { ContextTimed } from '@matrixai/contexts'; +import type { + ContainerType, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResponseError, + JSONValue, + RPCStream, +} from '@/types'; +import type { RPCErrorEvent } from '@/events'; +import type { IdGen } from '@/types'; +import { ReadableStream, TransformStream, WritableStream } from 'stream/web'; +import { fc, testProp } from '@fast-check/jest'; +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import RPCServer from '@/RPCServer'; +import * as rpcErrors from '@/errors/errors'; +import * as rpcUtils from '@/utils'; +import { promise, sleep } from '@/utils'; +import * as rpcUtilsMiddleware from '@/utils/middleware'; +import ServerHandler from '@/handlers/ServerHandler'; +import DuplexHandler from '@/handlers/DuplexHandler'; +import RawHandler from '@/handlers/RawHandler'; +import UnaryHandler from '@/handlers/UnaryHandler'; +import ClientHandler from '@/handlers/ClientHandler'; +import * as rpcTestUtils from './utils'; + +describe(`${RPCServer.name}`, () => { + const logger = new Logger(`${RPCServer.name} Test`, LogLevel.WARN, [ + new StreamHandler(), + ]); + const idGen: IdGen = () => Promise.resolve(null); + const methodName = 'testMethod'; + const specificMessageArb = fc + .array(rpcTestUtils.jsonRpcRequestMessageArb(fc.constant(methodName)), { + minLength: 5, + }) + .noShrink(); + const singleNumberMessageArb = fc.array( + rpcTestUtils.jsonRpcRequestMessageArb( + fc.constant(methodName), + fc.integer({ min: 1, max: 20 }), + ), + { + minLength: 2, + maxLength: 10, + }, + ); + const validToken = 'VALIDTOKEN'; + const invalidTokenMessageArb = rpcTestUtils.jsonRpcRequestMessageArb( + fc.constant('testMethod'), + fc.record({ + metadata: fc.record({ + token: fc.string().filter((v) => v !== validToken), + }), + data: rpcTestUtils.safeJsonValueArb, + }), + ); + + testProp( + 'can stream data with raw duplex stream handler', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils + .messagesToReadableStream(messages) + .pipeThrough( + rpcTestUtils.binaryStreamToSnippedStream([4, 7, 13, 2, 6]), + ); + class TestHandler extends RawHandler { + public handle = async ( + input: [JSONRPCRequest, ReadableStream], + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): Promise<[JSONValue, ReadableStream]> => { + for await (const _ of input[1]) { + // No touch, only consume + } + const readableStream = new ReadableStream({ + start: (controller) => { + controller.enqueue(Buffer.from('hello world!')); + controller.close(); + }, + }); + return Promise.resolve([null, readableStream]); + }; + } + + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestHandler({}), + }, + logger, + idGen, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + await outputResult; + await rpcServer.destroy(); + }, + { numRuns: 1 }, + ); + testProp( + 'can stream data with duplex stream handler', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncGenerator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncGenerator { + for await (const val of input) { + yield val; + break; + } + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + await outputResult; + await rpcServer.destroy(); + }, + ); + testProp( + 'can stream data with client stream handler', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + class TestMethod extends ClientHandler { + public handle = async ( + input: AsyncGenerator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): Promise => { + let count = 0; + for await (const _ of input) { + count += 1; + } + return count; + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + await outputResult; + await rpcServer.destroy(); + }, + ); + testProp( + 'can stream data with server stream handler', + [singleNumberMessageArb], + async (messages) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + class TestMethod extends ServerHandler { + public handle = async function* ( + input: number, + ): AsyncGenerator { + for (let i = 0; i < input; i++) { + yield i; + } + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + await outputResult; + await rpcServer.destroy(); + }, + ); + testProp( + 'can stream data with server stream handler', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + class TestMethod extends UnaryHandler { + public handle = async (input: JSONValue): Promise => { + return input; + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + await outputResult; + await rpcServer.destroy(); + }, + ); + testProp( + 'handler is provided with container', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + const container = { + a: Symbol('a'), + B: Symbol('b'), + C: Symbol('c'), + }; + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncGenerator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncGenerator { + expect(this.container).toBe(container); + for await (const val of input) { + yield val; + } + }; + } + + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod(container), + }, + logger, + idGen, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + await outputResult; + await rpcServer.destroy(); + }, + ); + testProp( + 'handler is provided with connectionInfo', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + const meta = { + localHost: 'hostA', + localPort: 12341, + remoteCertificates: [], + remoteHost: 'hostA', + remotePort: 12341, + }; + let handledMeta; + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncGenerator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncGenerator { + handledMeta = meta; + for await (const val of input) { + yield val; + } + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + meta, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + await outputResult; + await rpcServer.destroy(); + expect(handledMeta).toBe(meta); + }, + ); + testProp('handler can be aborted', [specificMessageArb], async (messages) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncGenerator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncGenerator { + for await (const val of input) { + if (ctx.signal.aborted) throw ctx.signal.reason; + yield val; + } + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + let thing; + const tapStream = rpcTestUtils.tapTransformStream( + async (_, iteration) => { + if (iteration === 2) { + // @ts-ignore: kidnap private property + const activeStreams = rpcServer.activeStreams.values(); + // @ts-ignore: kidnap private property + for (const activeStream of activeStreams) { + thing = activeStream; + activeStream.cancel(new rpcErrors.ErrorRPCStopping()); + } + } + }, + ); + void tapStream.readable.pipeTo(outputStream).catch(() => {}); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: tapStream.writable, + }; + rpcServer.handleStream(readWriteStream); + const result = await outputResult; + const lastMessage = result[result.length - 1]; + await expect(thing).toResolve(); + expect(lastMessage).toBeDefined(); + expect(() => + rpcUtils.parseJSONRPCResponseError(JSON.parse(lastMessage.toString())), + ).not.toThrow(); + await rpcServer.destroy(); + }); + testProp('handler yields nothing', [specificMessageArb], async (messages) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncGenerator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncGenerator { + for await (const _ of input) { + // Do nothing, just consume + } + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + await outputResult; + // We're just expecting no errors + await rpcServer.destroy(); + }); + testProp( + 'should send error message', + [specificMessageArb, rpcTestUtils.errorArb(rpcTestUtils.errorArb())], + async (messages, error) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + class TestMethod extends DuplexHandler { + public handle = async function* (): AsyncGenerator { + throw error; + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + let resolve, reject; + const errorProm = new Promise((resolve_, reject_) => { + resolve = resolve_; + reject = reject_; + }); + rpcServer.addEventListener('error', (thing: RPCErrorEvent) => { + resolve(thing); + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + const rawErrorMessage = (await outputResult)[0]!.toString(); + const errorMessage = JSON.parse(rawErrorMessage); + expect(errorMessage.error.message).toEqual(error.description); + reject(); + await expect(errorProm).toReject(); + await rpcServer.destroy(); + }, + ); + testProp( + 'should send error message with sensitive', + [specificMessageArb, rpcTestUtils.errorArb(rpcTestUtils.errorArb())], + async (messages, error) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + class TestMethod extends DuplexHandler { + public handle = async function* (): AsyncGenerator { + throw error; + }; + } + + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + sensitive: true, + logger, + idGen, + }); + let resolve, reject; + const errorProm = new Promise((resolve_, reject_) => { + resolve = resolve_; + reject = reject_; + }); + rpcServer.addEventListener('error', (thing: RPCErrorEvent) => { + resolve(thing); + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + const rawErrorMessage = (await outputResult)[0]!.toString(); + const errorMessage = JSON.parse(rawErrorMessage); + expect(errorMessage.error.message).toEqual(error.description); + reject(); + await expect(errorProm).toReject(); + await rpcServer.destroy(); + }, + ); + testProp( + 'should emit stream error if input stream fails', + [specificMessageArb], + async (messages) => { + const handlerEndedProm = promise(); + class TestMethod extends DuplexHandler { + public handle = async function* (input): AsyncGenerator { + try { + for await (const _ of input) { + // Consume but don't yield anything + } + } finally { + handlerEndedProm.resolveP(); + } + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + let resolve; + rpcServer.addEventListener('error', (thing: RPCErrorEvent) => { + resolve(thing); + }); + const passThroughStreamIn = new TransformStream(); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: passThroughStreamIn.readable, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + const writer = passThroughStreamIn.writable.getWriter(); + // Write messages + for (const message of messages) { + await writer.write(Buffer.from(JSON.stringify(message))); + } + // Abort stream + const writerReason = Symbol('writerAbort'); + await writer.abort(writerReason); + // We should get an error RPC message + await expect(outputResult).toResolve(); + const errorMessage = JSON.parse((await outputResult)[0].toString()); + // Parse without error + rpcUtils.parseJSONRPCResponseError(errorMessage); + // Check that the handler was cleaned up. + await expect(handlerEndedProm.p).toResolve(); + await rpcServer.destroy(); + }, + { numRuns: 1 }, + ); + testProp( + 'should emit stream error if output stream fails', + [specificMessageArb], + async (messages) => { + const handlerEndedProm = promise(); + let ctx: ContextTimed | undefined; + class TestMethod extends DuplexHandler { + public handle = async function* ( + input, + _cancel, + _meta, + ctx_, + ): AsyncGenerator { + ctx = ctx_; + // Echo input + try { + yield* input; + } finally { + handlerEndedProm.resolveP(); + } + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + idGen, + }); + let resolve; + const errorProm = new Promise((resolve_) => { + resolve = resolve_; + }); + rpcServer.addEventListener('error', (thing: RPCErrorEvent) => { + resolve(thing); + }); + const passThroughStreamIn = new TransformStream(); + const passThroughStreamOut = new TransformStream< + Uint8Array, + Uint8Array + >(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: passThroughStreamIn.readable, + writable: passThroughStreamOut.writable, + }; + rpcServer.handleStream(readWriteStream); + const writer = passThroughStreamIn.writable.getWriter(); + const reader = passThroughStreamOut.readable.getReader(); + // Write messages + for (const message of messages) { + await writer.write(Buffer.from(JSON.stringify(message))); + await reader.read(); + } + // Abort stream + // const writerReason = Symbol('writerAbort'); + const readerReason = Symbol('readerAbort'); + // Await writer.abort(writerReason); + await reader.cancel(readerReason); + // We should get an error event + const event = await errorProm; + await writer.close(); + // Expect(event.detail.cause).toContain(writerReason); + expect(event.detail).toBeInstanceOf(rpcErrors.ErrorRPCStreamEnded); + // Check that the handler was cleaned up. + await expect(handlerEndedProm.p).toResolve(); + // Check that an abort signal happened + expect(ctx).toBeDefined(); + expect(ctx?.signal.aborted).toBeTrue(); + expect(ctx?.signal.reason).toBe(readerReason); + await rpcServer.destroy(); + }, + { numRuns: 1 }, + ); + testProp('forward middlewares', [specificMessageArb], async (messages) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncGenerator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncGenerator { + yield* input; + }; + } + const middlewareFactory = rpcUtilsMiddleware.defaultServerMiddlewareWrapper( + () => { + return { + forward: new TransformStream({ + transform: (chunk, controller) => { + chunk.params = 1; + controller.enqueue(chunk); + }, + }), + reverse: new TransformStream(), + }; + }, + ); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + middlewareFactory: middlewareFactory, + logger, + idGen, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + const out = await outputResult; + expect(out.map((v) => v!.toString())).toStrictEqual( + messages.map(() => { + return JSON.stringify({ + jsonrpc: '2.0', + result: 1, + id: null, + }); + }), + ); + await rpcServer.destroy(); + }); + testProp('reverse middlewares', [specificMessageArb], async (messages) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncGenerator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncGenerator { + yield* input; + }; + } + const middleware = rpcUtilsMiddleware.defaultServerMiddlewareWrapper(() => { + return { + forward: new TransformStream(), + reverse: new TransformStream({ + transform: (chunk, controller) => { + if ('result' in chunk) chunk.result = 1; + controller.enqueue(chunk); + }, + }), + }; + }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + middlewareFactory: middleware, + logger, + idGen, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + const out = await outputResult; + expect(out.map((v) => v!.toString())).toStrictEqual( + messages.map(() => { + return JSON.stringify({ + jsonrpc: '2.0', + result: 1, + id: null, + }); + }), + ); + await rpcServer.destroy(); + }); + testProp( + 'forward middleware authentication', + [invalidTokenMessageArb], + async (message) => { + const stream = rpcTestUtils.messagesToReadableStream([message]); + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncGenerator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncGenerator { + yield* input; + }; + } + const middleware = rpcUtilsMiddleware.defaultServerMiddlewareWrapper( + () => { + let first = true; + let reverseController: TransformStreamDefaultController; + return { + forward: new TransformStream< + JSONRPCRequest, + JSONRPCRequest + >({ + transform: (chunk, controller) => { + if (first && chunk.params?.metadata.token !== validToken) { + reverseController.enqueue(failureMessage); + // Closing streams early + controller.terminate(); + reverseController.terminate(); + } + first = false; + controller.enqueue(chunk); + }, + }), + reverse: new TransformStream({ + start: (controller) => { + // Kidnapping reverse controller + reverseController = controller; + }, + transform: (chunk, controller) => { + controller.enqueue(chunk); + }, + }), + }; + }, + ); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + middlewareFactory: middleware, + logger, + idGen, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + type TestType = { + metadata: { + token: string; + }; + data: JSONValue; + }; + const failureMessage: JSONRPCResponseError = { + jsonrpc: '2.0', + id: null, + error: { + code: 1, + message: 'failure of some kind', + }, + }; + rpcServer.handleStream(readWriteStream); + expect((await outputResult).toString()).toEqual( + JSON.stringify(failureMessage), + ); + await rpcServer.destroy(); + }, + ); + + test('timeout with default time after handler selected', async () => { + const ctxProm = promise(); + + // Diagnostic log to indicate the start of the test + + class TestHandler extends RawHandler { + public handle = async ( + _input: [JSONRPCRequest, ReadableStream], + _cancel: (reason?: any) => void, + _meta: Record | undefined, + ctx_: ContextTimed, + ): Promise<[JSONValue, ReadableStream]> => { + return new Promise((resolve, reject) => { + ctxProm.resolveP(ctx_); + + let controller: ReadableStreamController; + const stream = new ReadableStream({ + start: (controller_) => { + controller = controller_; + }, + }); + + ctx_.signal.addEventListener('abort', () => { + controller!.error(Error('ending')); + }); + + // Return something to fulfill the Promise type expectation. + resolve([null, stream]); + }); + }; + } + + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestHandler({}), + }, + handlerTimeoutTime: 100, + logger, + idGen, + }); + + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const stream = rpcTestUtils.messagesToReadableStream([ + { + jsonrpc: '2.0', + method: 'testMethod', + params: null, + }, + { + jsonrpc: '2.0', + method: 'testMethod', + params: null, + }, + ]); + + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + + rpcServer.handleStream(readWriteStream); + + const ctx = await ctxProm.p; + + expect(ctx.timer.delay).toEqual(100); + + await ctx.timer; + + expect(ctx.signal.reason).toBeInstanceOf(rpcErrors.ErrorRPCTimedOut); + + await expect(outputResult).toReject(); + + await rpcServer.destroy(); + }); + test('timeout with default time before handler selected', async () => { + const rpcServer = await RPCServer.createRPCServer({ + manifest: {}, + handlerTimeoutTime: 100, + logger, + idGen, + }); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: new ReadableStream({ + // Ignore + cancel: () => {}, + }), + writable: new WritableStream({ + // Ignore + abort: () => {}, + }), + }; + rpcServer.handleStream(readWriteStream); + // With no handler we can only check alive connections through the server + // @ts-ignore: kidnap protected property + const activeStreams = rpcServer.activeStreams; + for await (const [prom] of activeStreams.entries()) { + await prom; + } + await rpcServer.destroy(); + }); + test('handler overrides timeout', async () => { + { + const waitProm = promise(); + const ctxShortProm = promise(); + class TestMethodShortTimeout extends UnaryHandler { + timeout = 25; + public handle = async ( + input: JSONValue, + _cancel, + _meta, + ctx_, + ): Promise => { + ctxShortProm.resolveP(ctx_); + await waitProm.p; + return input; + }; + } + const ctxLongProm = promise(); + class TestMethodLongTimeout extends UnaryHandler { + timeout = 100; + public handle = async ( + input: JSONValue, + _cancel, + _meta, + ctx_, + ): Promise => { + ctxLongProm.resolveP(ctx_); + await waitProm.p; + return input; + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testShort: new TestMethodShortTimeout({}), + testLong: new TestMethodLongTimeout({}), + }, + handlerTimeoutTime: 50, + logger, + idGen, + }); + const streamShort = rpcTestUtils.messagesToReadableStream([ + { + jsonrpc: '2.0', + method: 'testShort', + params: null, + }, + ]); + const readWriteStreamShort: RPCStream = { + cancel: () => {}, + readable: streamShort, + writable: new WritableStream(), + }; + rpcServer.handleStream(readWriteStreamShort); + // Shorter timeout is updated + const ctxShort = await ctxShortProm.p; + expect(ctxShort.timer.delay).toEqual(25); + const streamLong = rpcTestUtils.messagesToReadableStream([ + { + jsonrpc: '2.0', + method: 'testLong', + params: null, + }, + ]); + const readWriteStreamLong: RPCStream = { + cancel: () => {}, + readable: streamLong, + writable: new WritableStream(), + }; + rpcServer.handleStream(readWriteStreamLong); + + // Longer timeout is set to server's default + const ctxLong = await ctxLongProm.p; + expect(ctxLong.timer.delay).toEqual(50); + waitProm.resolveP(); + await rpcServer.destroy(); + } + }); + test('duplex handler refreshes timeout when messages are sent', async () => { + const contextProm = promise(); + const stepProm1 = promise(); + const stepProm2 = promise(); + const passthroughStream = new TransformStream(); + class TestHandler extends DuplexHandler { + public handle = async function* ( + input: AsyncGenerator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncGenerator { + contextProm.resolveP(ctx); + for await (const _ of input) { + // Do nothing, just consume + } + await stepProm1.p; + yield 1; + await stepProm2.p; + yield 2; + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestHandler({}), + }, + logger, + idGen, + handlerTimeoutTime: 1000, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const requestMessage = Buffer.from( + JSON.stringify({ + jsonrpc: '2.0', + method: 'testMethod', + params: 1, + }), + ); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: passthroughStream.readable, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + const writer = passthroughStream.writable.getWriter(); + await writer.write(requestMessage); + const ctx = await contextProm.p; + const scheduled: Date | undefined = ctx.timer.scheduled; + // Checking writing refreshes timer + await sleep(25); + await writer.write(requestMessage); + expect(ctx.timer.scheduled).toBeAfter(scheduled!); + expect( + ctx.timer.scheduled!.getTime() - scheduled!.getTime(), + ).toBeGreaterThanOrEqual(25); + await writer.close(); + // Checking reading refreshes timer + await sleep(25); + stepProm1.resolveP(); + expect(ctx.timer.scheduled).toBeAfter(scheduled!); + expect( + ctx.timer.scheduled!.getTime() - scheduled!.getTime(), + ).toBeGreaterThanOrEqual(25); + stepProm2.resolveP(); + await outputResult; + await rpcServer.destroy(); + }); + test('stream ending cleans up timer and abortSignal', async () => { + const ctxProm = promise(); + class TestHandler extends RawHandler { + public handle = async ( + input: [JSONRPCRequest, ReadableStream], + _cancel: (reason?: any) => void, + _meta: Record | undefined, + ctx_: ContextTimed, + ): Promise<[JSONValue, ReadableStream]> => { + return new Promise((resolve) => { + ctxProm.resolveP(ctx_); + void (async () => { + for await (const _ of input[1]) { + // Do nothing, just consume + } + })(); + const readableStream = new ReadableStream({ + start: (controller) => { + controller.close(); + }, + }); + resolve([null, readableStream]); + }); + }; + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestHandler({}), + }, + logger, + idGen, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const stream = rpcTestUtils.messagesToReadableStream([ + { + jsonrpc: '2.0', + method: 'testMethod', + params: null, + }, + ]); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + const ctx = await ctxProm.p; + await outputResult; + await rpcServer.destroy(false); + expect(ctx.signal.aborted).toBeTrue(); + expect(ctx.signal.reason).toBeInstanceOf(rpcErrors.ErrorRPCStreamEnded); + // If the timer has already resolved then it was cancelled + await expect(ctx.timer).toReject(); + await rpcServer.destroy(); + }); + testProp( + 'middleware can update timeout timer', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + const ctxProm = promise(); + class TestMethod extends DuplexHandler { + public handle = async function* ( + input: AsyncGenerator, + cancel: (reason?: any) => void, + meta: Record | undefined, + ctx: ContextTimed, + ): AsyncGenerator { + ctxProm.resolveP(ctx); + yield* input; + }; + } + const middlewareFactory = + rpcUtilsMiddleware.defaultServerMiddlewareWrapper((ctx) => { + ctx.timer.reset(12345); + return { + forward: new TransformStream(), + reverse: new TransformStream(), + }; + }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + middlewareFactory: middlewareFactory, + logger, + idGen, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: RPCStream = { + cancel: () => {}, + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream); + await outputResult; + const ctx = await ctxProm.p; + expect(ctx.timer.delay).toBe(12345); + }, + ); +}); diff --git a/tests/index.test.ts b/tests/index.test.ts deleted file mode 100644 index d48a1eb..0000000 --- a/tests/index.test.ts +++ /dev/null @@ -1 +0,0 @@ -describe('index', () => {}); diff --git a/tests/utils.ts b/tests/utils.ts new file mode 100644 index 0000000..b10c37d --- /dev/null +++ b/tests/utils.ts @@ -0,0 +1,301 @@ +import type { ReadableWritablePair } from 'stream/web'; +import type { JSONValue } from '@/types'; +import type { + JSONRPCError, + JSONRPCMessage, + JSONRPCRequestNotification, + JSONRPCRequestMessage, + JSONRPCResponseError, + JSONRPCResponseResult, + JSONRPCResponse, + JSONRPCRequest, +} from '@/types'; +import { ReadableStream, WritableStream, TransformStream } from 'stream/web'; +import { fc } from '@fast-check/jest'; +import * as utils from '@/utils'; +import { fromError } from '@/utils'; +import * as rpcErrors from '@/errors'; +import { ErrorRPC } from '@/errors'; + +/** + * This is used to convert regular chunks into randomly sized chunks based on + * a provided pattern. This is to replicate randomness introduced by packets + * splitting up the data. + */ +function binaryStreamToSnippedStream(snippingPattern: Array) { + let buffer = Buffer.alloc(0); + let iteration = 0; + return new TransformStream({ + transform: (chunk, controller) => { + buffer = Buffer.concat([buffer, chunk]); + while (true) { + const snipAmount = snippingPattern[iteration % snippingPattern.length]; + if (snipAmount > buffer.length) break; + iteration += 1; + const returnBuffer = buffer.subarray(0, snipAmount); + controller.enqueue(returnBuffer); + buffer = buffer.subarray(snipAmount); + } + }, + flush: (controller) => { + controller.enqueue(buffer); + }, + }); +} + +/** + * This is used to convert regular chunks into randomly sized chunks based on + * a provided pattern. This is to replicate randomness introduced by packets + * splitting up the data. + */ +function binaryStreamToNoisyStream(noise: Array) { + let iteration: number = 0; + return new TransformStream({ + transform: (chunk, controller) => { + const noiseBuffer = noise[iteration % noise.length]; + const newBuffer = Buffer.from(Buffer.concat([chunk, noiseBuffer])); + controller.enqueue(newBuffer); + iteration += 1; + }, + }); +} + +/** + * This takes an array of JSONRPCMessages and converts it to a readable stream. + * Used to seed input for handlers and output for callers. + */ +const messagesToReadableStream = (messages: Array) => { + return new ReadableStream({ + async start(controller) { + for (const arrayElement of messages) { + controller.enqueue(Buffer.from(JSON.stringify(arrayElement), 'utf-8')); + } + controller.close(); + }, + }); +}; + +/** + * Out RPC data is in form of JSON objects. + * This creates a JSON object of the type `JSONValue` and will be unchanged by + * a json stringify and parse cycle. + */ +const safeJsonValueArb = fc + .json() + .map((value) => JSON.parse(value.replace('__proto__', 'proto')) as JSONValue); + +const idArb = fc.oneof(fc.string(), fc.integer(), fc.constant(null)); + +const jsonRpcRequestMessageArb = ( + method: fc.Arbitrary = fc.string(), + params: fc.Arbitrary = safeJsonValueArb, +) => + fc + .record( + { + jsonrpc: fc.constant('2.0'), + method: method, + params: params, + id: idArb, + }, + { + requiredKeys: ['jsonrpc', 'method', 'id'], + }, + ) + .noShrink() as fc.Arbitrary; + +const jsonRpcRequestNotificationArb = ( + method: fc.Arbitrary = fc.string(), + params: fc.Arbitrary = safeJsonValueArb, +) => + fc + .record( + { + jsonrpc: fc.constant('2.0'), + method: method, + params: params, + }, + { + requiredKeys: ['jsonrpc', 'method'], + }, + ) + .noShrink() as fc.Arbitrary; + +const jsonRpcRequestArb = ( + method: fc.Arbitrary = fc.string(), + params: fc.Arbitrary = safeJsonValueArb, +) => + fc + .oneof( + jsonRpcRequestMessageArb(method, params), + jsonRpcRequestNotificationArb(method, params), + ) + .noShrink() as fc.Arbitrary; + +const jsonRpcResponseResultArb = ( + result: fc.Arbitrary = safeJsonValueArb, +) => + fc + .record({ + jsonrpc: fc.constant('2.0'), + result: result, + id: idArb, + }) + .noShrink() as fc.Arbitrary; +const jsonRpcErrorArb = ( + error: fc.Arbitrary> = fc.constant(new ErrorRPC('test error')), +) => + fc + .record( + { + code: fc.integer(), + message: fc.string(), + data: error.map((e) => JSON.stringify(fromError(e))), + }, + { + requiredKeys: ['code', 'message'], + }, + ) + .noShrink() as fc.Arbitrary; + +const jsonRpcResponseErrorArb = ( + error?: fc.Arbitrary>, + sensitive: boolean = false, +) => + fc + .record({ + jsonrpc: fc.constant('2.0'), + error: jsonRpcErrorArb(error), + id: idArb, + }) + .noShrink() as fc.Arbitrary; + +const jsonRpcResponseArb = ( + result: fc.Arbitrary = safeJsonValueArb, +) => + fc + .oneof(jsonRpcResponseResultArb(result), jsonRpcResponseErrorArb()) + .noShrink() as fc.Arbitrary; + +const jsonRpcMessageArb = ( + method: fc.Arbitrary = fc.string(), + params: fc.Arbitrary = safeJsonValueArb, + result: fc.Arbitrary = safeJsonValueArb, +) => + fc + .oneof(jsonRpcRequestArb(method, params), jsonRpcResponseArb(result)) + .noShrink() as fc.Arbitrary; + +const snippingPatternArb = fc + .array(fc.integer({ min: 1, max: 32 }), { minLength: 100, size: 'medium' }) + .noShrink(); + +const jsonMessagesArb = fc + .array(jsonRpcRequestMessageArb(), { minLength: 2 }) + .noShrink(); + +const rawDataArb = fc.array(fc.uint8Array({ minLength: 1 }), { minLength: 1 }); + +function streamToArray(): [Promise>, WritableStream] { + const outputArray: Array = []; + const result = utils.promise>(); + const outputStream = new WritableStream({ + write: (chunk) => { + outputArray.push(chunk); + }, + close: () => { + result.resolveP(outputArray); + }, + abort: (reason) => { + result.rejectP(reason); + }, + }); + return [result.p, outputStream]; +} + +type TapCallback = (chunk: T, iteration: number) => Promise; + +/** + * This is used to convert regular chunks into randomly sized chunks based on + * a provided pattern. This is to replicate randomness introduced by packets + * splitting up the data. + */ +function tapTransformStream(tapCallback: TapCallback = async () => {}) { + let iteration: number = 0; + return new TransformStream({ + transform: async (chunk, controller) => { + try { + await tapCallback(chunk, iteration); + } catch (e) { + // Ignore errors here + } + controller.enqueue(chunk); + iteration += 1; + }, + }); +} + +function createTapPairs( + forwardTapCallback: TapCallback = async () => {}, + reverseTapCallback: TapCallback = async () => {}, +) { + const forwardTap = tapTransformStream(forwardTapCallback); + const reverseTap = tapTransformStream(reverseTapCallback); + const clientPair: ReadableWritablePair = { + readable: reverseTap.readable, + writable: forwardTap.writable, + }; + const serverPair: ReadableWritablePair = { + readable: forwardTap.readable, + writable: reverseTap.writable, + }; + return { + clientPair, + serverPair, + }; +} + +const errorArb = ( + cause: fc.Arbitrary = fc.constant(undefined), +) => + cause.chain((cause) => + fc.oneof( + fc.constant(new rpcErrors.ErrorRPCRemote()), + fc.constant(new rpcErrors.ErrorRPCMessageLength(undefined)), + fc.constant( + new rpcErrors.ErrorRPCRemote( + { + command: 'someCommand', + host: `someHost`, + port: 0, + }, + undefined, + { + cause, + }, + ), + ), + ), + ); + +export { + binaryStreamToSnippedStream, + binaryStreamToNoisyStream, + messagesToReadableStream, + safeJsonValueArb, + jsonRpcRequestMessageArb, + jsonRpcRequestNotificationArb, + jsonRpcRequestArb, + jsonRpcResponseResultArb, + jsonRpcErrorArb, + jsonRpcResponseErrorArb, + jsonRpcResponseArb, + jsonRpcMessageArb, + snippingPatternArb, + jsonMessagesArb, + rawDataArb, + streamToArray, + tapTransformStream, + createTapPairs, + errorArb, +}; diff --git a/tests/utils/middleware.test.ts b/tests/utils/middleware.test.ts new file mode 100644 index 0000000..bf744a7 --- /dev/null +++ b/tests/utils/middleware.test.ts @@ -0,0 +1,105 @@ +import type { JSONRPCMessage, JSONValue } from '@/types'; +import { TransformStream } from 'stream/web'; +import { fc, testProp } from '@fast-check/jest'; +import { JSONParser } from '@streamparser/json'; +import { AsyncIterableX as AsyncIterable } from 'ix/asynciterable'; +import * as rpcUtils from '@/utils'; +import 'ix/add/asynciterable-operators/toarray'; +import * as rpcErrors from '@/errors'; +import * as rpcUtilsMiddleware from '@/utils/middleware'; +import * as rpcTestUtils from '../utils'; + +describe('Middleware tests', () => { + const noiseArb = fc + .array( + fc.uint8Array({ minLength: 5 }).map((array) => Buffer.from(array)), + { minLength: 5 }, + ) + .noShrink(); + + testProp( + 'can parse json stream', + [rpcTestUtils.jsonMessagesArb], + async (messages) => { + const parsedStream = rpcTestUtils + .messagesToReadableStream(messages) + .pipeThrough( + rpcUtilsMiddleware.binaryToJsonMessageStream( + rpcUtils.parseJSONRPCMessage, + ), + ); // Converting back. + + const asd = await AsyncIterable.as(parsedStream).toArray(); + expect(asd).toEqual(messages); + }, + { numRuns: 1000 }, + ); + testProp( + 'Message size limit is enforced when parsing', + [ + fc.array( + rpcTestUtils.jsonRpcRequestMessageArb(fc.string({ minLength: 100 })), + { + minLength: 1, + }, + ), + ], + async (messages) => { + const parsedStream = rpcTestUtils + .messagesToReadableStream(messages) + .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream([10])) + .pipeThrough( + rpcUtilsMiddleware.binaryToJsonMessageStream( + rpcUtils.parseJSONRPCMessage, + 50, + ), + ); + + const doThing = async () => { + for await (const _ of parsedStream) { + // No touch, only consume + } + }; + await expect(doThing()).rejects.toThrow(rpcErrors.ErrorRPCMessageLength); + }, + { numRuns: 1000 }, + ); + testProp( + 'can parse json stream with random chunk sizes', + [rpcTestUtils.jsonMessagesArb, rpcTestUtils.snippingPatternArb], + async (messages, snippattern) => { + const parsedStream = rpcTestUtils + .messagesToReadableStream(messages) + .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream(snippattern)) // Imaginary internet here + .pipeThrough( + rpcUtilsMiddleware.binaryToJsonMessageStream( + rpcUtils.parseJSONRPCMessage, + ), + ); // Converting back. + + const asd = await AsyncIterable.as(parsedStream).toArray(); + expect(asd).toStrictEqual(messages); + }, + { numRuns: 1000 }, + ); + testProp( + 'Will error on bad data', + [rpcTestUtils.jsonMessagesArb, rpcTestUtils.snippingPatternArb, noiseArb], + async (messages, snippattern, noise) => { + const parsedStream = rpcTestUtils + .messagesToReadableStream(messages) + .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream(snippattern)) // Imaginary internet here + .pipeThrough(rpcTestUtils.binaryStreamToNoisyStream(noise)) // Adding bad data to the stream + .pipeThrough( + rpcUtilsMiddleware.binaryToJsonMessageStream( + rpcUtils.parseJSONRPCMessage, + ), + ); // Converting back. + + await expect(AsyncIterable.as(parsedStream).toArray()).rejects.toThrow( + rpcErrors.ErrorRPCParse, + ); + }, + { numRuns: 1000 }, + ); +}); diff --git a/tests/utils/utils.test.ts b/tests/utils/utils.test.ts new file mode 100644 index 0000000..8594ce7 --- /dev/null +++ b/tests/utils/utils.test.ts @@ -0,0 +1,26 @@ +import { testProp, fc } from '@fast-check/jest'; +import { JSONParser } from '@streamparser/json'; +import * as rpcUtils from '@/utils'; +import 'ix/add/asynciterable-operators/toarray'; +import * as rpcTestUtils from '../utils'; + +describe('utils tests', () => { + testProp( + 'can parse messages', + [rpcTestUtils.jsonRpcMessageArb()], + async (message) => { + rpcUtils.parseJSONRPCMessage(message); + }, + { numRuns: 1000 }, + ); + testProp( + 'malformed data cases parsing errors', + [fc.json()], + async (message) => { + expect(() => + rpcUtils.parseJSONRPCMessage(Buffer.from(JSON.stringify(message))), + ).toThrow(); + }, + { numRuns: 1000 }, + ); +}); diff --git a/tsconfig.json b/tsconfig.json index 907ed72..a120436 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -8,6 +8,7 @@ "allowJs": true, "strictNullChecks": true, "noImplicitAny": false, + "experimentalDecorators": true, "esModuleInterop": true, "allowSyntheticDefaultImports": true, "resolveJsonModule": true,