diff --git a/mailfox/mailfox.py b/mailfox/mailfox.py index 314a96d..4a92f0e 100644 --- a/mailfox/mailfox.py +++ b/mailfox/mailfox.py @@ -1,11 +1,12 @@ import os import yaml from tqdm.auto import tqdm +import click import typer from typing_extensions import Annotated from typing import Optional from .email_interface import EmailHandler, EmailLLM -from .vector import VectorDatabase, FolderCluster +from .vector import VectorDatabase, FolderCluster, EmbeddingFunctions import time from functools import partial @@ -31,7 +32,7 @@ def run(): email_db_path = os.path.expanduser(config['email_db_path']) if os.path.exists(email_db_path) and not VectorDatabase(email_db_path).is_emails_empty(): - vector_db = VectorDatabase(email_db_path) + vector_db = VectorDatabase(email_db_path, embedding_function=config["default_embedding_function"], openai_api_key=api_key) else: typer.echo("It looks like you don't have an email database yet. Let's create one!") should_download_emails = typer.confirm("Would you like to download all your emails now?") @@ -111,7 +112,9 @@ def init(): clustering_path = typer.prompt("Enter the path for clustering data", default="~/.mailfox/data/clustering.pkl") flagged_folders = typer.prompt("Enter any flagged folders (comma-separated)", default=None) default_classifier = typer.prompt("Enter the default classifier", default="clustering") - set_config(email_db_path, clustering_path, flagged_folders, default_classifier) + default_embedding_function_choices = click.Choice([func.value for func in EmbeddingFunctions]) + default_embedding_function = typer.prompt("Enter the default embedding function", default="st", show_choices=True, type=default_embedding_function_choices) + set_config(email_db_path, clustering_path, flagged_folders, default_classifier, default_embedding_function) typer.echo("Step 3: Downloading emails") should_download_emails = typer.confirm("Would you like to download your emails to the VectorDB now?") @@ -126,12 +129,51 @@ def init(): typer.echo("Setup complete!") +# Database Commands +database_app = typer.Typer() +app.add_typer(database_app, name="database") + +@database_app.command("remove") +def remove_database(): + remove_db_flag = typer.confirm("Are you sure you want to remove the database? This action cannot be undone.", default=False) + + if remove_db_flag: + if not os.path.exists(os.path.expanduser(get_config()['email_db_path'])): + typer.echo("Database does not exist") + return + + import shutil + shutil.rmtree(os.path.expanduser(get_config()['email_db_path'])) + typer.echo("Database removed") + else: + typer.echo("Database removal cancelled") + + +# Clustering Commands +clustering_app = typer.Typer() +app.add_typer(clustering_app, name="clustering") + +@clustering_app.command("remove") +def remove_clustering(): + remove_clustering_flag = typer.confirm("Are you sure you want to remove the clustering data? This action cannot be undone.", default=False) + + if remove_clustering_flag: + if not os.path.exists(os.path.expanduser(get_config()['clustering_path'])): + typer.echo("Clustering data does not exist") + return + + os.remove(os.path.expanduser(get_config()['clustering_path'])) + typer.echo("Clustering data removed") + else: + typer.echo("Clustering data removal cancelled") + + # CONFIG SETTINGS config_app = typer.Typer() app.add_typer(config_app, name="config") @config_app.command("set") -def set_config(email_db_path: Optional[str] = None, clustering_path: Optional[str] = None, flagged_folders: Optional[str] = None, default_classifier: Optional[str] = None): +def set_config(email_db_path: Optional[str] = None, clustering_path: Optional[str] = None, flagged_folders: Optional[str] = None, default_classifier: Optional[str] = None, default_embedding_function: Optional[str] = None): """ Set configuration for the application. @@ -140,6 +182,7 @@ def set_config(email_db_path: Optional[str] = None, clustering_path: Optional[st clustering_path (Optional[str]): Path to the clustering data. flagged_folders (Optional[str]): Comma-separated list of flagged folders. default_classifier (Optional[str]): Default classifier to use (either clustering or llm). + default_embedding_function (Otional[str]): Default embedding function to use (either openai or sentence-transformers (st)). """ try: config_file = os.path.expanduser("~/.mailfox/config.yaml") @@ -157,7 +200,8 @@ def set_config(email_db_path: Optional[str] = None, clustering_path: Optional[st "email_db_path": email_db_path if email_db_path is not None else existing_config.get("email_db_path"), "clustering_path": clustering_path if clustering_path is not None else existing_config.get("clustering_path"), "flagged_folders": [folder.strip() for folder in flagged_folders.split(",")] if flagged_folders is not None else existing_config.get("flagged_folders"), - "default_classifier": default_classifier if default_classifier is not None else existing_config.get("default_classifier") + "default_classifier": default_classifier if default_classifier is not None else existing_config.get("default_classifier"), + "default_embedding_function": default_embedding_function if default_embedding_function is not None else existing_config.get("default_embedding_function") } # Save updated config diff --git a/mailfox/vector/database.py b/mailfox/vector/database.py index e019d8d..1011975 100644 --- a/mailfox/vector/database.py +++ b/mailfox/vector/database.py @@ -4,11 +4,23 @@ import numpy as np from tqdm.auto import tqdm +from enum import Enum + +class EmbeddingFunctions(str, Enum): + SENTENCE_TRANSFORMER = "st" + OPENAI = "openai" + class VectorDatabase(): - def __init__(self, db_path="./chroma_db/"): + def __init__(self, db_path="./chroma_db/", *, embedding_function=None, openai_api_key=None): self.chroma_client = chromadb.PersistentClient(db_path) - self.default_ef = embedding_functions.DefaultEmbeddingFunction() + if embedding_function is None or embedding_function == "st": + self.default_ef = embedding_functions.DefaultEmbeddingFunction() + elif embedding_function == "openai": + self.default_ef = embedding_functions.OpenAIEmbeddingFunction(api_key=openai_api_key, model_name="text-embedding-3-small") + else: + raise ValueError("Invalid embedding function") + self.emails_collection = self.chroma_client.get_or_create_collection(name="emails", embedding_function=self.default_ef) def is_emails_empty(self): @@ -19,8 +31,9 @@ def embed(self, text: list[str]): return self.default_ef(text) def embed_email(self, email: dict): - embeddings = np.array(self.default_ef([email['from'], email['subject'], email['body']])) - embedding = embeddings[0] * 0.3 + embeddings[1] * 0.2 + embeddings[2] * 0.5 + embeddings = np.array(self.default_ef([email['from'] + " " + email['subject'], email['body']])) + # embedding = embeddings[0] * 0.3 + embeddings[1] * 0.2 + embeddings[2] * 0.5 + embedding = embeddings[0] * 0.3 + embeddings[1] * 0.7 embedding = embedding.reshape(1, -1) return embedding