Skip to content

Commit

Permalink
Add OpenAI embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
Stelath committed Jun 23, 2024
1 parent 9b350b4 commit 71fab2b
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 9 deletions.
54 changes: 49 additions & 5 deletions mailfox/mailfox.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import yaml
from tqdm.auto import tqdm
import click
import typer
from typing_extensions import Annotated
from typing import Optional
from .email_interface import EmailHandler, EmailLLM
from .vector import VectorDatabase, FolderCluster
from .vector import VectorDatabase, FolderCluster, EmbeddingFunctions
import time
from functools import partial

Expand All @@ -31,7 +32,7 @@ def run():

email_db_path = os.path.expanduser(config['email_db_path'])
if os.path.exists(email_db_path) and not VectorDatabase(email_db_path).is_emails_empty():
vector_db = VectorDatabase(email_db_path)
vector_db = VectorDatabase(email_db_path, embedding_function=config["default_embedding_function"], openai_api_key=api_key)
else:
typer.echo("It looks like you don't have an email database yet. Let's create one!")
should_download_emails = typer.confirm("Would you like to download all your emails now?")
Expand Down Expand Up @@ -111,7 +112,9 @@ def init():
clustering_path = typer.prompt("Enter the path for clustering data", default="~/.mailfox/data/clustering.pkl")
flagged_folders = typer.prompt("Enter any flagged folders (comma-separated)", default=None)
default_classifier = typer.prompt("Enter the default classifier", default="clustering")
set_config(email_db_path, clustering_path, flagged_folders, default_classifier)
default_embedding_function_choices = click.Choice([func.value for func in EmbeddingFunctions])
default_embedding_function = typer.prompt("Enter the default embedding function", default="st", show_choices=True, type=default_embedding_function_choices)
set_config(email_db_path, clustering_path, flagged_folders, default_classifier, default_embedding_function)

typer.echo("Step 3: Downloading emails")
should_download_emails = typer.confirm("Would you like to download your emails to the VectorDB now?")
Expand All @@ -126,12 +129,51 @@ def init():
typer.echo("Setup complete!")


# Database Commands
database_app = typer.Typer()
app.add_typer(database_app, name="database")

@database_app.command("remove")
def remove_database():
remove_db_flag = typer.confirm("Are you sure you want to remove the database? This action cannot be undone.", default=False)

if remove_db_flag:
if not os.path.exists(os.path.expanduser(get_config()['email_db_path'])):
typer.echo("Database does not exist")
return

import shutil
shutil.rmtree(os.path.expanduser(get_config()['email_db_path']))
typer.echo("Database removed")
else:
typer.echo("Database removal cancelled")


# Clustering Commands
clustering_app = typer.Typer()
app.add_typer(clustering_app, name="clustering")

@clustering_app.command("remove")
def remove_clustering():
remove_clustering_flag = typer.confirm("Are you sure you want to remove the clustering data? This action cannot be undone.", default=False)

if remove_clustering_flag:
if not os.path.exists(os.path.expanduser(get_config()['clustering_path'])):
typer.echo("Clustering data does not exist")
return

os.remove(os.path.expanduser(get_config()['clustering_path']))
typer.echo("Clustering data removed")
else:
typer.echo("Clustering data removal cancelled")


# CONFIG SETTINGS
config_app = typer.Typer()
app.add_typer(config_app, name="config")

@config_app.command("set")
def set_config(email_db_path: Optional[str] = None, clustering_path: Optional[str] = None, flagged_folders: Optional[str] = None, default_classifier: Optional[str] = None):
def set_config(email_db_path: Optional[str] = None, clustering_path: Optional[str] = None, flagged_folders: Optional[str] = None, default_classifier: Optional[str] = None, default_embedding_function: Optional[str] = None):
"""
Set configuration for the application.
Expand All @@ -140,6 +182,7 @@ def set_config(email_db_path: Optional[str] = None, clustering_path: Optional[st
clustering_path (Optional[str]): Path to the clustering data.
flagged_folders (Optional[str]): Comma-separated list of flagged folders.
default_classifier (Optional[str]): Default classifier to use (either clustering or llm).
default_embedding_function (Otional[str]): Default embedding function to use (either openai or sentence-transformers (st)).
"""
try:
config_file = os.path.expanduser("~/.mailfox/config.yaml")
Expand All @@ -157,7 +200,8 @@ def set_config(email_db_path: Optional[str] = None, clustering_path: Optional[st
"email_db_path": email_db_path if email_db_path is not None else existing_config.get("email_db_path"),
"clustering_path": clustering_path if clustering_path is not None else existing_config.get("clustering_path"),
"flagged_folders": [folder.strip() for folder in flagged_folders.split(",")] if flagged_folders is not None else existing_config.get("flagged_folders"),
"default_classifier": default_classifier if default_classifier is not None else existing_config.get("default_classifier")
"default_classifier": default_classifier if default_classifier is not None else existing_config.get("default_classifier"),
"default_embedding_function": default_embedding_function if default_embedding_function is not None else existing_config.get("default_embedding_function")
}

# Save updated config
Expand Down
21 changes: 17 additions & 4 deletions mailfox/vector/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,23 @@
import numpy as np
from tqdm.auto import tqdm

from enum import Enum

class EmbeddingFunctions(str, Enum):
SENTENCE_TRANSFORMER = "st"
OPENAI = "openai"

class VectorDatabase():
def __init__(self, db_path="./chroma_db/"):
def __init__(self, db_path="./chroma_db/", *, embedding_function=None, openai_api_key=None):
self.chroma_client = chromadb.PersistentClient(db_path)

self.default_ef = embedding_functions.DefaultEmbeddingFunction()
if embedding_function is None or embedding_function == "st":
self.default_ef = embedding_functions.DefaultEmbeddingFunction()
elif embedding_function == "openai":
self.default_ef = embedding_functions.OpenAIEmbeddingFunction(api_key=openai_api_key, model_name="text-embedding-3-small")
else:
raise ValueError("Invalid embedding function")

self.emails_collection = self.chroma_client.get_or_create_collection(name="emails", embedding_function=self.default_ef)

def is_emails_empty(self):
Expand All @@ -19,8 +31,9 @@ def embed(self, text: list[str]):
return self.default_ef(text)

def embed_email(self, email: dict):
embeddings = np.array(self.default_ef([email['from'], email['subject'], email['body']]))
embedding = embeddings[0] * 0.3 + embeddings[1] * 0.2 + embeddings[2] * 0.5
embeddings = np.array(self.default_ef([email['from'] + " " + email['subject'], email['body']]))
# embedding = embeddings[0] * 0.3 + embeddings[1] * 0.2 + embeddings[2] * 0.5
embedding = embeddings[0] * 0.3 + embeddings[1] * 0.7
embedding = embedding.reshape(1, -1)

return embedding
Expand Down

0 comments on commit 71fab2b

Please sign in to comment.