From 194615ce043a1db8af3c2ffb2b69b17e1416142d Mon Sep 17 00:00:00 2001
From: linyq <linyqemail@163.com>
Date: Fri, 23 Aug 2024 13:53:13 +0800
Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=20fastapi=20=E6=8E=A5?=
 =?UTF-8?q?=E5=8F=A3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 TTS/server/server_fastapi.py | 236 +++++++++++++++++++++++++++++++++++
 1 file changed, 236 insertions(+)
 create mode 100644 TTS/server/server_fastapi.py

diff --git a/TTS/server/server_fastapi.py b/TTS/server/server_fastapi.py
new file mode 100644
index 0000000000..5d222bbf0e
--- /dev/null
+++ b/TTS/server/server_fastapi.py
@@ -0,0 +1,236 @@
+"""
+这是coqui-tts的fastapi接口
+pip install fastapi uvicorn python-multipart
+"""
+import io
+import json
+import os
+import logging
+import shutil
+import uuid
+import traceback
+from pathlib import Path
+from threading import Lock
+from typing import Union, List, Optional
+from pydantic import Field, BaseModel
+from fastapi import FastAPI, Query, Form, HTTPException, UploadFile, File
+from fastapi.responses import StreamingResponse, FileResponse
+
+from TTS.utils.manage import ModelManager
+from TTS.utils.synthesizer import Synthesizer
+
+app = FastAPI()
+lock = Lock()
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+path = Path(__file__).parent / "../.models.json"
+manager = ModelManager(path)
+
+# 将正在使用的模型更新为指定的已发布模型。
+model_path = None
+config_path = None
+speakers_file_path = None
+vocoder_path = None
+vocoder_config_path = None
+
+
+def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict, None]:
+    """
+    该函数功能是将一个uri样式的wave文件路径或GST(Guided Style Tokens)
+    字典转换为相应的格式。如果输入是一个字符串类型的路径,且该路径指向一个以".wav"结尾的文件,
+    则直接返回该路径。如果输入是一个字符串类型的JSON表示的GST字典,则将其解析为字典格式并返回。
+    如果输入为空字符串或既不是文件路径也不是GST字典的字符串,则返回None。
+    Args:
+        style_wav (str): uri
+    Returns:
+        Union[str, dict]: path to file (str) or gst style (dict)
+    """
+    if style_wav:
+        if os.path.isfile(style_wav) and style_wav.endswith(".wav"):
+            return style_wav  # style_wav 是位于服务器上的.wav文件
+        style_wav = json.loads(style_wav)
+        return style_wav  # style_wav 是带有 {token1_id : token1_weigth, ...} 的 GST 字典
+    return None
+
+
+@app.get("/list_models", summary="获取模型列表")
+async def get_list_models():
+    """
+    获取所有可用模型列表
+    """
+    tts_models = manager.list_tts_models()
+    vocoder_models = manager.list_vocoder_models()
+    return {"tts_models": tts_models, "vocoder_models": vocoder_models}
+
+
+@app.get("/load_model", summary="加载模型")
+async def load_models(
+        model_name: str = Query(description="模型名称"),
+        vocoder_name: str = Query(default=None, description="vocoder名称,一般不用填"),
+        config_path: str = Query(default=None, description="配置文件路径,一般不用填,只有`tts_models/multilingual/multi-dataset/xtts_v2`模型需要填 `/root/.local/share/tts/tts_models--multilingual--multi-dataset--xtts_v2/config.json`")
+):
+    """
+    加载模型
+    """
+    global synthesizer
+
+    model_path, tts_config_path, model_item = manager.download_model(model_name)
+    vocoder_name = model_item["default_vocoder"] if vocoder_name is None else vocoder_name
+
+    if vocoder_name is not None:
+        vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name)
+
+    if config_path is not None:
+        tts_config_path = config_path
+
+    # 加载模型
+    synthesizer = Synthesizer(
+        tts_checkpoint=model_path,
+        tts_config_path=tts_config_path,
+        tts_speakers_file="",
+        tts_languages_file="",
+        vocoder_checkpoint="",
+        vocoder_config="",
+        encoder_checkpoint="",
+        encoder_config="",
+        use_cuda=os.getenv("CUDA", False)
+    )
+    # 是否使用多发言人模式
+    speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
+    # 是否使用多语言模式
+    language_manager = getattr(synthesizer.tts_model, "language_manager", None)
+
+    return {
+        "message": "模型加载成功",
+        "speaker_manager": speaker_manager,
+        "language_manager": language_manager
+    }
+
+
+class TTSRequest(BaseModel):
+    """
+    语音合成请求
+    """
+    text: Optional[str] = Field(description="需要转换的文本")
+    speaker_idx: Optional[str] = Field(default=None, description="说话人 id")
+    language_idx: Optional[str] = Field(default=None, description="语言 id")
+    speed: Optional[float] = Field(default=1.0, description="生成音频的速度。默认为 1.0。(如果远低于 1.0,可能会产生伪影)"),
+    split_sentences: Optional[bool] = Field(default=True, description="将输入文本拆分为句子")
+    # Todo: 下面参数还没看懂实质上的作用;暂时不添加到接口中
+    # style_wav: Optional[str] = Field(default=None, description="GST的样式波形")
+    # style_text: Optional[str] = Field(default=None, description="Capacitron 的 style_wav 转录")
+    # reference_wav: Optional[str] = Field(default=None, description="用于语音转换的参考波形")
+    # reference_speaker_name: Optional[str] = Field(default=None, description="参考波形的扬声器 ID")
+
+@app.get("/tts", summary="语音合成")
+def tts(
+        text: Optional[str] = Query(description="需要转换的文本"),
+        speaker_idx: Optional[str] = Query(default=None, description="说话人"),
+        language_idx: Optional[str] = Query(default=None, description="语种;如:en"),
+        speed: Optional[float] = Query(default=1.0, description="生成音频的速度。默认为 1.0。(如果远低于 1.0,可能会产生伪影)"),
+        split_sentences: Optional[bool] = Query(default=True, description="将输入文本拆分为句子")
+):
+    """
+    语音合成
+    """
+    try:
+        with lock:
+            # 使用异常处理来增加健壮性
+            try:
+                wavs = synthesizer.tts(
+                    text=text,
+                    speaker_name=speaker_idx,
+                    language_name=language_idx,
+                    split_sentences=split_sentences,
+                    speed=speed
+                )
+                out = io.BytesIO()
+                synthesizer.save_wav(wavs, out)
+                # synthesizer.save_wav(wavs, "output_tts.wav")
+            except AttributeError as e:
+                logger.error(f"语音合成失败: {e}")
+                raise HTTPException(status_code=501, detail="未加载TTS模型")
+            except Exception:
+                logger.error(f"语音合成失败: {traceback.format_exc()}")
+                raise HTTPException(status_code=500, detail="语音合成失败")
+    except Exception as e:
+        logger.error(f"处理请求时出错: {traceback.format_exc()}")
+        raise HTTPException(status_code=500, detail=str(e))
+
+    headers = {
+        'Content-Disposition': 'inline; filename="output.wav"'
+    }
+    return StreamingResponse(out, media_type="audio/wav", headers=headers)
+    # return FileResponse("output_tts.wav", media_type="audio/wav", filename="output.wav")
+
+
+@app.post("/clone_tts", summary="语音克隆并合成")
+def clone_tts(
+        text: Optional[str] = Form(description="需要转换的文本"),
+        speaker_idx: Optional[str] = Form(default=None, description="说话人"),
+        language_idx: Optional[str] = Form(default=None, description="语种;如:en"),
+        speed: Optional[float] = Form(default=1.0, description="生成音频的速度。默认为 1.0。(如果远低于 1.0,可能会产生伪影)"),
+        split_sentences: Optional[bool] = Form(default=True, description="将输入文本拆分为句子"),
+        speaker_wav: List[UploadFile] = File(..., description="需要克隆的说话人音频wav文件;支持单个或多个文件;单个音频需要大于3s;wav文件数量和种类决定了克隆的效果")
+):
+    """
+    语音克隆并合成
+    """
+    try:
+        with lock:
+            # 创建缓存目录
+            cache_dir = "wav_cache"
+            if not os.path.exists(cache_dir):
+                os.makedirs(cache_dir)
+
+            # 保存上传的文件并获取文件路径
+            speaker_wav_paths = []
+            if speaker_wav is not None:
+                for file in speaker_wav:
+                    file_id = str(uuid.uuid4())
+                    file_path = os.path.join(cache_dir, f"{file_id}_{file.filename}")
+                    with open(file_path, "wb") as f:
+                        shutil.copyfileobj(file.file, f)
+                    speaker_wav_paths.append(file_path)
+
+            logger.info(f" > Model input: {text}")
+            logger.info(f" > Language Idx: {language_idx}")
+
+            # 使用异常处理来增加健壮性
+            try:
+                wavs = synthesizer.tts(
+                    text=text,
+                    language_name=language_idx,
+                    speaker_wav=speaker_wav_paths,
+                    split_sentences=split_sentences,
+                    speed=speed,
+                    speaker_name=speaker_idx
+                )
+                out = io.BytesIO()
+                synthesizer.save_wav(wavs, out)
+            except AttributeError as e:
+                logger.error(f"语音合成失败: {e}")
+                raise HTTPException(status_code=501, detail="未加载TTS模型")
+            except Exception:
+                logger.error(f"语音合成失败: {traceback.format_exc()}")
+                raise HTTPException(status_code=500, detail="语音合成失败")
+    except Exception as e:
+        logger.error(f"处理请求时出错: {e}")
+        raise HTTPException(status_code=500, detail=str(e))
+    finally:
+        # 清理缓存文件
+        for file_path in speaker_wav_paths:
+            if os.path.exists(file_path):
+                os.remove(file_path)
+
+        # 如果缓存目录为空,则删除
+        if not os.listdir(cache_dir):
+            os.rmdir(cache_dir)
+    return StreamingResponse(out, media_type="audio/wav")
+
+
+if __name__ == "__main__":
+    import uvicorn
+
+    uvicorn.run(app, host='0.0.0.0', port=5002, log_level="debug")