Skip to content

Commit

Permalink
Merge pull request #407 from NexaAI/perry/server-dev
Browse files Browse the repository at this point in the history
added file path params for server whisper API
  • Loading branch information
Davidqian123 authored Mar 1, 2025
2 parents 01d3ceb + 8443596 commit e6f017d
Showing 1 changed file with 48 additions and 15 deletions.
63 changes: 48 additions & 15 deletions nexa/gguf/server/nexa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,18 +1720,28 @@ async def process_audio(
language: Optional[str] = Query(
None, description="Language code (e.g. 'en', 'fr') for transcription."),
temperature: Optional[float] = Query(
0.0, description="Temperature for sampling.")
0.0, description="Temperature for sampling."),
tmp_file_dir: Optional[str] = Query(
None, description="Directory to save temporary audio file. If not provided, uses system temp directory.")
):
temp_audio_path = None
try:
if not whisper_model:
raise HTTPException(
status_code=400,
detail="Whisper model is not loaded. Please load a Whisper model first."
)

with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_audio:
temp_audio.write(await file.read())
temp_audio_path = temp_audio.name
# Modify temp file creation to use custom directory if provided
if tmp_file_dir:
os.makedirs(tmp_file_dir, exist_ok=True)
temp_audio_path = os.path.join(tmp_file_dir, f"temp_{file.filename}")
with open(temp_audio_path, 'wb') as temp_audio:
temp_audio.write(await file.read())
else:
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_audio:
temp_audio.write(await file.read())
temp_audio_path = temp_audio.name

# Set up parameters for Whisper or similar model
task_params = {
Expand All @@ -1753,9 +1763,13 @@ async def process_audio(
raise HTTPException(
status_code=500, detail=f"Error during {task}: {str(e)}")
finally:
if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path):
os.unlink(temp_audio_path)

# Clean up temp file if it was created
if temp_audio_path and os.path.exists(temp_audio_path):
try:
os.unlink(temp_audio_path)
logging.info(f"Cleaned up temporary file: {temp_audio_path}")
except Exception as e:
logging.error(f"Error cleaning up temporary file {temp_audio_path}: {e}")

@app.post("/v1/audio/processing_stream", tags=["Audio"])
async def processing_stream_audio(
Expand All @@ -1768,18 +1782,30 @@ async def processing_stream_audio(
"auto", description="Language code (e.g., 'en', 'fr')"),
min_chunk: Optional[float] = Query(
1.0, description="Minimum chunk duration for streaming"),
tmp_file_dir: Optional[str] = Query(
None, description="Directory to save temporary audio file. If not provided, uses system memory.")
):
temp_audio_path = None
try:
if not whisper_model:
raise HTTPException(
status_code=400,
detail="Whisper model is not loaded. Please load a Whisper model first."
)

# Read the entire file into memory
audio_bytes = await file.read()
a_full = load_audio_from_bytes(audio_bytes)
duration = len(a_full) / SAMPLING_RATE
# Modify audio loading to optionally save to file
if tmp_file_dir:
os.makedirs(tmp_file_dir, exist_ok=True)
temp_audio_path = os.path.join(tmp_file_dir, f"temp_{file.filename}")
with open(temp_audio_path, 'wb') as temp_audio:
audio_bytes = await file.read()
temp_audio.write(audio_bytes)
# Read the saved file
a_full = load_audio_from_bytes(audio_bytes)
else:
# Original in-memory processing
audio_bytes = await file.read()
a_full = load_audio_from_bytes(audio_bytes)

# Only include language parameter if task is "transcribe"
# For "translate", the language is always defined as "en"
Expand All @@ -1798,13 +1824,13 @@ async def processing_stream_audio(

def stream_generator():
nonlocal beg
while beg < duration:
while beg < len(a_full) / SAMPLING_RATE:
now = time.time() - start
if now < beg + min_chunk:
time.sleep((beg + min_chunk) - now)
end = time.time() - start
if end > duration:
end = duration
if end > len(a_full) / SAMPLING_RATE:
end = len(a_full) / SAMPLING_RATE

chunk_samples = int((end - beg)*SAMPLING_RATE)
chunk_audio = a_full[int(
Expand Down Expand Up @@ -1839,7 +1865,14 @@ def stream_generator():
except Exception as e:
logging.error(f"Error in audio processing stream: {e}")
raise HTTPException(status_code=500, detail=str(e))

finally:
# Clean up temp file if it was created
if temp_audio_path and os.path.exists(temp_audio_path):
try:
os.unlink(temp_audio_path)
logging.info(f"Cleaned up temporary file: {temp_audio_path}")
except Exception as e:
logging.error(f"Error cleaning up temporary file {temp_audio_path}: {e}")

@app.post("/v1/audiolm/chat/completions", tags=["AudioLM"])
async def audio_chat_completions(
Expand Down

0 comments on commit e6f017d

Please sign in to comment.