Skip to content

Commit

Permalink
Anirud/fix refactor chatui page (#64)
Browse files Browse the repository at this point in the history
* refactor(Header): Resolve ref warning and improve type safety

- Extract ModelSelector as a separate forwardRef component
- Fix TypeScript errors in DropdownMenu implementation
- Improve overall type safety in Header component
- fixes #62

* refactor ChatUI component
- Break down large ChatComponent into smaller, reusable components in its own folder
- Create separate Header, ChatHistory, and InputArea components
- move chat components like stats, examples into its own folder called "chatui"

* refactor: chatui history and examples
- add scroll back which allows user to scoll to bottom of chat

* keep focus on chat text area

* Refactor ChatComponent for improved modularity
- Move utility functions to separate files
- Update imports to use new utility files
- Simplify handleInference function

* increase cycle time before moving to a new chat example

* Create refactored runInference.ts for inference processing

* Create getRagContext.ts for RAG functionality

* Create types.ts with shared interfaces across chat ui
  • Loading branch information
anirudTT committed Oct 29, 2024
1 parent 24c2214 commit 419fcf2
Show file tree
Hide file tree
Showing 12 changed files with 768 additions and 556 deletions.
492 changes: 0 additions & 492 deletions app/frontend/src/components/ChatComponent.tsx

This file was deleted.

61 changes: 0 additions & 61 deletions app/frontend/src/components/ChatExamples.tsx

This file was deleted.

94 changes: 94 additions & 0 deletions app/frontend/src/components/chatui/ChatComponent.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
import { useState, useEffect } from "react";
import { Card } from "../ui/card";
import { useLocation } from "react-router-dom";
import logo from "../../assets/tt_logo.svg";
import { fetchModels } from "../../api/modelsDeployedApis";
import { useQuery } from "react-query";
import { fetchCollections } from "@/src/components/rag";
import Header from "./Header";
import ChatHistory from "./ChatHistory";
import InputArea from "./InputArea";
import { InferenceRequest, RagDataSource, ChatMessage, Model } from "./types";
import { runInference } from "./runInference";

export default function ChatComponent() {
const location = useLocation();
const [textInput, setTextInput] = useState<string>("");
const [ragDatasource, setRagDatasource] = useState<
RagDataSource | undefined
>();
const { data: ragDataSources } = useQuery("collectionsList", {
queryFn: fetchCollections,
initialData: [],
});
const [chatHistory, setChatHistory] = useState<ChatMessage[]>([]);
const [modelID, setModelID] = useState<string | null>(null);
const [modelName, setModelName] = useState<string | null>(null);
const [isStreaming, setIsStreaming] = useState(false);
const [modelsDeployed, setModelsDeployed] = useState<Model[]>([]);

useEffect(() => {
if (location.state) {
setModelID(location.state.containerID);
setModelName(location.state.modelName);
}

const loadModels = async () => {
try {
const models = await fetchModels();
setModelsDeployed(models);
} catch (error) {
console.error("Error fetching models:", error);
}
};

loadModels();
}, [location.state]);

const handleInference = () => {
if (textInput.trim() === "" || !modelID) return;

const inferenceRequest: InferenceRequest = {
deploy_id: modelID,
text: textInput,
};

runInference(
inferenceRequest,
ragDatasource,
textInput,
setChatHistory,
setIsStreaming,
);
setTextInput("");
};

return (
<div className="flex flex-col w-10/12 mx-auto h-screen overflow-hidden">
<Card className="flex flex-col w-full h-full">
<Header
modelName={modelName}
modelsDeployed={modelsDeployed}
setModelID={setModelID}
setModelName={setModelName}
ragDataSources={ragDataSources}
ragDatasource={ragDatasource}
setRagDatasource={setRagDatasource}
/>
<ChatHistory
chatHistory={chatHistory}
logo={logo}
setTextInput={setTextInput}
/>
<InputArea
textInput={textInput}
setTextInput={setTextInput}
handleInference={handleInference}
isStreaming={isStreaming}
/>
</Card>
</div>
);
}
120 changes: 120 additions & 0 deletions app/frontend/src/components/chatui/ChatExamples.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

import React, { useState, useEffect } from "react";
import { Button } from "../ui/button";
import {
MessageCircle,
Smile,
CloudSun,
Lightbulb,
Code,
Book,
Globe,
Rocket,
} from "lucide-react";

interface ChatExamplesProps {
logo: string;
setTextInput: React.Dispatch<React.SetStateAction<string>>;
}

const allExamples = [
{
icon: <MessageCircle className="h-6 w-6" />,
text: "Hello, how are you today?",
color: "text-blue-500 dark:text-blue-400",
},
{
icon: <Smile className="h-6 w-6" />,
text: "Can you tell me a joke?",
color: "text-red-500 dark:text-red-400",
},
{
icon: <CloudSun className="h-6 w-6" />,
text: "What's the weather like?",
color: "text-yellow-500 dark:text-yellow-400",
},
{
icon: <Lightbulb className="h-6 w-6" />,
text: "Tell me a fun fact.",
color: "text-green-500 dark:text-green-400",
},
{
icon: <Code className="h-6 w-6" />,
text: "Explain a coding concept.",
color: "text-purple-500 dark:text-purple-400",
},
{
icon: <Book className="h-6 w-6" />,
text: "Recommend a book to read.",
color: "text-pink-500 dark:text-pink-400",
},
{
icon: <Globe className="h-6 w-6" />,
text: "Describe a random country.",
color: "text-teal-500 dark:text-teal-400",
},
{
icon: <Rocket className="h-6 w-6" />,
text: "Share a space exploration fact.",
color: "text-orange-500 dark:text-orange-400",
},
];

const ChatExamples: React.FC<ChatExamplesProps> = ({ logo, setTextInput }) => {
const [displayedExamples, setDisplayedExamples] = useState(
allExamples.slice(0, 4),
);

useEffect(() => {
const interval = setInterval(() => {
setDisplayedExamples((prevExamples) => {
const nextIndex =
(allExamples.indexOf(prevExamples[3]) + 1) % allExamples.length;
return [
...allExamples.slice(nextIndex, nextIndex + 4),
...allExamples.slice(
0,
Math.max(0, 4 - (allExamples.length - nextIndex)),
),
];
});
}, 15000); // Cycle time in milliseconds

return () => clearInterval(interval);
}, []);

return (
<div className="flex flex-col items-center justify-center min-h-[28rem] p-4 transition-colors duration-200">
<img
src={logo}
alt="Tenstorrent Logo"
className="w-16 h-16 mb-6 transform transition duration-300 hover:scale-110"
/>
<h2 className="text-2xl font-bold mb-6 text-gray-800 dark:text-gray-200 transition-colors duration-200">
Start a conversation with LLM Studio Chat...
</h2>
<div className="grid grid-cols-1 sm:grid-cols-2 gap-4 w-full max-w-4xl">
{displayedExamples.map((example, index) => (
<Button
key={index}
variant="outline"
className={`h-auto py-4 px-6 flex flex-col items-center justify-center text-center space-y-2 transition-all duration-300
bg-white dark:bg-gray-800 hover:bg-gray-100 dark:hover:bg-gray-700
border border-gray-200 dark:border-gray-600
text-gray-800 dark:text-gray-200`}
onClick={() => setTextInput(example.text)}
>
<span className={`${example.color} transition-colors duration-200`}>
{example.icon}
</span>
<span className="text-sm font-medium">{example.text}</span>
</Button>
))}
</div>
</div>
);
};

export default ChatExamples;
Loading

0 comments on commit 419fcf2

Please sign in to comment.