From a4975408797f501144734213bcf7fee9f5d41d03 Mon Sep 17 00:00:00 2001 From: Xander Song Date: Tue, 1 Oct 2024 16:42:05 -0700 Subject: [PATCH] feat(playground): streaming chat completions (#4785) * add subscriptions * configure relay * playground mvp * align the runs for multiple * add more functionality * cleanup the components --------- Co-authored-by: Mikyo King --- app/package.json | 5 +- app/pnpm-lock.yaml | 83 +++++++------- app/schema.graphql | 8 ++ app/src/RelayEnvironment.ts | 28 ++++- .../components/auth/OneTimeAPIKeyDialog.tsx | 2 +- app/src/pages/dataset/DatasetCodeDropdown.tsx | 6 +- .../pages/playground/MessageRolePicker.tsx | 46 ++++++++ .../playground/PlaygroundChatTemplate.tsx | 20 +++- .../PlaygroundInputModeRadioGroup.tsx | 2 +- .../pages/playground/PlaygroundInstance.tsx | 2 +- app/src/pages/playground/PlaygroundOutput.tsx | 101 +++++++++++++++++- .../pages/playground/PlaygroundTemplate.tsx | 71 +++++++++--- .../PlaygroundOutputSubscription.graphql.ts | 82 ++++++++++++++ app/src/pages/settings/SettingsPage.tsx | 2 +- app/src/pages/trace/SpanCodeDropdown.tsx | 4 +- app/src/pages/trace/SpanDetails.tsx | 3 - app/src/store/playgroundStore.tsx | 71 +++++++++++- pyproject.toml | 1 + src/phoenix/server/api/schema.py | 2 + src/phoenix/server/api/subscriptions.py | 76 +++++++++++++ src/phoenix/server/app.py | 2 + 21 files changed, 544 insertions(+), 73 deletions(-) create mode 100644 app/src/pages/playground/MessageRolePicker.tsx create mode 100644 app/src/pages/playground/__generated__/PlaygroundOutputSubscription.graphql.ts create mode 100644 src/phoenix/server/api/subscriptions.py diff --git a/app/package.json b/app/package.json index cd16447d6a9..03095faa6c6 100644 --- a/app/package.json +++ b/app/package.json @@ -6,7 +6,7 @@ "license": "None", "private": true, "dependencies": { - "@arizeai/components": "^1.8.1", + "@arizeai/components": "^1.8.3", "@arizeai/openinference-semantic-conventions": "^0.10.0", "@arizeai/point-cloud": "^3.0.6", "@codemirror/autocomplete": "6.12.0", @@ -25,6 +25,7 @@ "d3-scale-chromatic": "^3.1.0", "d3-time-format": "^4.1.0", "date-fns": "^3.6.0", + "graphql-ws": "^5.16.0", "lodash": "^4.17.21", "normalize.css": "^8.0.1", "polished": "^4.3.1", @@ -57,7 +58,7 @@ "@types/jest": "^29.5.12", "@types/lodash": "^4.17.7", "@types/node": "^22.5.4", - "@types/react": "18.2.48", + "@types/react": "^18.3.10", "@types/react-dom": "^18.3.0", "@types/react-relay": "^16.0.6", "@types/recharts": "^1.8.29", diff --git a/app/pnpm-lock.yaml b/app/pnpm-lock.yaml index f3a0ca56610..bab64f61191 100644 --- a/app/pnpm-lock.yaml +++ b/app/pnpm-lock.yaml @@ -12,14 +12,14 @@ importers: .: dependencies: '@arizeai/components': - specifier: ^1.8.1 - version: 1.8.1(@types/react@18.2.48)(eslint@8.57.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0) + specifier: ^1.8.3 + version: 1.8.3(@types/react@18.3.10)(eslint@8.57.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0) '@arizeai/openinference-semantic-conventions': specifier: ^0.10.0 version: 0.10.0 '@arizeai/point-cloud': specifier: ^3.0.6 - version: 3.0.6(@react-three/drei@9.108.4(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@types/react@18.2.48)(@types/three@0.149.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(react@18.2.0)(three-stdlib@2.30.4(three@0.139.2))(three@0.139.2) + version: 3.0.6(@react-three/drei@9.108.4(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@types/react@18.3.10)(@types/three@0.149.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(react@18.2.0)(three-stdlib@2.30.4(three@0.139.2))(three@0.139.2) '@codemirror/autocomplete': specifier: 6.12.0 version: 6.12.0(@codemirror/language@6.10.2)(@codemirror/state@6.4.1)(@codemirror/view@6.28.5)(@lezer/common@1.2.1) @@ -40,7 +40,7 @@ importers: version: 6.28.5 '@react-three/drei': specifier: ^9.108.4 - version: 9.108.4(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@types/react@18.2.48)(@types/three@0.149.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2) + version: 9.108.4(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@types/react@18.3.10)(@types/three@0.149.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2) '@react-three/fiber': specifier: 8.0.12 version: 8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2) @@ -68,6 +68,9 @@ importers: date-fns: specifier: ^3.6.0 version: 3.6.0 + graphql-ws: + specifier: ^5.16.0 + version: 5.16.0(graphql@16.9.0) lodash: specifier: ^4.17.21 version: 4.17.21 @@ -94,7 +97,7 @@ importers: version: 4.5.0(react-dom@18.2.0(react@18.2.0))(react@18.2.0) react-markdown: specifier: ^9.0.1 - version: 9.0.1(@types/react@18.2.48)(react@18.2.0) + version: 9.0.1(@types/react@18.3.10)(react@18.2.0) react-relay: specifier: ^16.2.0 version: 16.2.0(react@18.2.0) @@ -133,11 +136,11 @@ importers: version: 0.0.4(react@18.2.0) zustand: specifier: ^4.5.4 - version: 4.5.4(@types/react@18.2.48)(react@18.2.0) + version: 4.5.4(@types/react@18.3.10)(react@18.2.0) devDependencies: '@emotion/react': specifier: ^11.11.4 - version: 11.11.4(@types/react@18.2.48)(react@18.2.0) + version: 11.11.4(@types/react@18.3.10)(react@18.2.0) '@playwright/test': specifier: ^1.47.0 version: 1.47.0 @@ -160,8 +163,8 @@ importers: specifier: ^22.5.4 version: 22.5.4 '@types/react': - specifier: 18.2.48 - version: 18.2.48 + specifier: ^18.3.10 + version: 18.3.10 '@types/react-dom': specifier: ^18.3.0 version: 18.3.0 @@ -244,8 +247,8 @@ packages: resolution: {integrity: sha512-30iZtAPgz+LTIYoeivqYo853f02jBYSd5uGnGpkFV0M3xOt9aN73erkgYAmZU43x4VfqcnLxW9Kpg3R5LC4YYw==} engines: {node: '>=6.0.0'} - '@arizeai/components@1.8.1': - resolution: {integrity: sha512-djcZ9noXJRaHsYHYJXAGKeON+vacT3aJvKdFcF0hJeWZ/j7wMGCpvzYwxRgLUooZvj5ifTGBzG9DmF6/sC+v5g==} + '@arizeai/components@1.8.3': + resolution: {integrity: sha512-U+EV8+GDm0yLNh+xXcEiKY4v7cJgepHMVLRD9pmtYYYdCgiBOLZ7TNQv6gxcRRvlxuqUgLIcP//nwY/roI6MdQ==} engines: {node: '>=14'} peerDependencies: react: '>=18' @@ -1631,8 +1634,8 @@ packages: '@types/react-relay@16.0.6': resolution: {integrity: sha512-VTntVQJhlwQYNUlbNgGf8RYy7EtQPRZqsD/w2Si0ygZspJXuNlVdRkklWMFN99EMRhHDpqlNHD8i3wIs7QRz9g==} - '@types/react@18.2.48': - resolution: {integrity: sha512-qboRCl6Ie70DQQG9hhNREz81jqC1cs9EVNcjQ1AU+jH6NFfSAhVVbrrY/+nSF+Bsk4AOwm9Qa61InvMCyV+H3w==} + '@types/react@18.3.10': + resolution: {integrity: sha512-02sAAlBnP39JgXwkAq3PeU9DVaaGpZyF3MGcC0MKgQVkZor5IiiDAipVaxQHtDJAmO4GIy/rVBy/LzVj76Cyqg==} '@types/recharts@1.8.29': resolution: {integrity: sha512-ulKklaVsnFIIhTQsQw226TnOibrddW1qUQNFVhoQEyY1Z7FRQrNecFCGt7msRuJseudzE9czVawZb17dK/aPXw==} @@ -1640,9 +1643,6 @@ packages: '@types/relay-runtime@14.1.24': resolution: {integrity: sha512-ta7vPoFXtEG1wu0Mk7sTngzhmfNGnIe8cDiy3yBEm8pJcGpv55YY/+vWrd9gYd9OQht8rALZpXIYSOLzS/0PVg==} - '@types/scheduler@0.23.0': - resolution: {integrity: sha512-YIoDCTH3Af6XM5VuwGG/QL/CJqga1Zm3NkU3HZ4ZHK2fRMPYP1VczsTUqtsf43PH/iJNVlPHAo2oWX7BSdB2Hw==} - '@types/stack-utils@2.0.3': resolution: {integrity: sha512-9aEbYZ3TbYMznPdcdr3SmIrLXwC/AKZXQeCf9Pgao5CKb8CyHuEX5jzWPTkvregvhRJHcpRO6BFoGW9ycaOkYw==} @@ -2789,6 +2789,12 @@ packages: graphemer@1.4.0: resolution: {integrity: sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==} + graphql-ws@5.16.0: + resolution: {integrity: sha512-Ju2RCU2dQMgSKtArPbEtsK5gNLnsQyTNIo/T7cZNp96niC1x0KdJNZV0TIoilceBPQwfb5itrGl8pkFeOUMl4A==} + engines: {node: '>=10'} + peerDependencies: + graphql: '>=0.11 <=16' + graphql@15.3.0: resolution: {integrity: sha512-GTCJtzJmkFLWRfFJuoo9RWWa/FfamUHgiFosxi/X1Ani4AVWbeyBenZTNX6dM+7WSbbFfTo/25eh0LLkwHMw2w==} engines: {node: '>= 10.x'} @@ -4879,9 +4885,9 @@ snapshots: '@jridgewell/gen-mapping': 0.3.5 '@jridgewell/trace-mapping': 0.3.25 - '@arizeai/components@1.8.1(@types/react@18.2.48)(eslint@8.57.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0)': + '@arizeai/components@1.8.3(@types/react@18.3.10)(eslint@8.57.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0)': dependencies: - '@emotion/react': 11.11.4(@types/react@18.2.48)(react@18.2.0) + '@emotion/react': 11.11.4(@types/react@18.3.10)(react@18.2.0) '@react-aria/breadcrumbs': 3.5.13(react@18.2.0) '@react-aria/button': 3.9.5(react@18.2.0) '@react-aria/dialog': 3.5.14(react-dom@18.2.0(react@18.2.0))(react@18.2.0) @@ -4923,9 +4929,9 @@ snapshots: '@arizeai/openinference-semantic-conventions@0.10.0': {} - '@arizeai/point-cloud@3.0.6(@react-three/drei@9.108.4(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@types/react@18.2.48)(@types/three@0.149.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(react@18.2.0)(three-stdlib@2.30.4(three@0.139.2))(three@0.139.2)': + '@arizeai/point-cloud@3.0.6(@react-three/drei@9.108.4(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@types/react@18.3.10)(@types/three@0.149.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(react@18.2.0)(three-stdlib@2.30.4(three@0.139.2))(three@0.139.2)': dependencies: - '@react-three/drei': 9.108.4(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@types/react@18.2.48)(@types/three@0.149.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2) + '@react-three/drei': 9.108.4(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@types/react@18.3.10)(@types/three@0.149.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2) '@react-three/fiber': 8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2) react: 18.2.0 three: 0.139.2 @@ -5282,7 +5288,7 @@ snapshots: '@emotion/memoize@0.8.1': {} - '@emotion/react@11.11.4(@types/react@18.2.48)(react@18.2.0)': + '@emotion/react@11.11.4(@types/react@18.3.10)(react@18.2.0)': dependencies: '@babel/runtime': 7.24.8 '@emotion/babel-plugin': 11.11.0 @@ -5294,7 +5300,7 @@ snapshots: hoist-non-react-statics: 3.3.2 react: 18.2.0 optionalDependencies: - '@types/react': 18.2.48 + '@types/react': 18.3.10 transitivePeerDependencies: - supports-color @@ -6265,7 +6271,7 @@ snapshots: '@swc/helpers': 0.5.12 react: 18.2.0 - '@react-three/drei@9.108.4(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@types/react@18.2.48)(@types/three@0.149.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2)': + '@react-three/drei@9.108.4(@react-three/fiber@8.0.12(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2))(@types/react@18.3.10)(@types/three@0.149.0)(react-dom@18.2.0(react@18.2.0))(react@18.2.0)(three@0.139.2)': dependencies: '@babel/runtime': 7.24.8 '@mediapipe/tasks-vision': 0.10.8 @@ -6289,7 +6295,7 @@ snapshots: three-mesh-bvh: 0.7.6(three@0.139.2) three-stdlib: 2.30.4(three@0.139.2) troika-three-text: 0.49.1(three@0.139.2) - tunnel-rat: 0.1.2(@types/react@18.2.48)(react@18.2.0) + tunnel-rat: 0.1.2(@types/react@18.3.10)(react@18.2.0) utility-types: 3.11.0 uuid: 9.0.1 zustand: 3.7.2(react@18.2.0) @@ -6603,32 +6609,29 @@ snapshots: '@types/react-dom@18.3.0': dependencies: - '@types/react': 18.2.48 + '@types/react': 18.3.10 '@types/react-reconciler@0.26.7': dependencies: - '@types/react': 18.2.48 + '@types/react': 18.3.10 '@types/react-relay@16.0.6': dependencies: - '@types/react': 18.2.48 + '@types/react': 18.3.10 '@types/relay-runtime': 14.1.24 - '@types/react@18.2.48': + '@types/react@18.3.10': dependencies: '@types/prop-types': 15.7.12 - '@types/scheduler': 0.23.0 csstype: 3.1.3 '@types/recharts@1.8.29': dependencies: '@types/d3-shape': 1.3.12 - '@types/react': 18.2.48 + '@types/react': 18.3.10 '@types/relay-runtime@14.1.24': {} - '@types/scheduler@0.23.0': {} - '@types/stack-utils@2.0.3': {} '@types/stats.js@0.17.3': {} @@ -8077,6 +8080,10 @@ snapshots: graphemer@1.4.0: {} + graphql-ws@5.16.0(graphql@16.9.0): + dependencies: + graphql: 16.9.0 + graphql@15.3.0: {} graphql@16.9.0: {} @@ -9705,10 +9712,10 @@ snapshots: react-is@18.3.1: {} - react-markdown@9.0.1(@types/react@18.2.48)(react@18.2.0): + react-markdown@9.0.1(@types/react@18.3.10)(react@18.2.0): dependencies: '@types/hast': 3.0.4 - '@types/react': 18.2.48 + '@types/react': 18.3.10 devlop: 1.1.0 hast-util-to-jsx-runtime: 2.3.0 html-url-attributes: 3.0.0 @@ -10320,9 +10327,9 @@ snapshots: tslib@2.6.3: {} - tunnel-rat@0.1.2(@types/react@18.2.48)(react@18.2.0): + tunnel-rat@0.1.2(@types/react@18.3.10)(react@18.2.0): dependencies: - zustand: 4.5.4(@types/react@18.2.48)(react@18.2.0) + zustand: 4.5.4(@types/react@18.3.10)(react@18.2.0) transitivePeerDependencies: - '@types/react' - immer @@ -10661,11 +10668,11 @@ snapshots: optionalDependencies: react: 18.2.0 - zustand@4.5.4(@types/react@18.2.48)(react@18.2.0): + zustand@4.5.4(@types/react@18.3.10)(react@18.2.0): dependencies: use-sync-external-store: 1.2.0(react@18.2.0) optionalDependencies: - '@types/react': 18.2.48 + '@types/react': 18.3.10 react: 18.2.0 zwitch@2.0.4: {} diff --git a/app/schema.graphql b/app/schema.graphql index 9d3c39903d6..f7ec737a63d 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -65,6 +65,10 @@ enum AuthMethod { union Bin = NominalBin | IntervalBin | MissingValueBin +input ChatCompletionInput { + message: String! +} + input ClearProjectInput { id: GlobalID! @@ -1372,6 +1376,10 @@ enum SpanStatusCode { UNSET } +type Subscription { + chatCompletion(input: ChatCompletionInput!): String! +} + type SystemApiKey implements ApiKey & Node { """Name of the API key.""" name: String! diff --git a/app/src/RelayEnvironment.ts b/app/src/RelayEnvironment.ts index 6122d551e07..ee4dbe6a51c 100644 --- a/app/src/RelayEnvironment.ts +++ b/app/src/RelayEnvironment.ts @@ -1,9 +1,15 @@ +import { createClient, Sink } from "graphql-ws"; import { Environment, FetchFunction, + GraphQLResponse, Network, + Observable, RecordSource, + RequestParameters, Store, + SubscribeFunction, + Variables, } from "relay-runtime"; import { authFetch } from "@phoenix/authFetch"; @@ -50,9 +56,29 @@ const fetchRelay: FetchFunction = async (params, variables, _cacheConfig) => { return json; }; +const wsClient = createClient({ + url: "ws://localhost:6006/graphql", +}); + +const subscribe: SubscribeFunction = ( + operation: RequestParameters, + variables: Variables +) => { + return Observable.create((sink) => { + return wsClient.subscribe( + { + operationName: operation.name, + query: operation.text as string, + variables, + }, + sink as Sink + ); + }); +}; + // Export a singleton instance of Relay Environment configured with our network layer: export default new Environment({ - network: Network.create(fetchRelay), + network: Network.create(fetchRelay, subscribe), store: new Store(new RecordSource(), { // This property tells Relay to not immediately clear its cache when the user // navigates around the app. Relay will hold onto the specified number of diff --git a/app/src/components/auth/OneTimeAPIKeyDialog.tsx b/app/src/components/auth/OneTimeAPIKeyDialog.tsx index 82e1d31c874..a85b3b86d9b 100644 --- a/app/src/components/auth/OneTimeAPIKeyDialog.tsx +++ b/app/src/components/auth/OneTimeAPIKeyDialog.tsx @@ -37,7 +37,7 @@ export function OneTimeAPIKeyDialog(props: { jwt: string }) { - + diff --git a/app/src/pages/dataset/DatasetCodeDropdown.tsx b/app/src/pages/dataset/DatasetCodeDropdown.tsx index 61480577394..cd338421c6b 100644 --- a/app/src/pages/dataset/DatasetCodeDropdown.tsx +++ b/app/src/pages/dataset/DatasetCodeDropdown.tsx @@ -72,7 +72,7 @@ export function DatasetCodeDropdown() { width="100%" > - + @@ -105,7 +105,7 @@ export function DatasetCodeDropdown() {
- +
diff --git a/app/src/pages/playground/MessageRolePicker.tsx b/app/src/pages/playground/MessageRolePicker.tsx new file mode 100644 index 00000000000..33307a5539a --- /dev/null +++ b/app/src/pages/playground/MessageRolePicker.tsx @@ -0,0 +1,46 @@ +import React from "react"; +import { css } from "@emotion/react"; + +import { Item, Picker } from "@arizeai/components"; + +import { ChatMessageRole } from "@phoenix/store"; + +const hiddenLabelCSS = css` + .ac-field-label { + display: none; + } +`; + +type MessageRolePickerProps = { + /** + * The currently selected message role + */ + role: ChatMessageRole; + /** + * Whether to display a label for the picker + * This may be set to false in cases where the picker is rendered in a table for instance + * @default true + */ + includeLabel?: boolean; +}; + +export function MessageRolePicker({ + role, + includeLabel = true, +}: MessageRolePickerProps) { + return ( + {}} + > + System + User + Assistant + + ); +} diff --git a/app/src/pages/playground/PlaygroundChatTemplate.tsx b/app/src/pages/playground/PlaygroundChatTemplate.tsx index 99623e68c8e..f2103de9873 100644 --- a/app/src/pages/playground/PlaygroundChatTemplate.tsx +++ b/app/src/pages/playground/PlaygroundChatTemplate.tsx @@ -1,9 +1,11 @@ import React from "react"; +import { css } from "@emotion/react"; import { Card, TextArea } from "@arizeai/components"; import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext"; +import { MessageRolePicker } from "./MessageRolePicker"; import { PlaygroundInstanceProps } from "./types"; interface PlaygroundChatTemplateProps extends PlaygroundInstanceProps {} @@ -22,11 +24,25 @@ export function PlaygroundChatTemplate(props: PlaygroundChatTemplateProps) { } return ( -
    +
      {template.messages.map((message, index) => { return (
    • - + + } + variant="compact" + backgroundColor="light" + borderColor="light" + >