diff --git a/app/frontend/src/components/ChatComponent.tsx b/app/frontend/src/components/ChatComponent.tsx deleted file mode 100644 index c947922b..00000000 --- a/app/frontend/src/components/ChatComponent.tsx +++ /dev/null @@ -1,492 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC -"use client"; - -import React, { useState, useEffect, useRef } from "react"; -import { Card } from "./ui/card"; -import { Button } from "./ui/button"; -import * as ScrollArea from "@radix-ui/react-scroll-area"; -import { useLocation } from "react-router-dom"; -import { Spinner } from "./ui/spinner"; -import { User, ChevronDown, Send } from "lucide-react"; -import { Textarea } from "./ui/textarea"; -import logo from "../assets/tt_logo.svg"; -import { - Breadcrumb, - BreadcrumbEllipsis, - BreadcrumbItem, - BreadcrumbLink, - BreadcrumbList, - BreadcrumbPage, - BreadcrumbSeparator, -} from "./ui/breadcrumb"; -import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuItem, - DropdownMenuTrigger, -} from "./ui/dropdown-menu"; -import { fetchModels } from "../api/modelsDeployedApis"; -import ChatExamples from "./ChatExamples"; -import axios from "axios"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "./ui/select"; -import { useQuery } from "react-query"; -import { fetchCollections } from "@/src/components/rag"; -import { - Tooltip, - TooltipContent, - TooltipProvider, - TooltipTrigger, -} from "./ui/tooltip"; -import InferenceStats from "./InferenceStats"; -interface InferenceRequest { - deploy_id: string; - text: string; - rag_context?: { documents: string[] }; -} - -interface RagDataSource { - id: string; - name: string; - metadata: Record; -} - -interface ChatMessage { - sender: "user" | "assistant"; - text: string; - inferenceStats?: InferenceStats; // Optional property for stats -} - -interface Model { - id: string; - name: string; -} - -interface InferenceStats { - user_ttft_ms: number; - user_tps: number; - user_ttft_e2e_ms: number; - prefill: { - tokens_prefilled: number; - tps: number; - }; - decode: { - tokens_decoded: number; - tps: number; - }; - batch_size: number; - context_length: number; -} - -export default function ChatComponent() { - const location = useLocation(); - const [textInput, setTextInput] = useState(""); - const [ragDatasource, setRagDatasource] = useState< - RagDataSource | undefined - >(); - const { data: ragDataSources } = useQuery("collectionsList", { - queryFn: fetchCollections, - initialData: [], - }); - const [chatHistory, setChatHistory] = useState([]); - const [modelID, setModelID] = useState(null); - const [modelName, setModelName] = useState(null); - const [isStreaming, setIsStreaming] = useState(false); - const viewportRef = useRef(null); - const bottomRef = useRef(null); - const [isScrollButtonVisible, setIsScrollButtonVisible] = useState(false); - const [modelsDeployed, setModelsDeployed] = useState([]); - - useEffect(() => { - if (location.state) { - setModelID(location.state.containerID); - setModelName(location.state.modelName); - } - - const loadModels = async () => { - try { - const models = await fetchModels(); - setModelsDeployed(models); - } catch (error) { - console.error("Error fetching models:", error); - } - }; - - loadModels(); - }, [location.state]); - - const scrollToBottom = () => { - if (viewportRef.current) { - viewportRef.current.scrollTo({ - top: viewportRef.current.scrollHeight, - behavior: "smooth", - }); - } - }; - - const handleScroll = () => { - if (viewportRef.current) { - const { scrollTop, scrollHeight, clientHeight } = viewportRef.current; - const isAtBottom = scrollHeight - scrollTop <= clientHeight + 1; - setIsScrollButtonVisible(!isAtBottom); - } - }; - - const getRagContext = async (request: InferenceRequest) => { - const ragContext: { documents: string[] } = { documents: [] }; - - if (!ragDatasource) return ragContext; - - try { - const response = await axios.get( - `/collections-api/${ragDatasource.name}/query`, - { - params: { query: request.text }, - }, - ); - if (response?.data) { - ragContext.documents = response.data.documents; - } - } catch (e) { - console.error(`Error fetching RAG context: ${e}`); - } - - return ragContext; - }; - - const runInference = async (request: InferenceRequest) => { - try { - if (ragDatasource) { - request.rag_context = await getRagContext(request); - } - - setIsStreaming(true); - const response = await fetch(`/models-api/inference/`, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(request), - }); - const reader = response.body?.getReader(); - setChatHistory((prevHistory) => [ - ...prevHistory, - { sender: "user", text: textInput }, - { sender: "assistant", text: "" }, - ]); - setTextInput(""); - - let result = ""; - if (reader) { - let done = false; - while (!done) { - const { done: streamDone, value } = await reader.read(); - done = streamDone; - - if (value) { - const decoder = new TextDecoder(); - const chunk = decoder.decode(value); - console.log("Chunk:", chunk); - result += chunk; - const endOfStreamIndex = result.indexOf("<>"); - if (endOfStreamIndex !== -1) { - result = result.substring(0, endOfStreamIndex); - done = true; - } - const cleanedResult = result - .replace(/<\|eot_id\|>/g, "") // Remove "<|eot_id|>" - .replace(/<\|endoftext\|>/g, "") - .trim(); - const statsStartIndex = cleanedResult.indexOf("{"); - const statsEndIndex = cleanedResult.lastIndexOf("}"); - - let chatContent = cleanedResult; - - if (statsStartIndex !== -1 && statsEndIndex !== -1) { - chatContent = cleanedResult.substring(0, statsStartIndex).trim(); - - const statsJson = cleanedResult.substring( - statsStartIndex, - statsEndIndex + 1, - ); - try { - const parsedStats = JSON.parse(statsJson); - setChatHistory((prevHistory) => { - const updatedHistory = [...prevHistory]; - const lastAssistantMessage = updatedHistory.findLastIndex( - (message) => message.sender === "assistant", - ); - if (lastAssistantMessage !== -1) { - updatedHistory[lastAssistantMessage] = { - ...updatedHistory[lastAssistantMessage], - inferenceStats: parsedStats, - }; - } - return updatedHistory; - }); - } catch (e) { - console.error("Error parsing inference stats:", e); - } - } - - setChatHistory((prevHistory) => { - const updatedHistory = [...prevHistory]; - updatedHistory[updatedHistory.length - 1] = { - ...updatedHistory[updatedHistory.length - 1], - text: chatContent, - }; - return updatedHistory; - }); - } - } - } - - setIsStreaming(false); - } catch (error) { - console.error("Error running inference:", error); - setIsStreaming(false); - } - }; - - const handleInference = () => { - if (textInput.trim() === "" || !modelID) return; - - const inferenceRequest: InferenceRequest = { - deploy_id: modelID, - text: textInput, - }; - - runInference(inferenceRequest); - }; - - const RagContextSelector = ({ - collections, - onChange, - activeCollection, - }: { - collections: RagDataSource[]; - activeCollection?: RagDataSource; - onChange: (v: string) => void; - }) => ( -
- -
- ); - - const handleKeyPress = (e: React.KeyboardEvent) => { - if (e.key === "Enter" && !e.shiftKey) { - e.preventDefault(); - handleInference(); - } - }; - - const handleTextAreaInput = (e: React.ChangeEvent) => { - e.target.style.height = "auto"; - e.target.style.height = `${e.target.scrollHeight}px`; - setTextInput(e.target.value); - }; - - return ( -
- -
- {/* Breadcrumbs and RAG context selector */} - - - - - - - - Models Deployed - - - -

View all deployed models

-
-
-
-
- - / - - - - - - - - - Toggle menu - - - {modelsDeployed.map((model) => ( - { - setModelID(model.id); - setModelName(model.name); - }} - > - {model.name} - - ))} - - - - -

Select a different model

-
-
-
-
- - / - - - - - - - {modelName} - - - -

Current selected model

-
-
-
-
-
-
- { - const dataSource = ragDataSources.find( - (rds: RagDataSource) => rds.name === v, - ); - if (dataSource) { - setRagDatasource(dataSource); - } - }} - activeCollection={ragDatasource} - /> -
- {/* Chat history section */} -
- {chatHistory.length === 0 && ( - - )} - {chatHistory.length > 0 && ( - - -
- {chatHistory.map((message, index) => ( -
-
-
- {message.sender === "user" ? ( - - ) : ( - Tenstorrent Logo - )} -
-
-
- {message.text} -
- {message.sender === "assistant" && - message.inferenceStats && ( - - )} -
- ))} -
-
- - - )} - -
- -
-
- -
-
-