forked from yihong0618/ChatTTS
-
Notifications
You must be signed in to change notification settings - Fork 12
/
main.py
56 lines (43 loc) · 1.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/env python3
import wave
import numpy as np
from fastapi import FastAPI, HTTPException, Depends
import pydantic
import ChatTTS
app = FastAPI()
class TTSInput(pydantic.BaseModel):
text: str
output_path: str
seed: int = 697
def get_chat_model() -> ChatTTS.Chat:
chat = ChatTTS.Chat()
chat.load_models()
return chat
@app.post("/tts")
def tts(input: TTSInput, chat: ChatTTS.Chat = Depends(get_chat_model)):
try:
texts = [input.text]
r = chat.sample_random_speaker(seed=input.seed)
params_infer_code = {
'spk_emb': r, # add sampled speaker
'temperature': .3, # using customtemperature
'top_P': 0.7, # top P decode
'top_K': 20, # top K decode
}
params_refine_text = {
'prompt': '[oral_2][laugh_0][break_6]'
}
wavs = chat.infer(texts,
params_infer_code=params_infer_code,
params_refine_text=params_refine_text, use_decoder=True)
audio_data = np.array(wavs[0], dtype=np.float32)
sample_rate = 24000
audio_data = (audio_data * 32767).astype(np.int16)
with wave.open(input.output_path, "w") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio_data.tobytes())
return {"output_path": input.output_path}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))