Skip to content

Commit

Permalink
Simplify configuration (#84)
Browse files Browse the repository at this point in the history
* feat: Bring back support for prompt env var

* feat: Implement auth to es embedder
  • Loading branch information
TilmanGriesel authored Jan 19, 2025
1 parent e0e554d commit f7c9e81
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 14 deletions.
1 change: 1 addition & 0 deletions services/api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ DAILY_RATE_LIMIT=86400
REQUIRE_SECURE=false
DEBUG=false
SYSTEM_PROMPT_PATH=./
SYSTEM_PROMPT=

# Temperature
# Controls the randomness of the model's outputs.
Expand Down
2 changes: 0 additions & 2 deletions services/api/src/core/document_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def initialize_store(self) -> ElasticsearchDocumentStore:
if (
self.es_basic_auth_user
and self.es_basic_auth_password
and isinstance(self.es_basic_auth_user, str)
and isinstance(self.es_basic_auth_password, str)
and self.es_basic_auth_user.strip()
and self.es_basic_auth_password.strip()
):
Expand Down
14 changes: 13 additions & 1 deletion services/api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,21 @@ def show_welcome():


def load_systemprompt(base_path: str) -> str:
file = Path(base_path) / ".systemprompt"
default_prompt = ""

# Use environment variable if available
env_var_name = "SYSTEM_PROMPT"
env_prompt = os.getenv(env_var_name)
if env_prompt is not None:
content = env_prompt.strip()
logger.info(
f"Using system prompt from '{env_var_name}' environment variable; content: '{content}'"
)
return content

# Try reading from file
file = Path(base_path) / ".systemprompt"

if not file.exists():
logger.info("No .systemprompt file found. Using default prompt.")
return default_prompt
Expand Down
2 changes: 2 additions & 0 deletions tools/embed/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ HF_API_KEY=your-huggingface-api-key
# Elastic Search
ES_URL=http://elasticsearch:9200
ES_INDEX=default
ES_BASIC_AUTH_USERNAME=
ES_BASIC_AUTH_PASSWORD=

# Embedding
EMBEDDING_MODEL_NAME=snowflake-arctic-embed2
Expand Down
14 changes: 14 additions & 0 deletions tools/embed/src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ def parse_args():
help="Index for the Elasticsearch service",
)

parser.add_argument(
"--es-basic-auth-user",
type=str,
default=os.getenv("ES_BASIC_AUTH_USERNAME", ""),
help="Username for the Elasticsearch service authentication",
)

parser.add_argument(
"--es-basic-auth-password",
type=str,
default=os.getenv("ES_BASIC_AUTH_PASSWORD", ""),
help="Password for the Elasticsearch service authentication",
)

parser.add_argument(
"--ollama-url",
type=str,
Expand Down
43 changes: 33 additions & 10 deletions tools/embed/src/core/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

@dataclass
class PipelineConfig:
es_url: str
es_index: str
provider: str
embedding_model: str
es_url: str
es_index: str
es_basic_auth_user: Optional[str] = None
es_basic_auth_password: Optional[str] = None
ollama_url: Optional[str] = None
hf_api_key: Optional[str] = None

Expand Down Expand Up @@ -67,10 +69,12 @@ def log_metrics(self, logger):
class RAGEmbedder:
def __init__(
self,
es_url: str = None,
es_index: str = None,
provider_name: str = None,
embedding_model: str = None,
es_url: str = None,
es_index: str = None,
es_basic_auth_user: str = None,
es_basic_auth_password: str = None,
ollama_url: str = None,
hf_api_key: str = None,
):
Expand All @@ -95,10 +99,14 @@ def __init__(
self.logger.info(f"EMBEDDING MODEL:{embedding_model}")

self.config = PipelineConfig(
es_url=es_url or os.getenv("ES_URL"),
es_index=es_index or os.getenv("ES_INDEX"),
provider=provider,
embedding_model=embedding_model,
es_url=es_url or os.getenv("ES_URL"),
es_index=es_index or os.getenv("ES_INDEX"),
es_basic_auth_user=es_basic_auth_user
or os.getenv("ES_BASIC_AUTH_USERNAME"),
es_basic_auth_password=es_basic_auth_password
or os.getenv("ES_BASIC_AUTH_PASSWORD"),
ollama_url=ollama_url or os.getenv("OLLAMA_URL"),
hf_api_key=hf_api_key or os.getenv("HF_API_KEY"),
)
Expand Down Expand Up @@ -183,10 +191,25 @@ def _initialize_ollama(self):

def _initialize_document_store(self) -> ElasticsearchDocumentStore:
try:
document_store = ElasticsearchDocumentStore(
hosts=self.config.es_url,
index=self.config.es_index,
)
params = {
"hosts": self.config.es_url,
"index": self.config.es_index,
"embedding_similarity_function": "cosine",
}

# Add basic auth if non-empty
if (
self.config.es_basic_auth_user
and self.config.es_basic_auth_password
and self.config.es_basic_auth_user.strip()
and self.config.es_basic_auth_password.strip()
):
params["basic_auth"] = (
self.config.es_basic_auth_user,
self.config.es_basic_auth_password,
)

document_store = ElasticsearchDocumentStore(**params)
doc_count = document_store.count_documents()
self.logger.info(
f"Document store initialized successfully with {doc_count} documents"
Expand Down
4 changes: 3 additions & 1 deletion tools/embed/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ def main():
logger.info("Initializing RAG Embedder")
embedder = RAGEmbedder(
provider_name=args.provider,
ollama_url=args.ollama_url,
es_url=args.es_url,
es_index=args.es_index,
ollama_url=args.ollama_url,
es_basic_auth_user=args.es_basic_auth_user,
es_basic_auth_password=args.es_basic_auth_password,
embedding_model=args.embedding_model,
)
logger.debug("RAG Embedder initialized successfully")
Expand Down

0 comments on commit f7c9e81

Please sign in to comment.