-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from 13Mai13/feat/embedding
Feat/embedding
- Loading branch information
Showing
6 changed files
with
170 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
""" | ||
Semantic Search functionality | ||
""" | ||
import json | ||
import logging | ||
from pathlib import Path | ||
from typing import List, Dict, Any | ||
|
||
import torch | ||
from sentence_transformers import SentenceTransformer, util | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def load_url_data(data_path: Path) -> List[Dict]: | ||
"""Load URL data from JSON file""" | ||
logger.debug(f"Loading data from {data_path}") | ||
try: | ||
with open(data_path) as f: | ||
data = json.load(f) | ||
logger.info(f"Loaded {len(data)} URLs") | ||
return data | ||
except Exception as e: | ||
logger.error(f"Failed to load data: {e}") | ||
raise | ||
|
||
def load_model(model_name: str = 'all-MiniLM-L6-v2') -> SentenceTransformer: | ||
"""Load and return the sentence transformer model""" | ||
logger.info(f"Loading model: {model_name}") | ||
return SentenceTransformer(model_name) | ||
|
||
def encode_urls( | ||
model: SentenceTransformer, | ||
url_data: List[Dict] | ||
) -> torch.Tensor: | ||
"""Encode URL data using the transformer model""" | ||
logger.debug("Encoding URLs") | ||
texts = [f"{item['title']} {item['url']}" for item in url_data] | ||
return model.encode(texts, convert_to_tensor=True) | ||
|
||
def search_urls( | ||
query: str, | ||
model: SentenceTransformer, | ||
url_data: List[Dict], | ||
embeddings: torch.Tensor, | ||
top_k: int = 5 | ||
) -> List[Dict]: | ||
"""Search for most similar URLs given a query""" | ||
logger.debug(f"Searching for query: {query}") | ||
|
||
query_embedding = model.encode(query, convert_to_tensor=True) | ||
cos_scores = util.cos_sim(query_embedding, embeddings)[0] | ||
top_results = torch.topk(cos_scores, k=min(top_k, len(url_data))) | ||
|
||
matches = [ | ||
{ | ||
'score': float(score), | ||
'url': url_data[int(idx)]['url'], | ||
'title': url_data[int(idx)]['title'] | ||
} | ||
for score, idx in zip(top_results[0], top_results[1]) | ||
] | ||
|
||
logger.info(f"Found {len(matches)} matches") | ||
return matches | ||
|
||
def main(config: Dict[str, Any], query: str): | ||
"""Main preprocessing function.""" | ||
logger.info("Starting data preprocessing") | ||
|
||
model = load_model() | ||
url_data = load_url_data(config['data']['output_path']) | ||
embeddings = encode_urls(model, url_data) | ||
|
||
# Perform search | ||
matches = search_urls(query, model, url_data, embeddings, top_k=5) | ||
|
||
# Output results | ||
for idx, match in enumerate(matches, 1): | ||
print(f"\n{idx}. Score: {match['score']:.4f}") | ||
print(f" Title: {match['title']}") | ||
print(f" URL: {match['url']}") | ||
|
||
logger.info("Data preprocessing completed") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pytest | ||
from sentence_transformers import SentenceTransformer | ||
import torch | ||
|
||
from src.semantic_search.search import load_model, encode_urls, search_urls | ||
|
||
@pytest.fixture | ||
def model(): | ||
return load_model() | ||
|
||
@pytest.fixture | ||
def sample_data(): | ||
return [ | ||
{ | ||
"url": "https://example.com/python", | ||
"title": "Python Programming", | ||
"content": "" | ||
}, | ||
{ | ||
"url": "https://example.com/javascript", | ||
"title": "JavaScript Basics", | ||
"content": "" | ||
} | ||
] | ||
|
||
def test_load_model(): | ||
model = load_model() | ||
assert isinstance(model, SentenceTransformer) | ||
|
||
def test_encode_urls(model, sample_data): | ||
embeddings = encode_urls(model, sample_data) | ||
assert isinstance(embeddings, torch.Tensor) | ||
assert len(embeddings) == len(sample_data) | ||
|
||
def test_search_urls(model, sample_data): | ||
embeddings = encode_urls(model, sample_data) | ||
results = search_urls("python programming", model, sample_data, embeddings, top_k=1) | ||
assert len(results) == 1 | ||
assert all(k in results[0] for k in ['score', 'url', 'title']) |