-
Notifications
You must be signed in to change notification settings - Fork 1
/
transcribe.py
110 lines (87 loc) · 2.73 KB
/
transcribe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from dataclasses import dataclass
import logging
import sys
import traceback
from modal import Image
import common
from common import app
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@dataclass
class TranscriptionProgress:
percent_done: int
transcript: dict = None
class TranscriptionError(Exception):
pass
def load_whisper():
import whisper
whisper.load_model(common.MODEL_NAME)
transcriber_image = (
Image.debian_slim(python_version="3.10.8")
.apt_install("ffmpeg")
.pip_install(
"https://github.com/openai/whisper/archive/v20230314.tar.gz", "tqdm"
)
.run_function(load_whisper)
)
@app.function(
gpu=["A100-40GB", "A10G"],
cpu=8.0,
container_idle_timeout=180,
image=transcriber_image,
network_file_systems=common.nfs,
timeout=1200,
)
def transcribe(transcription_id, language, prompt=None):
import torch.multiprocessing as mp
t = common.db.select(transcription_id)
if not t:
raise TranscriptionError(f"invalid id : {transcription_id}")
device = common.get_device()
mp.set_start_method("spawn", force=True)
q = mp.Queue()
p = mp.Process(
target=worker,
args=(q, str(t.transcoded_file), device, language, prompt),
)
logger.info("spawning whisper process")
p.start()
while True:
res = q.get()
if res is None:
break
yield res
p.join()
def worker(q, audio, device, language, prompt):
import tqdm
import whisper
import whisper.transcribe
class Progress(tqdm.tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._current = self.n
def update(self, n):
super().update(n)
self._current += n
percent_done = int(100 * self._current / self.total)
logger.info(f"progress: {percent_done}")
q.put(percent_done)
try:
# patch whisper so we can generate progress
transcribe_module = sys.modules["whisper.transcribe"]
transcribe_module.tqdm.tqdm = Progress
# run recognition and send back the transcript. this will also send
# progress back to the parent process via the given pipe
use_gpu = device == "gpu"
logger.info(f"transcribe loading model")
model = whisper.load_model(common.MODEL_NAME, device=device)
logger.info(f"transcribe {language} (gpu:{use_gpu}). prompt: {prompt}")
transcript = model.transcribe(
audio, language=language, prompt=prompt, fp16=use_gpu, verbose=False
)
q.put(transcript)
q.put(None)
except Exception as e:
traceback.print_exc()
q.put(e)
q.put(None)