diff --git a/examples/pipelines/rag/llamaindex_ollama_pipeline.py b/examples/pipelines/rag/llamaindex_ollama_pipeline.py index efafe672..0b7765c0 100644 --- a/examples/pipelines/rag/llamaindex_ollama_pipeline.py +++ b/examples/pipelines/rag/llamaindex_ollama_pipeline.py @@ -10,28 +10,48 @@ from typing import List, Union, Generator, Iterator from schemas import OpenAIChatMessage +import os + +from pydantic import BaseModel class Pipeline: + + class Valves(BaseModel): + LLAMAINDEX_OLLAMA_BASE_URL: str + LLAMAINDEX_MODEL_NAME: str + LLAMAINDEX_EMBEDDING_MODEL_NAME: str + def __init__(self): self.documents = None self.index = None + self.valves = self.Valves( + **{ + "LLAMAINDEX_OLLAMA_BASE_URL": os.getenv("LLAMAINDEX_OLLAMA_BASE_URL", "http://localhost:11434"), + "LLAMAINDEX_MODEL_NAME": os.getenv("LLAMAINDEX_MODEL_NAME", "llama3"), + "LLAMAINDEX_EMBEDDING_MODEL_NAME": os.getenv("LLAMAINDEX_EMBEDDING_MODEL_NAME", "nomic-embed-text"), + } + ) + async def on_startup(self): from llama_index.embeddings.ollama import OllamaEmbedding from llama_index.llms.ollama import Ollama from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader Settings.embed_model = OllamaEmbedding( - model_name="nomic-embed-text", - base_url="http://localhost:11434", + model_name=self.valves.LLAMAINDEX_EMBEDDING_MODEL_NAME, + base_url=self.valves.LLAMAINDEX_OLLAMA_BASE_URL, + ) + Settings.llm = Ollama( + model=self.valves.LLAMAINDEX_MODEL_NAME, + base_url=self.valves.LLAMAINDEX_OLLAMA_BASE_URL, ) - Settings.llm = Ollama(model="llama3") # This function is called when the server is started. global documents, index - self.documents = SimpleDirectoryReader("./data").load_data() + self.documents = SimpleDirectoryReader("/app/backend/data").load_data() self.index = VectorStoreIndex.from_documents(self.documents) pass