From 066bb888225967d7ff2bf4af0e5de61171f2994b Mon Sep 17 00:00:00 2001 From: CJ Cenizal Date: Fri, 12 Jul 2024 14:01:07 -0700 Subject: [PATCH] Fix StreamQueryClient handling of HTTP errors and add support for None reranker and GenerationInfo event. (#24) * Add prettier rules. --- .prettierrc.json | 5 + docs/src/index.tsx | 1 + src/apiV1/client.test.ts | 10 +- src/apiV2/EventBuffer.test.ts | 14 +- src/apiV2/EventBuffer.ts | 45 ++-- src/apiV2/apiTypes.ts | 10 +- src/apiV2/client.mocks.ts | 10 +- src/apiV2/client.test.ts | 279 +++++++++++++++--------- src/apiV2/client.ts | 87 ++++---- src/apiV2/types.ts | 8 + src/common/createTestStreamingServer.ts | 32 ++- src/common/generateStream.ts | 20 +- 12 files changed, 309 insertions(+), 212 deletions(-) create mode 100644 .prettierrc.json diff --git a/.prettierrc.json b/.prettierrc.json new file mode 100644 index 0000000..b28a34e --- /dev/null +++ b/.prettierrc.json @@ -0,0 +1,5 @@ +{ + "semi": true, + "printWidth": 120, + "trailingComma": "none" +} diff --git a/docs/src/index.tsx b/docs/src/index.tsx index c8b52c2..380dddf 100644 --- a/docs/src/index.tsx +++ b/docs/src/index.tsx @@ -75,6 +75,7 @@ const App = () => { sentencesBefore: 2, sentencesAfter: 2, }, + reranker: "none", }, generation: { maxUsedSearchResults: 5, diff --git a/src/apiV1/client.test.ts b/src/apiV1/client.test.ts index 9eea40e..200d901 100644 --- a/src/apiV1/client.test.ts +++ b/src/apiV1/client.test.ts @@ -10,13 +10,13 @@ describe("stream-query-client API v1", () => { let server: SetupServerApi; beforeAll(async () => { - server = createTestStreamingServer( - "/v1/stream-query", + server = createTestStreamingServer({ + path: "/v1/stream-query", chunks, - (json: any) => { + createChunk: (json: any) => { return encoder.encode(JSON.stringify(json)); - } - ); + }, + }); await server.listen(); }); diff --git a/src/apiV2/EventBuffer.test.ts b/src/apiV2/EventBuffer.test.ts index 583a264..e9164fb 100644 --- a/src/apiV2/EventBuffer.test.ts +++ b/src/apiV2/EventBuffer.test.ts @@ -16,8 +16,8 @@ data:{"type":"end"} expect(onStreamEvent).toHaveBeenNthCalledWith(1, { type: "error", messages: [ - "INVALID_ARGUMENT: The filter expression contains an error. Syntax error at 1:0 nc79bc8s must be referenced as doc.nc79bc8s or part.nc79bc8s", - ], + "INVALID_ARGUMENT: The filter expression contains an error. Syntax error at 1:0 nc79bc8s must be referenced as doc.nc79bc8s or part.nc79bc8s" + ] }); expect(onStreamEvent).toHaveBeenNthCalledWith(2, { type: "end" }); @@ -42,7 +42,7 @@ data:{"type":"search_results", expect(onStreamEvent).toHaveBeenCalledWith({ type: "searchResults", - searchResults: [{ id: "doc1" }], + searchResults: [{ id: "doc1" }] }); }); @@ -63,8 +63,8 @@ data:{"type":"end"} expect(onStreamEvent).toHaveBeenNthCalledWith(1, { type: "error", messages: [ - "INVALID_ARGUMENT: The filter expression contains an error. Syntax error at 1:0 nc79bc8s must be referenced as doc.nc79bc8s or part.nc79bc8s", - ], + "INVALID_ARGUMENT: The filter expression contains an error. Syntax error at 1:0 nc79bc8s must be referenced as doc.nc79bc8s or part.nc79bc8s" + ] }); expect(onStreamEvent).toHaveBeenNthCalledWith(2, { type: "end" }); @@ -82,7 +82,7 @@ data:{"type":"end"} type: "unexpectedError", raw: ` {"messages":["Request failed. See https://status.vectara.com for the latest info on any outages. If the problem persists, please contact us via support or via our community forums at https://discuss.vectara.com if you’re a Growth user."],"request_id":"00000000000000000000000000000000"} - `, + ` }); }); @@ -98,7 +98,7 @@ data:{"type":"apocalypse"} expect(onStreamEvent).toHaveBeenCalledWith({ type: "unexpectedEvent", rawType: "meteor_strike", - raw: { type: "apocalypse" }, + raw: { type: "apocalypse" } }); }); }); diff --git a/src/apiV2/EventBuffer.ts b/src/apiV2/EventBuffer.ts index 434a341..f406b49 100644 --- a/src/apiV2/EventBuffer.ts +++ b/src/apiV2/EventBuffer.ts @@ -8,11 +8,7 @@ export class EventBuffer { private eventInProgress = ""; private updatedText = ""; - constructor( - onStreamEvent: (event: any) => void, - includeRaw = false, - status = 200 - ) { + constructor(onStreamEvent: (event: any) => void, includeRaw = false, status = 200) { this.events = []; this.onStreamEvent = onStreamEvent; this.includeRaw = includeRaw; @@ -44,8 +40,12 @@ export class EventBuffer { const rawEvent = JSON.parse(this.eventInProgress); this.enqueueEvent(rawEvent); this.eventInProgress = ""; - } catch { - // @tes-expect-error no-empty + } catch (error: any) { + const isJsonError = error.stack.includes("at JSON.parse"); + // Silently ignore JSON parsing errors, as they are expected. + if (!isJsonError) { + console.error(error); + } } }); @@ -61,6 +61,8 @@ export class EventBuffer { turn_id, factual_consistency_score, generation_chunk, + rendered_prompt, + rephrased_query } = rawEvent; switch (type) { @@ -68,7 +70,7 @@ export class EventBuffer { this.events.push({ type: "error", messages, - ...(this.includeRaw && { raw: rawEvent }), + ...(this.includeRaw && { raw: rawEvent }) }); break; @@ -76,7 +78,7 @@ export class EventBuffer { this.events.push({ type: "searchResults", searchResults: search_results, - ...(this.includeRaw && { raw: rawEvent }), + ...(this.includeRaw && { raw: rawEvent }) }); break; @@ -85,7 +87,7 @@ export class EventBuffer { type: "chatInfo", chatId: chat_id, turnId: turn_id, - ...(this.includeRaw && { raw: rawEvent }), + ...(this.includeRaw && { raw: rawEvent }) }); break; @@ -95,14 +97,23 @@ export class EventBuffer { type: "generationChunk", updatedText: this.updatedText, generationChunk: generation_chunk, - ...(this.includeRaw && { raw: rawEvent }), + ...(this.includeRaw && { raw: rawEvent }) + }); + break; + + case "generation_info": + this.events.push({ + type: "generationInfo", + renderedPrompt: rendered_prompt, + rephrasedQuery: rephrased_query, + ...(this.includeRaw && { raw: rawEvent }) }); break; case "generation_end": this.events.push({ type: "generationEnd", - ...(this.includeRaw && { raw: rawEvent }), + ...(this.includeRaw && { raw: rawEvent }) }); break; @@ -110,14 +121,14 @@ export class EventBuffer { this.events.push({ type: "factualConsistencyScore", factualConsistencyScore: factual_consistency_score, - ...(this.includeRaw && { raw: rawEvent }), + ...(this.includeRaw && { raw: rawEvent }) }); break; case "end": this.events.push({ type: "end", - ...(this.includeRaw && { raw: rawEvent }), + ...(this.includeRaw && { raw: rawEvent }) }); break; @@ -126,20 +137,20 @@ export class EventBuffer { this.events.push({ type: "unexpectedEvent", rawType: type, - raw: rawEvent, + raw: rawEvent }); } else if (this.status !== 200) { // Assume an error. this.events.push({ type: "requestError", status: this.status, - raw: rawEvent, + raw: rawEvent }); } else { // Assume an error. this.events.push({ type: "unexpectedError", - raw: rawEvent, + raw: rawEvent }); } } diff --git a/src/apiV2/apiTypes.ts b/src/apiV2/apiTypes.ts index a2e12f3..19631e8 100644 --- a/src/apiV2/apiTypes.ts +++ b/src/apiV2/apiTypes.ts @@ -1,6 +1,8 @@ import { SummaryLanguage } from "../common/types"; export namespace Query { + export type NoneReranker = { type: "none" }; + export type CustomerSpecificReranker = { type: "customer_reranker"; reranker_id: string; @@ -29,7 +31,7 @@ export namespace Query { start_tag?: string; end_tag?: string; }; - reranker?: CustomerSpecificReranker | MmrReranker; + reranker?: NoneReranker | CustomerSpecificReranker | MmrReranker; }; export type NoneCitations = { @@ -64,11 +66,7 @@ export namespace Query { frequency_penalty: number; presence_penalty: number; }; - citations?: - | NoneCitations - | NumericCitations - | HtmlCitations - | MarkdownCitations; + citations?: NoneCitations | NumericCitations | HtmlCitations | MarkdownCitations; enable_factual_consistency_score?: boolean; }; diff --git a/src/apiV2/client.mocks.ts b/src/apiV2/client.mocks.ts index b5f6d9e..a05c460 100644 --- a/src/apiV2/client.mocks.ts +++ b/src/apiV2/client.mocks.ts @@ -10,12 +10,16 @@ data:{"type":"chat_info","chat_id":"cht_74b5a5f3-1f51-4427-a317-f62efb493928","t const chunk3 = `event:generation_chunk data:{"type":"generation_chunk","generation_chunk":"Markdown is "}`; +// Generation info. +const chunk4 = `event:generation_info +data:{"type":"generation_info","rephrased_query":"Rephrased query","rendered_prompt":"Rendered prompt"}`; + // FCS. -const chunk4 = `event:factual_consistency_score +const chunk5 = `event:factual_consistency_score data:{"type":"factual_consistency_score","factual_consistency_score":0.41796625}`; // // End. -const chunk5 = `event:end +const chunk6 = `event:end data:{"type":"end"}`; -export const chunks = [chunk1, chunk2, chunk3, chunk4, chunk5]; +export const chunks = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]; diff --git a/src/apiV2/client.test.ts b/src/apiV2/client.test.ts index 92bcb67..23ab630 100644 --- a/src/apiV2/client.test.ts +++ b/src/apiV2/client.test.ts @@ -6,126 +6,189 @@ import { chunks } from "./client.mocks"; const encoder = new TextEncoder(); +const streamQueryConfig: StreamQueryConfig = { + customerId: "1366999410", + apiKey: "zqt_UXrBcnI2UXINZkrv4g1tQPhzj02vfdtqYJIDiA", + corpusKey: "1", + query: "test query", + search: { + offset: 0, + limit: 5, + metadataFilter: "" + }, + generation: { + maxUsedSearchResults: 5, + responseLanguage: "eng", + enableFactualConsistencyScore: true, + promptName: "vectara-experimental-summary-ext-2023-12-11-large" + }, + chat: { + store: true + } +}; + describe("stream-query-client API v2", () => { let server: SetupServerApi; - beforeAll(async () => { - server = createTestStreamingServer("/v2/chats", chunks, (value: any) => { - return encoder.encode(value); + describe("happy path", () => { + beforeAll(async () => { + server = createTestStreamingServer({ + path: "/v2/chats", + chunks, + createChunk: (value: any) => { + return encoder.encode(value); + } + }); + await server.listen(); }); - await server.listen(); - }); - afterEach(() => { - server.resetHandlers(); - }); + afterEach(() => { + server.resetHandlers(); + }); + + afterAll(() => { + server.close(); + }); + + it("streamQuery converts streamed chunks into usable data", async () => { + const handleEvent = jest.fn(); + + const onStreamEvent = (event: StreamEvent) => { + handleEvent(event); + + if (event.type === "end") { + expect(handleEvent).toHaveBeenNthCalledWith(1, { + type: "searchResults", + searchResults: [ + { + text: "(If you're not a Markdown Here user, check out the Markdown Cheatsheet that is not specific to MDH. But, really, you should also use Markdown Here, because it's awesome. http://markdown-here.com)", + score: 0.7467775344848633, + document_metadata: { + "Application-Name": "Microsoft Word 12.0.0", + "Application-Version": 12, + "Character Count": 475, + "Character-Count-With-Spaces": 583, + "Content-Type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "Creation-Date": "2021-02-25T10:03:47Z", + "Last-Modified": "2021-02-25T10:03:47Z", + "Last-Save-Date": "2021-02-25T10:03:47Z", + "Line-Count": 12, + "Page-Count": 1, + "Paragraph-Count": 8, + Template: "Normal.dotm", + "Total-Time": 6, + "Word-Count": 83, + "X-Parsed-By": "org.apache.tika.parser.microsoft.ooxml.OOXMLParser", + date: "2021-02-25T10:03:47Z", + "dcterms:created": "2021-02-25T10:03:47Z", + "dcterms:modified": "2021-02-25T10:03:47Z", + "extended-properties:AppVersion": 12, + "extended-properties:Application": "Microsoft Word 12.0.0", + "extended-properties:DocSecurityString": "None", + "extended-properties:Template": "Normal.dotm", + "extended-properties:TotalTime": 6, + "meta:character-count": 475, + "meta:character-count-with-spaces": 583, + "meta:creation-date": "2021-02-25T10:03:47Z", + "meta:line-count": 12, + "meta:page-count": 1, + "meta:paragraph-count": 8, + "meta:save-date": "2021-02-25T10:03:47Z", + "meta:word-count": 83, + modified: "2021-02-25T10:03:47Z", + "xmpTPg:NPages": 1 + }, + part_metadata: { + lang: "eng", + len: 25, + offset: 648 + }, + document_id: "914e8885-1a65-4b56-a279-95661b264f3b" + } + ] + }); + + expect(handleEvent).toHaveBeenNthCalledWith(2, { + type: "chatInfo", + chatId: "cht_74b5a5f3-1f51-4427-a317-f62efb493928", + turnId: "trn_74b5a5f3-1f51-4427-a317-f62efb493928" + }); + + expect(handleEvent).toHaveBeenNthCalledWith(3, { + type: "generationChunk", + generationChunk: "Markdown is ", + updatedText: "Markdown is " + }); + + expect(handleEvent).toHaveBeenNthCalledWith(4, { + type: "generationInfo", + renderedPrompt: "Rendered prompt", + rephrasedQuery: "Rephrased query" + }); + + expect(handleEvent).toHaveBeenNthCalledWith(5, { + type: "factualConsistencyScore", + factualConsistencyScore: 0.41796625 + }); + + expect(handleEvent).toHaveBeenNthCalledWith(6, { type: "end" }); + } + }; + + await streamQueryV2({ streamQueryConfig, onStreamEvent }); + }); - afterAll(() => { - server.close(); + it("surfaces response headers", async () => { + const { responseHeaders } = await streamQueryV2({ + streamQueryConfig, + onStreamEvent: jest.fn() + }); + expect(responseHeaders?.get("Content-Type")).toBe("text/json"); + }); + + it("surfaces response status", async () => { + const { status } = await streamQueryV2({ + streamQueryConfig, + onStreamEvent: jest.fn() + }); + expect(status).toBe(200); + }); }); - it("streamQuery converts streamed chunks into usable data", async () => { - const streamQueryConfig: StreamQueryConfig = { - customerId: "1366999410", - apiKey: "zqt_UXrBcnI2UXINZkrv4g1tQPhzj02vfdtqYJIDiA", - corpusKey: "1", - query: "test query", - search: { - offset: 0, - limit: 5, - metadataFilter: "", - }, - generation: { - maxUsedSearchResults: 5, - responseLanguage: "eng", - enableFactualConsistencyScore: true, - promptName: "vectara-experimental-summary-ext-2023-12-11-large", - }, - chat: { - store: true, - }, - }; - - const handleEvent = jest.fn(); - - const onStreamEvent = (event: StreamEvent) => { - handleEvent(event); - - if (event.type === "end") { - expect(handleEvent).toHaveBeenNthCalledWith(1, { - type: "searchResults", - searchResults: [ - { - text: "(If you're not a Markdown Here user, check out the Markdown Cheatsheet that is not specific to MDH. But, really, you should also use Markdown Here, because it's awesome. http://markdown-here.com)", - score: 0.7467775344848633, - document_metadata: { - "Application-Name": "Microsoft Word 12.0.0", - "Application-Version": 12, - "Character Count": 475, - "Character-Count-With-Spaces": 583, - "Content-Type": - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "Creation-Date": "2021-02-25T10:03:47Z", - "Last-Modified": "2021-02-25T10:03:47Z", - "Last-Save-Date": "2021-02-25T10:03:47Z", - "Line-Count": 12, - "Page-Count": 1, - "Paragraph-Count": 8, - Template: "Normal.dotm", - "Total-Time": 6, - "Word-Count": 83, - "X-Parsed-By": - "org.apache.tika.parser.microsoft.ooxml.OOXMLParser", - date: "2021-02-25T10:03:47Z", - "dcterms:created": "2021-02-25T10:03:47Z", - "dcterms:modified": "2021-02-25T10:03:47Z", - "extended-properties:AppVersion": 12, - "extended-properties:Application": "Microsoft Word 12.0.0", - "extended-properties:DocSecurityString": "None", - "extended-properties:Template": "Normal.dotm", - "extended-properties:TotalTime": 6, - "meta:character-count": 475, - "meta:character-count-with-spaces": 583, - "meta:creation-date": "2021-02-25T10:03:47Z", - "meta:line-count": 12, - "meta:page-count": 1, - "meta:paragraph-count": 8, - "meta:save-date": "2021-02-25T10:03:47Z", - "meta:word-count": 83, - modified: "2021-02-25T10:03:47Z", - "xmpTPg:NPages": 1, - }, - part_metadata: { - lang: "eng", - len: 25, - offset: 648, - }, - document_id: "914e8885-1a65-4b56-a279-95661b264f3b", - }, - ], - }); + describe("unhappy path", () => { + beforeAll(async () => { + server = createTestStreamingServer({ + path: "/v2/chats", + chunks, + createChunk: (value: any) => { + return encoder.encode(value); + }, + shouldRequestsFail: true + }); + await server.listen(); + }); - expect(handleEvent).toHaveBeenNthCalledWith(2, { - type: "chatInfo", - chatId: "cht_74b5a5f3-1f51-4427-a317-f62efb493928", - turnId: "trn_74b5a5f3-1f51-4427-a317-f62efb493928", - }); + afterEach(() => { + server.resetHandlers(); + }); - expect(handleEvent).toHaveBeenNthCalledWith(3, { - type: "generationChunk", - generationChunk: "Markdown is ", - updatedText: "Markdown is ", - }); + afterAll(() => { + server.close(); + }); - expect(handleEvent).toHaveBeenNthCalledWith(4, { - type: "factualConsistencyScore", - factualConsistencyScore: 0.41796625, - }); + it("surfaces HTTP errors", async () => { + const handleEvent = jest.fn(); - expect(handleEvent).toHaveBeenNthCalledWith(5, { type: "end" }); - } - }; + const onStreamEvent = (event: StreamEvent) => { + handleEvent(event); - await streamQueryV2({ streamQueryConfig, onStreamEvent }); + expect(handleEvent).toHaveBeenNthCalledWith(1, { + type: "genericError", + error: new Error("Request failed (Not found)") + }); + }; + + await streamQueryV2({ streamQueryConfig, onStreamEvent }); + }); }); }); diff --git a/src/apiV2/client.ts b/src/apiV2/client.ts index f3cdc34..60bafa0 100644 --- a/src/apiV2/client.ts +++ b/src/apiV2/client.ts @@ -3,29 +3,33 @@ import { StreamQueryConfig, StreamEventHandler, StreamQueryRequest, - StreamQueryRequestHeaders, + StreamQueryRequestHeaders } from "./types"; import { Query } from "./apiTypes"; import { DEFAULT_DOMAIN } from "../common/constants"; import { generateStream } from "../common/generateStream"; import { EventBuffer } from "./EventBuffer"; -const convertReranker = ( - reranker?: StreamQueryConfig["search"]["reranker"] -) => { +const convertReranker = (reranker?: StreamQueryConfig["search"]["reranker"]) => { if (!reranker) return; + if (reranker.type === "none") { + return { + type: reranker.type + }; + } + if (reranker.type === "customer_reranker") { return { type: reranker.type, - reranker_id: reranker.rerankerId, + reranker_id: reranker.rerankerId }; } if (reranker.type === "mmr") { return { type: reranker.type, - diversity_bias: reranker.diversityBias, + diversity_bias: reranker.diversityBias }; } }; @@ -35,7 +39,7 @@ const convertCitations = (citations?: GenerationConfig["citations"]) => { if (citations.style === "none" || citations.style === "numeric") { return { - style: citations.style, + style: citations.style }; } @@ -43,7 +47,7 @@ const convertCitations = (citations?: GenerationConfig["citations"]) => { return { style: citations.style, url_pattern: citations.urlPattern, - text_pattern: citations.textPattern, + text_pattern: citations.textPattern }; } }; @@ -51,7 +55,7 @@ const convertCitations = (citations?: GenerationConfig["citations"]) => { export const streamQueryV2 = async ({ streamQueryConfig, onStreamEvent, - includeRawEvents = false, + includeRawEvents = false }: { streamQueryConfig: StreamQueryConfig; onStreamEvent: StreamEventHandler; @@ -72,10 +76,10 @@ export const streamQueryV2 = async ({ offset, limit, contextConfiguration, - reranker, + reranker }, generation, - chat, + chat } = streamQueryConfig; const body: Query.Body = { @@ -87,8 +91,8 @@ export const streamQueryV2 = async ({ metadata_filter: metadataFilter, lexical_interpolation: lexicalInterpolation, custom_dimensions: customDimensions, - semantics, - }, + semantics + } ], offset, limit, @@ -98,11 +102,11 @@ export const streamQueryV2 = async ({ sentences_before: contextConfiguration?.sentencesBefore, sentences_after: contextConfiguration?.sentencesAfter, start_tag: contextConfiguration?.startTag, - end_tag: contextConfiguration?.endTag, + end_tag: contextConfiguration?.endTag }, - reranker: convertReranker(reranker), + reranker: convertReranker(reranker) }, - stream_response: true, + stream_response: true }; if (generation) { @@ -114,7 +118,7 @@ export const streamQueryV2 = async ({ responseLanguage, modelParameters, citations, - enableFactualConsistencyScore, + enableFactualConsistencyScore } = generation; body.generation = { @@ -127,16 +131,16 @@ export const streamQueryV2 = async ({ max_tokens: modelParameters.maxTokens, temperature: modelParameters.temperature, frequency_penalty: modelParameters.frequencyPenalty, - presence_penalty: modelParameters.presencePenalty, + presence_penalty: modelParameters.presencePenalty }, citations: convertCitations(citations), - enable_factual_consistency_score: enableFactualConsistencyScore, + enable_factual_consistency_score: enableFactualConsistencyScore }; } if (chat) { body.chat = { - store: chat.store, + store: chat.store }; } @@ -154,7 +158,7 @@ export const streamQueryV2 = async ({ const headers: StreamQueryRequestHeaders = { "customer-id": customerId, - "Content-Type": "application/json", + "Content-Type": "application/json" }; if (apiKey) headers["x-api-key"] = apiKey; @@ -166,12 +170,11 @@ export const streamQueryV2 = async ({ method: "POST", url, headers, - body, + body }; try { - const { cancelStream, stream, status, responseHeaders } = - await generateStream(headers, JSON.stringify(body), url); + const { cancelStream, stream, status, responseHeaders } = await generateStream(headers, JSON.stringify(body), url); const consumeStream = async () => { try { @@ -181,27 +184,15 @@ export const streamQueryV2 = async ({ try { buffer.consumeChunk(chunk); } catch (error) { - if (error instanceof Error) { - onStreamEvent({ - type: "genericError", - error, - }); - } else { - throw error; - } + handleError(error, onStreamEvent); } } } catch (error) { if (error instanceof DOMException && error.name == "AbortError") { // Swallow the "DOMException: BodyStreamBuffer was aborted" error // triggered by cancelling a stream. - } else if (error instanceof Error) { - onStreamEvent({ - type: "genericError", - error, - }); } else { - throw error; + handleError(error, onStreamEvent); } } }; @@ -210,15 +201,19 @@ export const streamQueryV2 = async ({ return { cancelStream, request, status, responseHeaders }; } catch (error) { - if (error instanceof Error) { - onStreamEvent({ - type: "genericError", - error, - }); - } else { - throw error; - } + handleError(error, onStreamEvent); } return { request }; }; + +const handleError = (error: unknown, onStreamEvent: StreamEventHandler) => { + if (error instanceof Error) { + onStreamEvent({ + type: "genericError", + error + }); + } else { + throw error; + } +}; diff --git a/src/apiV2/types.ts b/src/apiV2/types.ts index 5fcaec3..c4693ca 100644 --- a/src/apiV2/types.ts +++ b/src/apiV2/types.ts @@ -74,6 +74,7 @@ export type StreamQueryConfig = { endTag?: string; }; reranker?: + | { type: "none" } | { type: "customer_reranker"; rerankerId: string; @@ -114,6 +115,7 @@ export type StreamEvent = | SearchResultsEvent | ChatInfoEvent | GenerationChunkEvent + | GenerationInfoEvent | GenerationEndEvent | FactualConsistencyScoreEvent | EndEvent @@ -148,6 +150,12 @@ export type GenerationChunkEvent = BaseEvent & { generationChunk: string; }; +export type GenerationInfoEvent = BaseEvent & { + type: "generationInfo"; + renderedPrompt?: string; + rephrasedQuery?: string; +}; + export type GenerationEndEvent = BaseEvent & { type: "generationEnd"; }; diff --git a/src/common/createTestStreamingServer.ts b/src/common/createTestStreamingServer.ts index ecfc7eb..7196863 100644 --- a/src/common/createTestStreamingServer.ts +++ b/src/common/createTestStreamingServer.ts @@ -1,14 +1,27 @@ import { setupServer } from "msw/node"; -import { http } from "msw"; +import { HttpResponse, http } from "msw"; import { DEFAULT_DOMAIN } from "./constants"; -export const createTestStreamingServer = ( - path: string, - chunks: any[], - createChunk: (value: any) => any -) => { +export const createTestStreamingServer = ({ + path, + chunks, + createChunk, + shouldRequestsFail +}: { + path: string; + chunks: any[]; + createChunk: (value: any) => any; + shouldRequestsFail?: boolean; +}) => { return setupServer( http.post(`${DEFAULT_DOMAIN}${path}`, () => { + if (shouldRequestsFail) { + return new HttpResponse(null, { + status: 404, + statusText: "Not found" + }); + } + const stream = new ReadableStream({ start(controller) { chunks.forEach((chunk) => { @@ -16,15 +29,14 @@ export const createTestStreamingServer = ( }); controller.close(); - }, + } }); // Send the mocked response immediately. const response = new Response(stream, { - // status: 200, headers: { - "Content-Type": "text/json", - }, + "Content-Type": "text/json" + } }); return response; diff --git a/src/common/generateStream.ts b/src/common/generateStream.ts index 574ea0e..c8098b9 100644 --- a/src/common/generateStream.ts +++ b/src/common/generateStream.ts @@ -1,30 +1,30 @@ -export const generateStream = async ( - headers: Record, - body: string, - url: string -) => { +export const generateStream = async (headers: Record, body: string, url: string) => { const controller = new AbortController(); const response = await fetch(url, { method: "POST", headers, body, - signal: controller.signal, + signal: controller.signal }); + if (!response.ok) { + throw new Error(`Request failed (${response.statusText})`, { + cause: response.status + }); + } + if (!response.body) throw new Error("Response body does not exist"); return { stream: getIterableStream(response.body), cancelStream: () => controller.abort(), status: response.status, - responseHeaders: response.headers, + responseHeaders: response.headers }; }; -async function* getIterableStream( - body: ReadableStream -): AsyncIterable { +async function* getIterableStream(body: ReadableStream): AsyncIterable { const reader = body.getReader(); const decoder = new TextDecoder();