Skip to content

Commit

Permalink
refactor: clean up code and modularize RAG actions
Browse files Browse the repository at this point in the history
  • Loading branch information
Nyumat committed Nov 10, 2024
1 parent 9c6cc47 commit b2445c0
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 363 deletions.
108 changes: 20 additions & 88 deletions src/app/api/chat/route.ts
Original file line number Diff line number Diff line change
@@ -1,108 +1,40 @@
import { auth } from "@/lib/auth";
import { prisma } from "@/lib/prisma";
import { CreateMessageInput } from "@/lib/models";
import {
getFileContext,
saveMessagesInTransaction,
} from "@/lib/retrieval-augmentation-gen";
import { buildPrompt } from "@/lib/utils";
import { openai } from "@ai-sdk/openai";
import { streamText } from "ai";
import axios from "axios";
import { type Result } from "../embeddings/route";

/**
* POST /api/chat
* Processes a message and returns a response.x
* @param req The incoming request.
* @returns A response indicating the status of the operation.
* @throws If an error occurs while processing the request.
*/
export async function POST(req: Request) {
try {
const session = await auth();

if (!session?.user?.id)
return new Response("Unauthorized", { status: 401 });

if (req.headers.get("content-type") !== "application/json")
return new Response("Invalid content type", { status: 400 });

const { messages, chatId, fileName } = await req.json();

const chat = await prisma.chat.findUnique({
where: {
id: chatId,
userId: session.user.id,
},
include: {
CourseMaterial: true,
},
});

if (!chat) return new Response("Chat not found", { status: 404 });

const userMessage = messages[messages.length - 1];

if (userMessage.role === "user") {
const existingMessage = await prisma.message.findFirst({
where: {
chatId,
content: userMessage.content,
role: "user",
},
});

if (!existingMessage) {
await prisma.message.create({
data: {
chatId,
content: userMessage.content,
role: userMessage.role,
},
});
}
}

let fileContext = "";

if (fileName) {
const message = await prisma.message.findFirst({
where: {
chatId,
role: "user",
},
});
if (!message) throw new Error("No message found");

const response = await axios.get(
`http://localhost:3000/api/embeddings?query=${message.content}`,
);

const results = response.data.results as Result[];

if (results.length > 0) {
// Merge all TopK results into a single string
fileContext = results.map((result: Result) => result.text).join("\n");
} else {
fileContext = "";
}
}

const systemPrompt = {
role: "system",
content: `AI assistant is a brand new, powerful, human-like artificial intelligence.
The traits of AI include expert knowledge, helpfulness, cleverness, and articulateness.
AI is a well-behaved and well-mannered individual.
AI is always friendly, kind, and inspiring, and he is eager to provide vivid and thoughtful responses to the user.
AI has the sum of all knowledge in their brain, and is able to accurately answer nearly any question about any topic in conversation.
AI assistant is a big fan of Pinecone and Vercel.
START CONTEXT BLOCK
${fileContext || ""}
END OF CONTEXT BLOCK
AI assistant will take into account any CONTEXT BLOCK that is provided in a conversation.
If the context does not provide the answer to question, the AI assistant will say, "I'm sorry, but I don't know the answer to that question".
AI assistant will not apologize for previous responses, but instead will indicated new information was gained.
AI assistant will not invent anything that is not drawn directly from the context.`,
};
const { messages, chatId, fileId }: CreateMessageInput = await req.json();
const fileContext = fileId
? await getFileContext(messages[messages.length - 1])
: "";

const systemPrompt = buildPrompt(fileContext);
const response = await streamText({
model: openai("gpt-4o-mini"),
messages: [systemPrompt, ...messages],
async onFinish({ text }) {
await prisma.message.create({
data: {
chatId,
content: text,
role: "assistant",
},
});
saveMessagesInTransaction({ messages, chatId, text });
},
});

Expand Down
209 changes: 46 additions & 163 deletions src/app/api/embeddings/route.ts
Original file line number Diff line number Diff line change
@@ -1,132 +1,42 @@
import { getPresignedUrl } from "@/lib/cloudFlareClient";
import { createEmbedding } from "@/lib/openAiClient";
import { pineconeIndex } from "@/lib/pineconeClient";
import { prisma } from "@/lib/prisma";
import { WebPDFLoader } from "@langchain/community/document_loaders/web/pdf";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { PdfRecord } from "@/lib/models";
import { deleteDocumentFromPinecone, queryDocuments } from "@/lib/pinecone";
import {
handleEmbeddingAndStorage,
processDocument,
syncDocumentWithDb,
} from "@/lib/retrieval-augmentation-gen";
import { NextResponse } from "next/server";
import fetch from "node-fetch";

interface DocumentChunk {
id: string;
text: string;
metadata: {
fileName: string;
pageNumber: number;
};
}

export type Result = {
text: string;
fileName: string;
pageNumber: number;
score: number;
};

async function downloadPDFFromPresignedUrl(
presignedUrl: string,
): Promise<Blob> {
const response = await fetch(presignedUrl);
if (!response.ok) {
throw new Error(`Failed to download PDF: ${response.statusText}`);
}
const arrayBuffer = await response.arrayBuffer();
return new Blob([arrayBuffer], { type: "application/pdf" });
}

async function processPDFIntoChunks(
fileName: string,
): Promise<DocumentChunk[]> {
// Get pre-signed URL and download PDF
const presignedUrl = await getPresignedUrl(fileName);
if (!presignedUrl) {
throw new Error("Failed to generate pre-signed URL");
}

const pdfBlob = await downloadPDFFromPresignedUrl(presignedUrl);

// Load and parse PDF
const loader = new WebPDFLoader(pdfBlob, {
splitPages: true,
parsedItemSeparator: "",
});
const documents = await loader.load();

// Split into chunks
const textSplitter = new RecursiveCharacterTextSplitter({
chunkSize: 1000,
chunkOverlap: 200,
});

const chunks = await textSplitter.splitDocuments(documents);

// Format chunks for processing
return chunks.map((chunk, index) => ({
id: `${fileName}-chunk-${index}`,
text: chunk.pageContent.replace(/\n/g, " "),
metadata: {
fileName: fileName,
pageNumber: chunk.metadata.pageNumber || 0,
},
}));
}

/**
* POST /api/embeddings
* Processes a PDF file into chunks, embeds the text of each chunk, and stores the chunks in the Pinecone index.
* @param request The incoming request.
* @returns A response indicating the status of the operation.
* @throws If an error occurs while processing the PDF.
*/
export async function POST(request: Request) {
try {
const { fileName } = await request.json();

if (!fileName) {
const requestBody = await request.json();
const args: PdfRecord = requestBody.data;
if (!args.fileName)
return NextResponse.json(
{ error: "fileName is required" },
{ status: 400 },
);
}

// Process PDF into chunks
const chunks = await processPDFIntoChunks(fileName);

const upsertPromises = chunks.map(async (chunk) => {
const embedding = await createEmbedding(chunk.text);
return pineconeIndex.namespace("documents").upsert([
{
id: chunk.id,
values: embedding,
metadata: {
text: chunk.text,
fileName: chunk.metadata.fileName,
pageNumber: chunk.metadata.pageNumber,
},
},
]);
});

await Promise.all(upsertPromises);
const chunks = await processDocument(args);
const [taskOne, taskTwo] = await Promise.allSettled([
await handleEmbeddingAndStorage(chunks),
await syncDocumentWithDb(args, chunks),
]);

const courseMaterial = await prisma.courseMaterial.findFirst({
where: {
fileName: fileName,
},
});

if (!courseMaterial) {
if (taskOne.status === "rejected" || taskTwo.status === "rejected") {
return NextResponse.json(
{ error: "Course material not found" },
{ status: 404 },
{ error: "Failed to process PDF" },
{ status: 500 },
);
}

await prisma.courseMaterial.update({
where: {
id: courseMaterial.id,
},
data: {
documentIds: {
set: chunks.map((chunk) => chunk.id),
},
isIndexed: true,
},
});

return NextResponse.json({
message: "PDF processed and stored successfully",
chunks: chunks.length,
Expand All @@ -140,46 +50,26 @@ export async function POST(request: Request) {
}
}

/**
* GET /api/embeddings
* Searches for documents similar to the given query.
* @param request The incoming request.
* @returns A response containing the search results.
* @throws If an error occurs while searching for documents.
*/
export async function GET(request: Request) {
try {
const { searchParams } = new URL(request.url);
const query = searchParams.get("query");

if (!query) {
if (!query)
return NextResponse.json(
{ error: "Query parameter is required" },
{ status: 400 },
);
}

// Generate embedding for the query
const queryEmbedding = await createEmbedding(query);

// Prepare search options
const searchOptions = {
vector: queryEmbedding,
topK: 5,
includeMetadata: true,
};

// Search in Pinecone
const results = await pineconeIndex
.namespace("documents")
.query(searchOptions);

// Format results
const formattedResults = results.matches.map((match) => ({
text: match?.metadata?.text,
fileName: match?.metadata?.fileName,
pageNumber: match?.metadata?.pageNumber,
score: match.score,
}));

// TODO: filter out non-fileName matches

return NextResponse.json({
results: formattedResults,
});
// Relevant chunks are stored in the Pinecone index
const results = await queryDocuments(query);
return NextResponse.json({ results });
} catch (error) {
console.error("Error searching documents:", error);
return NextResponse.json(
Expand All @@ -189,33 +79,26 @@ export async function GET(request: Request) {
}
}

/**
* DELETE /api/embeddings
* Deletes the documents associated with the given chat ID.
* @param request The incoming request.
* @returns A response indicating the status of the operation.
* @throws If an error occurs while deleting the documents.
*/
export async function DELETE(request: Request) {
try {
const { searchParams } = new URL(request.url);
const id = searchParams.get("chatId");

if (!id) {
if (!id)
return NextResponse.json(
{ error: "ChatID parameter is required" },
{ status: 400 },
);
}

const chat = await prisma.chat.findFirstOrThrow({
where: {
id: id,
},
include: {
CourseMaterial: true,
},
});
await deleteDocumentFromPinecone(id);

const documentIds = chat.CourseMaterial?.documentIds || [];
// Delete one or many document(s) from Pinecone
await pineconeIndex.namespace("documents").deleteMany(documentIds);
return NextResponse.json({
message: "Document deleted successfully",
});
return NextResponse.json({ message: "Document deleted successfully" });
} catch (error) {
console.error("Error deleting document:", error);
return NextResponse.json(
Expand Down
2 changes: 1 addition & 1 deletion src/app/api/files/[id]/sign/route.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { getPresignedUrl } from "@/lib/cloudFlareClient";
import { getPresignedUrl } from "@/lib/cloudflare";
import { prisma } from "@/lib/prisma";
import { NextResponse } from "next/server";

Expand Down
Loading

0 comments on commit b2445c0

Please sign in to comment.