Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

immediate display transcribed text added #472

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ayushma/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ChatMessage(BaseModel):
original_message = models.TextField(blank=True, null=True)
language = models.CharField(max_length=10, blank=False, default="en")
reference_documents = models.ManyToManyField(Document, blank=True)
# generated ayushma voice audio via TTS
audio = models.FileField(blank=True, null=True)
meta = models.JSONField(blank=True, null=True)
temperature = models.FloatField(blank=True, null=True)
Expand Down
15 changes: 12 additions & 3 deletions ayushma/serializers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ class ConverseSerializer(serializers.Serializer):
stream = serializers.BooleanField(default=True)
generate_audio = serializers.BooleanField(default=True)
noonce = serializers.CharField(required=False)
transcript_start_time = serializers.FloatField(required=False)
transcript_end_time = serializers.FloatField(required=False)


class ChatDetailSerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -146,9 +148,11 @@ def get_chats(self, obj):
)
return [
{
"messageType": ChatMessageType.USER
if thread_message.role == "user"
else ChatMessageType.AYUSHMA,
"messageType": (
ChatMessageType.USER
if thread_message.role == "user"
else ChatMessageType.AYUSHMA
),
"message": thread_message.content[0].text.value,
"reference_documents": thread_message.content[0].text.annotations,
"language": "en",
Expand All @@ -159,3 +163,8 @@ def get_chats(self, obj):
chatmessages = ChatMessage.objects.filter(chat=obj).order_by("created_at")
context = {"request": self.context.get("request")}
return ChatMessageSerializer(chatmessages, many=True, context=context).data


class SpeechToTextSerializer(serializers.Serializer):
audio = serializers.FileField(required=True)
language = serializers.CharField(default="en")
6 changes: 6 additions & 0 deletions ayushma/utils/converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def converse_api(
audio = request.data.get("audio")
text = request.data.get("text")
language = request.data.get("language") or "en"

try:
service: Service = request.service
except AttributeError:
Expand Down Expand Up @@ -128,6 +129,11 @@ def converse_api(
translated_text = transcript

elif converse_type == "text":
if request.data.get("transcript_start_time") and request.data.get(
"transcript_end_time"
):
stats["transcript_start_time"] = request.data["transcript_start_time"]
stats["transcript_end_time"] = request.data["transcript_end_time"]
translated_text = text

if language != "en":
Expand Down
44 changes: 43 additions & 1 deletion ayushma/views/chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import time

from django.conf import settings
from drf_spectacular.utils import extend_schema
from rest_framework import filters, status
from rest_framework.decorators import action
from rest_framework.decorators import action, api_view, permission_classes
from rest_framework.exceptions import ValidationError
from rest_framework.mixins import (
CreateModelMixin,
Expand All @@ -20,8 +22,10 @@
ChatFeedbackSerializer,
ChatSerializer,
ConverseSerializer,
SpeechToTextSerializer,
)
from ayushma.utils.converse import converse_api
from ayushma.utils.speech_to_text import speech_to_text
from utils.views.base import BaseModelViewSet
from utils.views.mixins import PartialUpdateModelMixin

Expand All @@ -42,6 +46,7 @@ class ChatViewSet(
"retrieve": ChatDetailSerializer,
"list_all": ChatDetailSerializer,
"converse": ConverseSerializer,
"speech_to_text": SpeechToTextSerializer,
}
permission_classes = (IsTempTokenOrAuthenticated,)
lookup_field = "external_id"
Expand Down Expand Up @@ -100,6 +105,43 @@ def list_all(self, *args, **kwarg):
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)

@extend_schema(
tags=("chats",),
)
@action(detail=True, methods=["post"])
def speech_to_text(self, *args, **kwarg):
serializer = self.get_serializer(data=self.request.data)
serializer.is_valid()

project_id = kwarg["project_external_id"]
audio = serializer.validated_data["audio"]
language = serializer.validated_data.get("language", "en")

stats = {}
try:
stt_engine = Project.objects.get(external_id=project_id).stt_engine
except Project.DoesNotExist:
return Response(
{"error": "Project not found"}, status=status.HTTP_400_BAD_REQUEST
)
try:
stats["transcript_start_time"] = time.time()
transcript = speech_to_text(stt_engine, audio, language + "-IN")
stats["transcript_end_time"] = time.time()
translated_text = transcript
except Exception as e:
print(f"Failed to transcribe speech with {stt_engine} engine: {e}")
return Response(
{
"error": "Something went wrong in getting transcription, please try again later"
},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

return Response(
{"transcript": translated_text, "stats": stats}, status=status.HTTP_200_OK
)

@extend_schema(
tags=("chats",),
)
Expand Down
Loading