-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
170 additions
and
218 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from openai import OpenAI | ||
import pandas as pd | ||
import json | ||
import faiss | ||
import numpy as np | ||
import os | ||
import pickle | ||
from django.conf import settings | ||
|
||
client = OpenAI(api_key=settings.OPENAI_API_KEY) | ||
|
||
embedding_file = os.path.join(settings.BASE_DIR, "chatbot_laura", "embeddings.pkl") | ||
index_file = os.path.join(settings.BASE_DIR, "chatbot_laura", "faiss_index.index") | ||
|
||
def get_embedding(text, model="text-embedding-ada-002"): | ||
response = client.embeddings.create(input=text, model=model) | ||
return response.data[0].embedding | ||
|
||
def save_embeddings(embeddings, file_name): | ||
with open(file_name, 'wb') as f: | ||
pickle.dump(embeddings, f) | ||
|
||
def load_embeddings(file_name): | ||
with open(file_name, 'rb') as f: | ||
return pickle.load(f) | ||
|
||
def process_data(json_file): | ||
with open(json_file, "r") as file: | ||
data = json.load(file) | ||
|
||
processed_data = [] | ||
for item in data: | ||
if item['type'] == 'qa': | ||
text_for_embedding = f"{item['content']['pregunta']} {item['content']['respuesta']}" | ||
elif item['type'] == 'info': | ||
text_for_embedding = f"{item['content']['titulo']} {item['content'].get('descripcion', '')}" | ||
else: | ||
continue | ||
|
||
processed_data.append({ | ||
'text_for_embedding': text_for_embedding, | ||
'full_content': item['content'], | ||
'type': item['type'], | ||
'url': item.get('url', ''), | ||
'metadata': item.get('metadata', {}) | ||
}) | ||
|
||
return pd.DataFrame(processed_data) | ||
|
||
def initialize_or_load_index(df): | ||
if os.path.exists(embedding_file) and os.path.exists(index_file): | ||
embeddings = load_embeddings(embedding_file) | ||
index = faiss.read_index(index_file) | ||
else: | ||
df['embedding'] = df['text_for_embedding'].apply(lambda x: get_embedding(x)) | ||
embedding_matrix = np.array(df['embedding'].tolist()).astype('float32') | ||
embedding_matrix /= np.linalg.norm(embedding_matrix, axis=1)[:, None] | ||
|
||
index = faiss.IndexFlatIP(embedding_matrix.shape[1]) | ||
index.add(embedding_matrix) | ||
|
||
save_embeddings(df['embedding'].tolist(), embedding_file) | ||
faiss.write_index(index, index_file) | ||
embeddings = df['embedding'].tolist() | ||
|
||
return index, embeddings | ||
|
||
def search(query, df, index, k=3): | ||
query_embedding = np.array(get_embedding(query)).astype('float32') | ||
query_embedding /= np.linalg.norm(query_embedding) | ||
D, I = index.search(np.array([query_embedding]), k) | ||
|
||
results = [] | ||
for i in range(k): | ||
result = df.iloc[I[0][i]] | ||
results.append({ | ||
'content': result['full_content'], | ||
'url': result['url'], | ||
'type': result['type'], | ||
'metadata': result['metadata'], | ||
'similarity_score': float(D[0][i]) | ||
}) | ||
|
||
return results | ||
|
||
# Initialize data and index | ||
df = process_data(os.path.join(settings.BASE_DIR, "chatbot_laura", "preguntas_respuestas_procesadasV1.json")) | ||
index, embeddings = initialize_or_load_index(df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from django.db import models | ||
|
||
class SearchResult(models.Model): | ||
content = models.JSONField() | ||
url = models.URLField(blank=True) | ||
type = models.CharField(max_length=50) | ||
metadata = models.JSONField(default=dict) | ||
similarity_score = models.FloatField() | ||
|
||
def __str__(self): | ||
return f"{self.type} - {self.similarity_score}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
# chatbot_laura/routing.py | ||
from django.urls import path | ||
from .consumers import LauraChatConsumer | ||
from channels.routing import ProtocolTypeRouter, URLRouter | ||
from django.urls import path | ||
from chatbot_laura.views import ChatConsumer | ||
|
||
websocket_urlpatterns = [ | ||
path('ws/laura-chat/', LauraChatConsumer.as_asgi()), | ||
] | ||
application = ProtocolTypeRouter({ | ||
"websocket": URLRouter([ | ||
path("ws/chat/", ChatConsumer.as_asgi()), | ||
]), | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from django.urls import path | ||
from .views import chatbot_laura_view | ||
from . import views | ||
|
||
urlpatterns = [ | ||
path('chatbot-laura/', chatbot_laura_view, name='chatbot_laura'), | ||
] | ||
path('search/', views.search_view, name='search'), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,26 @@ | ||
from django.http import JsonResponse | ||
from rest_framework.decorators import api_view | ||
from .chatbot_laura_logic import ChatbotLauraLogic | ||
from channels.generic.websocket import AsyncWebsocketConsumer | ||
import json | ||
from .chatbot_logic import search, df, index | ||
|
||
# Inicializar el chatbot globalmente para que se cargue solo una vez | ||
chatbot_logic = ChatbotLauraLogic() | ||
class ChatConsumer(AsyncWebsocketConsumer): | ||
async def connect(self): | ||
await self.accept() | ||
|
||
@api_view(['POST']) | ||
def chatbot_laura_view(request): | ||
""" | ||
Vista para procesar las solicitudes del chatbot Laura. | ||
""" | ||
try: | ||
# Obtener el mensaje enviado por el usuario desde el cuerpo del POST | ||
mensaje = request.data.get('mensaje', '') | ||
if not mensaje: | ||
return JsonResponse({'error': 'No se proporcionó el mensaje'}, status=400) | ||
|
||
# Realizar la búsqueda utilizando la lógica del chatbot | ||
resultados = chatbot_logic.search(mensaje) | ||
async def disconnect(self, close_code): | ||
pass | ||
|
||
# Devolver los resultados como una respuesta JSON | ||
return JsonResponse({'resultados': resultados}, status=200) | ||
|
||
except Exception as e: | ||
return JsonResponse({'error': str(e)}, status=500) | ||
async def receive(self, text_data): | ||
text_data_json = json.loads(text_data) | ||
query = text_data_json['message'] | ||
|
||
results = search(query, df, index) | ||
|
||
await self.send(text_data=json.dumps({ | ||
'message': results | ||
})) | ||
|
||
def search_view(request): | ||
query = request.GET.get('query', '') | ||
results = search(query, df, index) | ||
return JsonResponse({'results': results}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.