Skip to content

Commit

Permalink
feat: use torchaudio instead of soundfile
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Dec 12, 2024
1 parent 7df2da2 commit e96be18
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
16 changes: 11 additions & 5 deletions docs/source/pretrained/whisper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import einops
from transformers import AutomaticSpeechRecognitionPipeline, pipeline

from todd.utils import get_audio
Expand All @@ -6,6 +7,8 @@
'https://github.com/SWivid/F5-TTS/raw/refs/heads/main/'
'src/f5_tts/infer/examples/basic/basic_ref_zh.wav',
)
audio_array = audio.numpy()
audio_array = einops.rearrange(audio_array, '1 t -> t')

pipe: AutomaticSpeechRecognitionPipeline = pipeline(
'automatic-speech-recognition',
Expand All @@ -14,17 +17,20 @@
device_map='auto',
)

result = pipe(audio)
result = pipe(audio_array)
print(result)

result = pipe(audio, generate_kwargs=dict(language='zh'))
result = pipe(audio_array, generate_kwargs=dict(language='zh'))
print(result)

result = pipe(audio, generate_kwargs=dict(task='translate', language='en'))
result = pipe(
audio_array,
generate_kwargs=dict(task='translate', language='en'),
)
print(result)

result = pipe(audio, return_timestamps=True)
result = pipe(audio_array, return_timestamps=True)
print(result)

result = pipe(audio, return_timestamps='word')
result = pipe(audio_array, return_timestamps='word')
print(result)
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ dependencies = [
'opencv-python',
'python-pptx',
'pycocotools',
'soundfile',
'tensorboard',
'timm',
'toml',
Expand Down Expand Up @@ -163,7 +162,6 @@ module = [
'pptx.*',
'scipy.*',
'setuptools.*',
'soundfile.*',
'torchvision.*',
'transformers.*',
'yapf.*',
Expand Down
7 changes: 4 additions & 3 deletions todd/utils/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import numpy as np
import numpy.typing as npt
import requests
import soundfile as sf
import torch
import torchaudio
from PIL import Image


Expand All @@ -25,5 +26,5 @@ def get_image(url: str) -> npt.NDArray[np.uint8]:
return np.array(image)


def get_audio(url: str) -> tuple[npt.NDArray[np.float64], int]:
return sf.read(get_bytes(url))
def get_audio(url: str) -> tuple[torch.Tensor, int]:
return torchaudio.load(get_bytes(url))

0 comments on commit e96be18

Please sign in to comment.