-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstt_llm_tts_model.py
624 lines (530 loc) · 28.3 KB
/
stt_llm_tts_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
import time
import torch
import numpy as np
from copy import deepcopy
from nemo_loader import load_rnnt_model
from style_tts2_model import StyleTTS2Model
from transformers import AutoTokenizer, AutoModelForCausalLM
VERBOSE = False
class STT(torch.nn.Module):
# Constants or FastConformer RNNT model
LAST_CHANNEL_CACHE_SIZE = 70
MAX_SYMBOLS = 10
SOS = 1024
BLANK_INDEX = 1024
PRED_RNN_LAYERS = 1
# Options for loading model checkpoint from NGC cloud
NEMO_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_fastconformer_hybrid_large_streaming_80ms/versions/1.20.0/files/stt_en_fastconformer_hybrid_large_streaming_80ms.nemo"
NEMO_DESCRIPTION = "For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_80ms"
def __init__(self, vocabulary_path):
"""Initialize the encoder, decoder and joint model of the Streaming STT FastConformer.
Args:
vocabulary_path: Valid file path to a .npy file containing the vocabulary of the Joint model
Returns:
Instantce of the Streaming STT model
"""
super(STT, self).__init__()
# load vocabulary
self.vocabulary = list(np.load(vocabulary_path))
# load encoder, decoder and joint of FastConformer RNNT model (checkpint is downloaded and cahched)
enc_dec_rnnt_model = load_rnnt_model(self.NEMO_URL, self.NEMO_DESCRIPTION)
self.encoder = enc_dec_rnnt_model.encoder.eval()
self.decoder = enc_dec_rnnt_model.decoder.eval()
self.joint = enc_dec_rnnt_model.joint.eval()
# Init encoder state
self.cache_last_channel, self.cache_last_time, self.cache_last_channel_len = self.encoder.get_initial_cache_state(batch_size=1)
## Init decoder state
self.y_sequence = []
self.dec_state = None
self.last_token = None
def forward(self,processed_signal, processed_signal_length):
"""Perform an encoding, decoding and joint step on the given audio signal.
Args:
processed_signal: preprocessed audio chunk
processed_signal_length: number of encodedings in the chunk
Returns:
Sequence of transcribed tokens
"""
with torch.no_grad():
## call encoder
with torch.jit.optimized_execution(False):
encoder_output = self.encoder(
audio_signal=processed_signal,
length=processed_signal_length,
cache_last_channel=self.cache_last_channel,
cache_last_time=self.cache_last_time,
cache_last_channel_len=self.cache_last_channel_len,
)
encoded,encoded_len,self.cache_last_channel,self.cache_last_time,self.cache_last_channel_len = encoder_output
self.cache_last_channel = self.cache_last_channel[:, :, -self.LAST_CHANNEL_CACHE_SIZE :, :]
## call decoder and joint
with torch.inference_mode():
self.y_sequence, self.dec_state, self.last_token = self.greedy_RNNT_decode(
encoder_output=encoded,
encoded_lengths=encoded_len,
y_sequence=self.y_sequence,
dec_state= self.dec_state,
last_token=self.last_token)
return self.y_sequence
def states_to_device(self, dec_state, device='cpu'):
"""Maps the decoding state to the given device
Args:
dec_state: hidden state of the decoder model
device: target device
Returns:
Hidden state of the decoder mapped to the target device
"""
if torch.is_tensor(dec_state):
dec_state = dec_state.to(device)
elif isinstance(dec_state, (list, tuple)):
dec_state = tuple(self.states_to_device(dec_i, device) for dec_i in dec_state)
return dec_state
def label_collate(self, labels, device=None):
if isinstance(labels, torch.Tensor):
return labels.type(torch.int64)
batch_size = len(labels)
max_len = max(len(label) for label in labels)
cat_labels = np.full((batch_size, max_len), fill_value=0.0, dtype=np.int32)
for e, l in enumerate(labels):
cat_labels[e, : len(l)] = l
labels = torch.tensor(cat_labels, dtype=torch.int64, device=device)
return labels
def batch_select_state(self, batch_states, idx):
if batch_states is not None:
state_list = []
for state_id in range(len(batch_states)):
states = [batch_states[state_id][layer][idx] for layer in range(self.PRED_RNN_LAYERS)]
state_list.append(states)
return state_list
else:
return None
def batch_concat_states(self, batch_states):
state_list = []
for state_id in range(len(batch_states[0])):
batch_list = []
for sample_id in range(len(batch_states)):
tensor = torch.stack(batch_states[sample_id][state_id]) # [L, H]
tensor = tensor.unsqueeze(0) # [1, L, H]
batch_list.append(tensor)
state_tensor = torch.cat(batch_list, 0) # [B, L, H]
state_tensor = state_tensor.transpose(1, 0) # [L, B, H]
state_list.append(state_tensor)
return state_list
def tokens_to_text(self,prediction):
"""Translate a sequence of predicted tokens to text
Args:
prediction: sequence of tokens
Returns:
A string containing the translated text
"""
prediction = [p for p in prediction if p != self.BLANK_INDEX]
# De-tokenize the integer tokens
text = ""
for token in prediction:
text += self.vocabulary[token]
return text.replace("▁"," ")
def decoder_step(self, label,hidden):
"""Perform a single decoding step
Args:
label: previous predicted token
hidden: hidden state after the last decoding step
Returns:
decoder outut and new hidden state
"""
if isinstance(label, torch.Tensor):
if label.dtype != torch.long:
label = label.long()
else:
if label == self.SOS:
# Last token was start or stopp token, call decoder with empty target
return self.decoder.predict(None, hidden, add_sos=False, batch_size=None)
label = self.label_collate([[label]])
# call decoder conditioned on the previous predicted label
return self.decoder.predict(label, hidden, add_sos=False, batch_size=None)
def joint_step(self, enc, pred):
"""Perform a single joint step
Args:
enc: encoded audio signal
pred: decoder output
Returns:
probabilities over the tokens of the vocabulary
"""
with torch.no_grad():
logits = self.joint.joint(enc, pred)
return logits
def greedy_RNNT_decode(self, encoder_output, encoded_lengths, y_sequence = [], dec_state=None, last_token = None):
"""Perform decoder and joint step for every encoded signal in the given audio chunk
Args:
encoder_output: encoded audio chunk containing multiple encodings
encoded_lengths: number of encodings in the chunk
y_sequence: current sequence of transcribed tokens
dec_state: previous decoder hidden state
last_token: previous predicted token
Returns:
updated sequence of transcribed tokens
updated decoder hidden state
updated last predicted token
"""
encoder_output = encoder_output.transpose(1, 2)
encoder_output = encoder_output[0, :, :].unsqueeze(1)
encoded_lengths = encoded_lengths[0]
y_sequence = (y_sequence.cpu().tolist() if isinstance(y_sequence, torch.Tensor) else y_sequence)
if dec_state is not None:
dec_state = self.batch_concat_states([dec_state])
dec_state = self.states_to_device(dec_state, encoder_output.device)
# For timestep t in X_t
for time_idx in range(encoded_lengths):
# Extract encoder embedding at timestep t
encoder_output_t = encoder_output.narrow(dim=0, start=time_idx, length=1)
# Setup exit flags and counter
not_blank = True
symbols_added = 0
# While blank is not predicted, or we dont run out of max symbols per timestep
while not_blank and (self.MAX_SYMBOLS is None or symbols_added < self.MAX_SYMBOLS):
# In the first timestep, we initialize the network with RNNT Blank
# In later timesteps, we provide previous predicted label as input.
if last_token is None and dec_state is None:
last_label = self.SOS
else:
last_label = self.label_collate([[last_token]])
# Decoder + Joint Step
dec_output, hidden_prime = self.decoder_step(last_label, dec_state)
logp = self.joint_step(encoder_output_t, dec_output)[0, 0, 0, :]
del dec_output
# torch.max(0) op doesnt exist for FP 16.
if logp.dtype != torch.float32:
logp = logp.float()
# get index k, of max prob
v, k = logp.max(0)
k = k.item() # K is the label at timestep t_s in inner loop, s >= 0.
del logp
# If blank token is predicted, exit inner loop, move onto next timestep t
if k == self.BLANK_INDEX:
not_blank = False
else:
y_sequence.append(k)
dec_state = hidden_prime
last_token = k
# Increment token counter.
symbols_added += 1
# prpare outputs and decoder state for next step
y_sequence = ( y_sequence.to(torch.long) if isinstance(y_sequence, torch.Tensor) else torch.tensor(y_sequence, dtype=torch.long))
dec_state = self.batch_select_state(dec_state, 0)
dec_state = self.states_to_device(dec_state)
return y_sequence, dec_state, last_token
class LLM(torch.nn.Module):
MIN_LENGTH = 10 # minimum number of generated tokens before EOS token can be predicted
def __init__(self, model_name, device="cuda"):
"""Initialized tokenizer and LLM model from huggingface
Args:
model_name: huggingface model descriptor
Returns:
Instnatce of the LLM model
"""
super(LLM, self).__init__()
self.device = device
# Initialize given model. Weights are downloaded from HF hub and cached
self.model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, trust_remote_code=True).to(self.device)
self.model.eval()
# Initialize tokenizer for given model. Weights are downloaded from HF hub and cached
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Specify special tokens for inference
self.pad_token_id = self.model.generation_config.eos_token_id
self.eos_token_id = [self.model.generation_config.eos_token_id]
self.eos_token_id_tensor = torch.tensor(self.eos_token_id).to(self.device)
self.sentence_stop_token_id_tensor = torch.tensor([13]).to(self.device)
def tokenize(self, text):
"""Resolve the given text into a sequence of tokens
Args:
text: string of arbitrary length
Returns:
Sequence of tokens
"""
return self.tokenizer(text, return_tensors="pt").input_ids.to(self.device)
def detokenize(self, tokens):
"""Resolve the given sequence of tokens to a text string
Args:
tokens: sequence of tokens
Returns:
string containing the resolved text
"""
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
def forward(self, input_token, past_key_values, cur_len):
"""Perform a single LLM inference step
Args:
input_token: either input token from user or previous generated token from LLM
past_key_values: key-value cache from LLM
cur_len: length of the current token sequence that was generated by the LLM
Returns:
next token and updated key-value cache
"""
unfinished_sequences = torch.ones([1], dtype=torch.long, device=self.device)
# prepare model inputs and set key-value chache
model_inputs = self.model.prepare_inputs_for_generation(input_token, past_key_values=past_key_values, use_cache=True)
# forward pass to get logits and updated key value cache
outputs =self.model(**model_inputs, return_dict=True)
past_key_values = outputs.past_key_values
next_token_logits = outputs.logits[:, -1, :]
# Prevent prediction of EOS token if min lenght of sequence is not reached
if cur_len < self.MIN_LENGTH:
for i in self.eos_token_id:
next_token_logits[:, i] = -float("inf")
# Select token with highes probability
next_tokens = torch.argmax(next_token_logits, dim=-1)
next_tokens = next_tokens * unfinished_sequences + self.pad_token_id * (1 - unfinished_sequences)
return next_tokens, past_key_values
class TTS(torch.nn.Module):
def __init__(self, device="cuda"):
"""Initialize STT model and perform warm up steps
Args:
Returns:
Instance of TTS model
"""
super(TTS, self).__init__()
# init Style TTS model
self.tts_model = StyleTTS2Model(device=device)
# perform multiple inference steps for warmup
for _ in range(3):
start = time.time()
self.forward("warming up!")
if VERBOSE:
print("tts warm up: ", round(time.time() - start,3))
def forward(self, text):
"""Syntehsize the entire text
Args:
text: String of text that should be synthesized
Returns:
synthesized audio data in wave format
"""
wav = self.tts_model(text)
return wav
class STT_LLM_TTS(torch.nn.Module):
THRESHOLD_VOICE_DETECTION = -31000
def __init__(self, device):
"""Initialize STT, LLM and TTS model and the state of the voice assistant
Args:
Returns:
Instance of voice assistant model
"""
super(STT_LLM_TTS, self).__init__()
# Init STT model
self.stt = STT(vocabulary_path="vocab.npy")
# Init LLM model
self.llm = LLM(model_name="microsoft/phi-2", device=device)
# Init TTS model
self.tts = TTS(device=device)
# Init voice assistant state
self.transcribed_tokens = []
self.transcribed_words = []
self.current_word = ""
self.first = True
self.last_token_timestep = None
self.start_generation_timestep = None
self.past_key_values = None
self.past_key_values_backup = None
self.last_token = None
self.generating = False
self.response_sequence = []
self.response_sentence = []
self.transcribing = True
self.last_return = None
def forward(self, processed_signal, processed_signal_length):
"""Perform a single voice assistant forward path
1) If a processed signal is passed and speech is detected, the audio chunk is transcribed
and the transcribed tokens are passed to the LLM to update its key-value chache
2) If processed signal is None (buffer did not collected enough bytes of audio signal) or no speech is detected
Only LLM is called to generate the next token
Args:
processed_signal: preprossed audio chunk
processed_signal_length: Number of timesteps in the chunk
Returns:
If end of sequence of end of sentence token was generated, return generated text, syntehsized
audio and flag if an interrupt was detected
else: return None
"""
with torch.inference_mode():
with torch.no_grad():
transcribed_text = ""
if processed_signal is not None:
## Detect voice to activate transcription
if not self.transcribing and torch.sum(processed_signal).item() > self.THRESHOLD_VOICE_DETECTION:
self.transcribing = True
## Speech to text
if self.transcribing:
s = time.time()
# Call STT model to transcribe the given audio chunk
y_sequence = self.stt(processed_signal, processed_signal_length)
y_sequence = list(y_sequence.cpu().numpy())
# Select new tokens and add them to the current sequence. new_tokens can alls be empty
# if no word or subword could be detected in the chunk
new_tokens = y_sequence[len(self.transcribed_tokens):]
self.transcribed_tokens += new_tokens
# trascribe new token to text
transcribed_text = self.stt.tokens_to_text(new_tokens)
if VERBOSE:
print(" -STT: ",transcribed_text, round(time.time()-s,3))
# Handle Interrupt when generating but user continues speaking
if self.generating and len(transcribed_text)>0:
# User speaks while sentence is generated
# Decide if its part of the current query or the user wants to interrupt and ask something else
if time.time() - self.last_token_timestep <= 0.8:
# --> Same query
self.generating = False
self.last_token_timestep = time.time()
# reset key value chache to the beginning of generation
self.past_key_values = deepcopy(self.past_key_values_backup)
# reset generated sentence and whole generated sequence
self.response_sequence = []
self.response_sentence = []
else:
# --> Interrupt
# reset state of voice assistant entirely TODO: find better solution
self.first = True
self.past_key_values = None
self.generating = False
self.response_sentence = []
self.response_sequence = []
self.current_word = ""
self.transcribing = True
backup = self.start_generation_timestep
self.last_token_timestep = None
self.start_generation_timestep = None
return None, None, True
# Handle token generation when no new word was detected
if self.last_token_timestep is not None and not self.first and len(transcribed_text)==0:
# stop transcription after certain time until LLM generation is finished
# TODO disable STT right at the beginning and handle interrupts by amplitude in audio input
if self.transcribing:
if VERBOSE:
print("[Stop transcribing!]")
self.transcribing = False
# Start token generation
if not self.generating:
self.generating = True
# backup key-value chache to restore it in case on an interrupt
self.past_key_values_backup = deepcopy(self.past_key_values)
self.start_generation_timestep = time.time()
if VERBOSE:
print("\n\n[START] ", round(time.time()-self.last_token_timestep,3))
# pass the remaining words / subwords of the input sequence to the LLM
if self.current_word != "":
tokens = self.llm.tokenize(self.current_word)
for t in tokens[0]:
t = torch.unsqueeze(torch.unsqueeze(t, 0),0)
s = time.time()
self.last_token, self.past_key_values = self.llm(t, self.past_key_values, len(self.response_sequence))
if VERBOSE:
print(" -LLM (Input): ", round(time.time()-s,3))
# Add format tokens
# TODO calculate key value cache earlier and use it here to save inference time
tokens = self.llm.tokenize("\nOutput:")
for t in tokens[0]:
t = torch.unsqueeze(torch.unsqueeze(t, 0),0)
s = time.time()
self.last_token, self.past_key_values = self.llm(t, self.past_key_values, len(self.response_sequence))
if VERBOSE:
print(" -LLM (Format): ", round(time.time()-s,3))
self.last_token = self.llm.tokenize(" ")
# LLM generates a new token and adds it to the response sequence
if self.generating:
s = time.time()
self.last_token, self.past_key_values = self.llm(self.last_token, self.past_key_values, len(self.response_sequence))
if VERBOSE:
print(" -LLM (Generation): ", round(time.time()-s,3))
self.response_sequence.append(self.last_token)
self.response_sentence.append(self.last_token)
self.last_token = torch.unsqueeze(self.last_token, 0)
# check for stopping conditions
# TODO improve check for end of sentence or end of sequence
end_of_sentence = self.last_token.eq(self.llm.sentence_stop_token_id_tensor)
end_of_sequence = self.last_token.eq(self.llm.eos_token_id_tensor)
if end_of_sentence or len(self.response_sentence)>=50:
# --> End of Sentence
if time.time() - self.last_token_timestep > 0.3:
response = self.llm.detokenize(self.response_sentence)
response = "".join(response)
# only reset current sentence not the entire response sequence
self.response_sentence = []
# synthesize generated sequence
wav = self.tts(response)
# calculate and print latency
latency = time.time()-self.start_generation_timestep
print("## Total Latency: ", round(latency,3))
# reset timer
self.start_generation_timestep = time.time()
# return generated sentence and timestep for time measurement
return response, wav, False
if end_of_sequence:
# --> End of Sequence
# This if statement prevents the assistant to return a sequence to early when the user is
# still speaking TODO find a better solution to deal with early eos tokens
if time.time() - self.last_token_timestep > 0.3:
# reset state of the voice assistnat (but keep key-value cache for context)
self.first = True
self.generating = False
self.current_word = ""
self.transcribing = True
self.last_token_timestep = None
# detokenize current sentence
response = self.llm.detokenize(self.response_sentence)
# reset sentence and sequenc
self.response_sequence = []
self.response_sentence = []
response = "".join(response)
# Sometimes the previous sentence was already the end of the sequence but the EOS
# token is generated after the end of sentence token. If the sequence ends but the current
# sentence is too short, nothing is returned and no speech is synthesized
if len(response) > 3:
# synthesize generated sequence
wav = self.tts(response)
# calculate and print latency
latency = time.time()-self.start_generation_timestep
print("## Total Latency: ", round(latency,3))
self.start_generation_timestep = None
return response, wav, False
else:
self.start_generation_timestep = None
else:
# --> not in generation mode handle new transcribed token
# skip if there was no full audio chunk in the buffer
if processed_signal is None:
self.last_return = time.time()
return None, None, False
if len(transcribed_text)>0:
# reset timer for last recognized word / subword
self.last_token_timestep = time.time()
print(transcribed_text)
# First STT run is very slow, show Start prompt afterwards and add format tokens to LLM
if self.first:
self.first = False
print("\n\n ## Listening... ")
# Add format token
# TODO precompute this earlier
tokens = self.llm.tokenize("\nInstruct:")
for t in tokens[0]:
t = torch.unsqueeze(torch.unsqueeze(t, 0),0)
s = time.time()
self.last_token, self.past_key_values = self.llm(t, self.past_key_values, len(self.response_sequence))
if VERBOSE:
print(" -LLM (Format): ", round(time.time()-s,3))
if (len(transcribed_text) == 0 or transcribed_text.startswith(" ")) and len(self.current_word)>0:
# --> new word
new_word = self.current_word
self.transcribed_words.append(new_word)
self.current_word = transcribed_text
self.counter = 0
# call LLM to process new input token
tokens = self.llm.tokenize(new_word)
for t in tokens[0]:
start_llm = time.time()
t = torch.unsqueeze(torch.unsqueeze(t, 0),0)
s=time.time()
self.last_token, self.past_key_values = self.llm(t, self.past_key_values, len(self.response_sequence))
if VERBOSE:
print(" -LLM (Input): ", round(time.time()-s,3))
# if new token is just a subword, add it to the current word
else:
self.current_word += transcribed_text
# If generated sentence/sequence is not finished --> Return none
self.last_return = time.time()
return None, None, False