Skip to content

Commit

Permalink
Merge pull request #7 from 13Mai13/feat/embedding
Browse files Browse the repository at this point in the history
Feat/embedding
  • Loading branch information
13Mai13 authored Jan 30, 2025
2 parents 1efb068 + f44943b commit 06593e9
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 5 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ pre-commit install
pre-commit run --all-files
```

### Commands

```
python main.py --stage search --env dev --query "python tips"
```

## Project

Expand Down
4 changes: 4 additions & 0 deletions configs/dev-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ logging:
data:
output_path: 'data/processed.json'
input_path: 'data/data.txt'
# Model Config
search:
model_name: "all-MiniLM-L6-v2"
top_k: 5
20 changes: 15 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,20 @@

from src.ultils import load_config, setup_logging
from src.preprocessing.main import main as preprocess_main
from src.semantic_search.search import main as search_main


class Stage(str, Enum):
preprocess = "preprocess"
search = "search"
all = "all"


class Config(str, Enum):
dev = "dev"
prod = "prod"


app = typer.Typer()


def get_config_path(env: Config) -> Path:
"""Get config path and verify it exists"""
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -62,6 +61,10 @@ def main(
env: Config = typer.Option(
Config.dev,
help="Environment to use (dev or prod)"
),
query: str = typer.Option(
None,
help="Search query (required for search stage)"
)
) -> None:
"""Run pipeline"""
Expand All @@ -86,9 +89,16 @@ def main(

logger.info("Pipeline completed successfully")

if stage in [Stage.search, Stage.all]:
if query is None:
logger.error("Search query is required for search stage")
print("[red]Error: Search query is required for search stage[/red]")
raise typer.Exit(1)
logger.info("Starting search stage")
search_main(config, query)

except Exception as e:
logger = logging.getLogger(__name__)
logger.exception(f"Pipeline failed with error: {str(e)}")
print(f"[red]Error: {str(e)}[/red]")
raise typer.Exit(1)


Expand Down
25 changes: 25 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
annotated-types==0.7.0
certifi==2024.12.14
cfgv==3.4.0
charset-normalizer==3.4.1
click==8.1.8
distlib==0.3.9
filelock==3.17.0
fsspec==2024.12.0
huggingface-hub==0.28.0
identify==2.6.6
idna==3.10
iniconfig==2.0.0
jinja2==3.1.5
joblib==1.4.2
markdown-it-py==3.0.0
markupsafe==3.0.2
mdurl==0.1.2
mpmath==1.3.0
networkx==3.4.2
nodeenv==1.9.1
numpy==2.2.2
packaging==24.2
pillow==11.1.0
platformdirs==4.3.6
pluggy==1.5.0
pre-commit==4.1.0
Expand All @@ -17,10 +29,23 @@ pydantic-core==2.27.2
pygments==2.19.1
pytest==8.3.4
pyyaml==6.0.2
regex==2024.11.6
requests==2.32.3
rich==13.9.4
ruff==0.9.3
safetensors==0.5.2
scikit-learn==1.6.1
scipy==1.15.1
sentence-transformers==3.4.1
setuptools==75.8.0
shellingham==1.5.4
sympy==1.13.1
threadpoolctl==3.5.0
tokenizers==0.21.0
torch==2.6.0
tqdm==4.67.1
transformers==4.48.1
typer==0.15.1
typing-extensions==4.12.2
urllib3==2.3.0
virtualenv==20.29.1
83 changes: 83 additions & 0 deletions src/semantic_search/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Semantic Search functionality
"""
import json
import logging
from pathlib import Path
from typing import List, Dict, Any

import torch
from sentence_transformers import SentenceTransformer, util

logger = logging.getLogger(__name__)

def load_url_data(data_path: Path) -> List[Dict]:
"""Load URL data from JSON file"""
logger.debug(f"Loading data from {data_path}")
try:
with open(data_path) as f:
data = json.load(f)
logger.info(f"Loaded {len(data)} URLs")
return data
except Exception as e:
logger.error(f"Failed to load data: {e}")
raise

def load_model(model_name: str = 'all-MiniLM-L6-v2') -> SentenceTransformer:
"""Load and return the sentence transformer model"""
logger.info(f"Loading model: {model_name}")
return SentenceTransformer(model_name)

def encode_urls(
model: SentenceTransformer,
url_data: List[Dict]
) -> torch.Tensor:
"""Encode URL data using the transformer model"""
logger.debug("Encoding URLs")
texts = [f"{item['title']} {item['url']}" for item in url_data]
return model.encode(texts, convert_to_tensor=True)

def search_urls(
query: str,
model: SentenceTransformer,
url_data: List[Dict],
embeddings: torch.Tensor,
top_k: int = 5
) -> List[Dict]:
"""Search for most similar URLs given a query"""
logger.debug(f"Searching for query: {query}")

query_embedding = model.encode(query, convert_to_tensor=True)
cos_scores = util.cos_sim(query_embedding, embeddings)[0]
top_results = torch.topk(cos_scores, k=min(top_k, len(url_data)))

matches = [
{
'score': float(score),
'url': url_data[int(idx)]['url'],
'title': url_data[int(idx)]['title']
}
for score, idx in zip(top_results[0], top_results[1])
]

logger.info(f"Found {len(matches)} matches")
return matches

def main(config: Dict[str, Any], query: str):
"""Main preprocessing function."""
logger.info("Starting data preprocessing")

model = load_model()
url_data = load_url_data(config['data']['output_path'])
embeddings = encode_urls(model, url_data)

# Perform search
matches = search_urls(query, model, url_data, embeddings, top_k=5)

# Output results
for idx, match in enumerate(matches, 1):
print(f"\n{idx}. Score: {match['score']:.4f}")
print(f" Title: {match['title']}")
print(f" URL: {match['url']}")

logger.info("Data preprocessing completed")
39 changes: 39 additions & 0 deletions test/test_semantic_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
from sentence_transformers import SentenceTransformer
import torch

from src.semantic_search.search import load_model, encode_urls, search_urls

@pytest.fixture
def model():
return load_model()

@pytest.fixture
def sample_data():
return [
{
"url": "https://example.com/python",
"title": "Python Programming",
"content": ""
},
{
"url": "https://example.com/javascript",
"title": "JavaScript Basics",
"content": ""
}
]

def test_load_model():
model = load_model()
assert isinstance(model, SentenceTransformer)

def test_encode_urls(model, sample_data):
embeddings = encode_urls(model, sample_data)
assert isinstance(embeddings, torch.Tensor)
assert len(embeddings) == len(sample_data)

def test_search_urls(model, sample_data):
embeddings = encode_urls(model, sample_data)
results = search_urls("python programming", model, sample_data, embeddings, top_k=1)
assert len(results) == 1
assert all(k in results[0] for k in ['score', 'url', 'title'])

0 comments on commit 06593e9

Please sign in to comment.