diff --git a/.github/workflows/docker-build.yaml b/.github/workflows/docker-build.yaml index 40ee6e9..6231c1d 100644 --- a/.github/workflows/docker-build.yaml +++ b/.github/workflows/docker-build.yaml @@ -15,9 +15,17 @@ jobs: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v1 - - name: Build + - name: Build EasyVideoTrans service uses: docker/build-push-action@v2 with: context: . push: false - tags: hanfa/pytvzhen-web:${{github.event.pull_request.number}} + tags: hanfa/easyvideotrans:${{github.event.pull_request.number}} + + - name: Build EasyVideoTrans workloads + uses: docker/build-push-action@v2 + with: + context: . + file: Dockerfile-gpu-workload + push: false + tags: hanfa/easyvideotrans-workloads:${{github.event.pull_request.number}} diff --git a/.github/workflows/docker-release.yaml b/.github/workflows/docker-release.yaml index 88468e7..0e88f78 100644 --- a/.github/workflows/docker-release.yaml +++ b/.github/workflows/docker-release.yaml @@ -1,4 +1,4 @@ -name: Pytvzhen-web Docker Image Release +name: EasyVideoTrans Service Docker Image Release on: workflow_run: @@ -29,4 +29,4 @@ jobs: with: context: . push: true - tags: hanfa/pytvzhen-web:latest + tags: hanfa/easyvideotrans:latest diff --git a/.github/workflows/docker-workload-release.yaml b/.github/workflows/docker-workload-release.yaml new file mode 100644 index 0000000..ccbdccf --- /dev/null +++ b/.github/workflows/docker-workload-release.yaml @@ -0,0 +1,32 @@ +name: EasyVideoTrans Workloads Docker Image Release + +on: + workflow_run: + workflows: [ "Pytvzhen-web application test" ] + branches: [ "master" ] + types: + - completed + +jobs: + build: + runs-on: self-hosted + + steps: + - name: Check out code + uses: actions/checkout@v2 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + + - name: Login to DockerHub + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: hanfa/easyvideotrans-workloads:latest diff --git a/.gitignore b/.gitignore index 3f92276..b29cccd 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,5 @@ output/ !celery_results/* .DS_Store + +.pytest_cache diff --git a/Dockerfile b/Dockerfile index 35281a3..6eef609 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,8 @@ # Use an official NVIDIA runtime with CUDA and Miniconda as a parent image -FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime AS base +FROM python:3.9-slim AS base + +ENV PYTHONDONTWRITEBYTECODE 1 +ENV PYTHONUNBUFFERED 1 # Disable interactive debian ENV TZ=America/New_York \ @@ -15,7 +18,7 @@ COPY requirements.txt . # Install dependencies RUN pip install --upgrade pip -RUN pip install -r requirements.txt +RUN pip install --default-timeout=200 -r requirements.txt FROM base AS final @@ -27,11 +30,11 @@ COPY . /app COPY configs/supervisord.conf /etc/supervisor/conf.d/supervisord.conf # Set environment variables to configure Celery -ENV CELERY_BROKER_DOMAIN=localhost -ENV CELERY_BROKER_URL=pyamqp://guest@localhost:5672// -ENV CELERY_RESULT_BACKEND=file:///app/celery_results -ENV CELERY_WORKER_PREFETCH_MULTIPLIER=1 -ENV CELERY_TASK_ACKS_LATE=true +ENV CELERY_BROKER_DOMAIN localhost +ENV CELERY_BROKER_URL pyamqp://guest@localhost:5672// +ENV CELERY_RESULT_BACKEND file:///app/celery_results +ENV CELERY_WORKER_PREFETCH_MULTIPLIER 1 +ENV CELERY_TASK_ACKS_LATE true # Make port 8080 available to the world outside this container EXPOSE 8080 @@ -42,7 +45,7 @@ ENV FLASK_APP app.py ENV FLASK_DEBUG 0 ARG PYTVZHEN_STAGE=beta -ENV PYTVZHEN_STAGE=${PYTVZHEN_STAGE} +ENV PYTVZHEN_STAGE ${PYTVZHEN_STAGE} # Run supervisord to start both Flask and Celery CMD ["/usr/bin/supervisord"] diff --git a/Dockerfile-gpu-workload b/Dockerfile-gpu-workload new file mode 100644 index 0000000..bd5de5c --- /dev/null +++ b/Dockerfile-gpu-workload @@ -0,0 +1,30 @@ +FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-runtime AS base + +ENV PYTHONDONTWRITEBYTECODE 1 +ENV PYTHONUNBUFFERED 1 + +# Disable interactive debian +ENV TZ=America/New_York \ + DEBIAN_FRONTEND=noninteractive + +WORKDIR /app + +RUN apt-get update && apt-get install -y \ + ffmpeg \ + git \ + && rm -rf /var/lib/apt/lists/* + + +COPY workloads/requirements.txt /app/ + +RUN pip install --no-cache-dir -r requirements.txt + +COPY workloads /app/workloads/ +COPY src /app/src/ +COPY inference.py /app + +ENV LD_LIBRARY_PATH /opt/conda/lib/python3.11/site-packages/nvidia/cudnn/lib + +EXPOSE 8188 + +CMD ["python", "inference.py"] diff --git a/app.py b/app.py index a88c85b..4e87baf 100644 --- a/app.py +++ b/app.py @@ -4,11 +4,10 @@ import zipfile import shutil import uuid -from src.service.audio_processing.audio_remove import audio_remove -from src.service.audio_processing.transcribe_audio import transcribe_audio_en -from src.service.audio_processing.voice_connect import connect_voice -from src.service.translation import get_translator, srt_sentense_merge +from src.service.video_synthesis.voice_connect import connect_voice +from src.service.translation import get_translator from src.service.tts import get_tts_client +from src.workload_client import EasyVideoTransWorkloadClient from src.task_manager.celery_tasks.tasks import video_preview_task from src.task_manager.celery_tasks.celery_utils import get_queue_length from werkzeug.utils import secure_filename @@ -19,9 +18,10 @@ from prometheus_flask_exporter import PrometheusMetrics app = Flask(__name__, template_folder="./appendix/templates", static_folder="./appendix/static") -app.config.from_file("./configs/pytvzhen.json", load=json.load) +app.config.from_file("./configs/easyvideotrans.json", load=json.load) metrics = PrometheusMetrics(app) metrics.info('pytvzhen_web', 'Pytvzhen backend API', version='1.0.0') + PYTVZHEN_STAGE = 'PYTVZHEN_STAGE' pytvzhen_api_request_counter = metrics.counter( 'pytvzhen_api_request_counter', 'Request count by request paths', @@ -29,6 +29,12 @@ 'method': lambda: request.method, 'status': lambda r: r.status_code} ) +# Setup workloads client to submit any GPU workloads to EasyVideoTrans compute backend +gpu_workload = EasyVideoTransWorkloadClient( + audio_separation_endpoint=app.config['VOICE_BACKGROUND_SEPARATION_ENDPOINT'], + audio_transcribe_endpoint=app.config['AUDIO_TRANSCRIBE_ENDPOINT'], +) + def pytvzhen_stage(): return os.environ[PYTVZHEN_STAGE] if PYTVZHEN_STAGE in os.environ else 'default' @@ -283,9 +289,7 @@ def remove_audio_bg(video_id): f'not found at {output_path}, please extract it first')}), 404 try: - baseline_path = app.config['REMOVE_BACKGROUND_MUSIC_BASELINE_MODEL_PATH'] - audio_remove(audio_path, audio_no_bg_path, audio_bg_fn_path, baseline_path, - app.config['REMOVE_BACKGROUND_MUSIC_TORCH_DEVICE']) + audio_bg_fn_path, audio_no_bg_fn = gpu_workload.separate_audio(audio_fn) return jsonify({"message": log_info_return_str( f"Remove remove background music for {audio_fn} as {audio_no_bg_fn} and {audio_bg_fn_path} successfully."), "video_id": video_id}), 200 @@ -333,7 +337,6 @@ def audio_bg_serve(video_id): def transcribe(video_id): output_path = app.config['OUTPUT_PATH'] - transcribe_model = "medium" en_srt_fn, en_srt_merged_fn, audio_no_bg_fn = f'{video_id}_en.srt', f'{video_id}_en_merged.srt', f'{video_id}_no_bg.wav' en_srt_path, en_srt_merged_path, audio_no_bg_path = (os.path.join(output_path, en_srt_fn), @@ -351,10 +354,7 @@ def transcribe(video_id): f'not found at {audio_no_bg_path}, please extract it first')}), 404 try: - transcribe_audio_en(app.logger, path=audio_no_bg_path, modelName=transcribe_model, language="en", - srtFilePathAndName=en_srt_path) - srt_sentense_merge(app.logger, en_srt_path, en_srt_merged_path) - + gpu_workload.transcribe_audio(audio_no_bg_fn, [en_srt_fn, en_srt_merged_fn]) return jsonify({"message": log_info_return_str( f"Transcribed SRT from {audio_no_bg_fn} as {en_srt_fn} and {en_srt_merged_fn} successfully."), "video_id": video_id}), 200 diff --git a/configs/easyvideotrans.json b/configs/easyvideotrans.json new file mode 100644 index 0000000..ce0c40e --- /dev/null +++ b/configs/easyvideotrans.json @@ -0,0 +1,6 @@ +{ + "OUTPUT_PATH": "./output", + "VIDEO_MAX_DURATION": 3610, + "VOICE_BACKGROUND_SEPARATION_ENDPOINT": "http://localhost:8199/audio-sep", + "AUDIO_TRANSCRIBE_ENDPOINT": "http://localhost:8199/audio-transcribe" +} \ No newline at end of file diff --git a/configs/pytvzhen.json b/configs/pytvzhen.json deleted file mode 100644 index 13e966a..0000000 --- a/configs/pytvzhen.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "OUTPUT_PATH": "./output", - "VIDEO_MAX_DURATION": 3610, - "REMOVE_BACKGROUND_MUSIC_TORCH_DEVICE": "cuda:0", - "REMOVE_BACKGROUND_MUSIC_BASELINE_MODEL_PATH": "./models/baseline.pth" -} \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..c84bff9 --- /dev/null +++ b/inference.py @@ -0,0 +1,195 @@ +import os +import time +from pathlib import Path +import numpy as np +import soundfile as sf +import librosa +import torch +from functools import wraps + +from flask import Flask, request, jsonify +from prometheus_flask_exporter import PrometheusMetrics +from prometheus_client import Summary, Histogram, Gauge + +from workloads.lib.separator import Separator +from workloads.lib import spec_utils, nets +from workloads.lib.audio_processing.transcribe_audio import transcribe_audio_en +from workloads.lib.srt import srt_sentense_merge + +# Initialize the Flask app +app = Flask(__name__) + +# Integrate Prometheus metrics +metrics = PrometheusMetrics(app) +metrics.info("app_info", "EasyVideoTrans GPU Workloads Processing API", version="1.0.0") + +# Custom Prometheus metrics +INFERENCE_DURATION = Summary("inference_duration_seconds", "Time spent on inference") +TRANSCRIBE_DURATION = Summary("transcribe_duration_seconds", "Time spent on transcribe") +AUDIO_FILE_SIZE = Histogram("audio_file_size_bytes", "Size of input audio files", + buckets=[1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, + 2097152, 4194304, 8388608]) +CURRENT_INFERENCE = Gauge("current_inference", "Number of ongoing inferences") + +# Model setup from https://github.com/tsurumeso/vocal-remover/tree/develop +MODEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'workloads/pretrained_models') +DEFAULT_MODEL_PATH = os.path.join(MODEL_DIR, 'baseline.pth') + +model = nets.CascadedNet(n_fft=2048, hop_length=1024, nout=32, nout_lstm=128) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model.load_state_dict(torch.load(DEFAULT_MODEL_PATH, map_location=device)) +model.to(device) +separator = Separator(model, device, batchsize=4, + cropsize=256, + postprocess=False) + +# Setup input / output configurations +INPUT_DIR = "workloads/static/outputs" +OUTPUT_DIR = "workloads/static/outputs" +os.makedirs(OUTPUT_DIR, exist_ok=True) + + +def load_spectrogram(file_path): + X, sample_rate = librosa.load( + file_path, sr=44100, mono=False, dtype=np.float32, res_type='kaiser_fast' + ) + + if X.ndim == 1: + # mono to stereo + X = np.asarray([X, X]) + + x_spec = spec_utils.wave_to_spectrogram(X, hop_length=1024, n_fft=2048) + return x_spec, sample_rate + + +@app.route("/") +def index(): + """ + Health check endpoint. + """ + return jsonify({"message": "Speech Separation API is running."}), 200 + + +def require_filename_points_to_existing_file(func): + @wraps(func) + def decorated_func(*args, **kwargs): + + if not request.is_json: + return jsonify({"message": "Missing JSON in request"}), 400 + + data = request.get_json() + if not data or "file_name" not in data: + return jsonify({"error": "Invalid request. Please provide 'file_name' in the JSON payload."}), 400 + + # Get the file path from the payload + file_name = data["file_name"] + file_path = os.path.join(INPUT_DIR, file_name) + + if not os.path.exists(file_path): + return jsonify({"error": f"File not found: {file_path}"}), 404 + + return func(file_path, *args, **kwargs) + + return decorated_func + + +def require_output_filenames(func): + @wraps(func) + def decorated_func(file_path, *args, **kwargs): + data = request.get_json() + + if "output_filenames" not in data: + return jsonify({"error": "Invalid request. Please provide 'output_filenames' in the JSON payload."}), 400 + + output_filenames = data["output_filenames"] + output_filepaths = [os.path.join(OUTPUT_DIR, name) for name in output_filenames] + + return func(file_path, output_filepaths, *args, **kwargs) + + return decorated_func + + +@app.route("/audio-sep", methods=["POST"]) +@require_filename_points_to_existing_file +def audio_separation(file_path): + """ + Endpoint to perform audio separation. + Accepts an audio file and returns separated sources. + """ + + file_stem_name = Path(file_path).stem + + # Track the size of the input audio file + file_size = os.path.getsize(file_path) + AUDIO_FILE_SIZE.observe(file_size) + + # Perform source separation + app.logger.info(f"Processing file: {file_path}") + start_time = time.time() + CURRENT_INFERENCE.inc() # Increment the gauge for ongoing inferences + try: + x_spec, sample_rate = load_spectrogram(file_path) + app.logger.info(f"Done loading sound file: {file_path}") + + y_spec, v_spec = separator.separate_tta(x_spec) + + background_wave_fn, voice_wave_fn = f"{file_stem_name}_bg.wav", f"{file_stem_name}_no_bg.wav" + background_wave_path, voice_wave_path = os.path.join(OUTPUT_DIR, background_wave_fn), os.path.join( + OUTPUT_DIR, voice_wave_fn) + wave = spec_utils.spectrogram_to_wave(y_spec) + sf.write(background_wave_path, wave.T, int(sample_rate)) + app.logger.info(f"Done inversed stft for background, saved to: {background_wave_path}") + + wave = spec_utils.spectrogram_to_wave(v_spec) + sf.write(voice_wave_path, wave.T, int(sample_rate)) + app.logger.info(f"Done inversed stft for vocal, saved to: {voice_wave_path}") + + duration = time.time() - start_time + INFERENCE_DURATION.observe(duration) + CURRENT_INFERENCE.dec() # Decrement the gauge + + # Return the paths of the separated sources + response = { + "message": "Separation successful.", + "files": [background_wave_fn, voice_wave_fn], + "inference_duration_seconds": duration, + "input_audio_size_bytes": file_size, + } + return jsonify(response), 200 + except Exception as e: + print(f"Error during separation: {e}") + CURRENT_INFERENCE.dec() # Decrement the gauge in case of failure + return jsonify({"error": "An error occurred during audio separation."}), 500 + + +@app.route("/audio-transcribe", methods=["POST"]) +@require_filename_points_to_existing_file +@require_output_filenames +def audio_transcribe(file_path, output_filepaths): + app.logger.info(f"Transcribing file: {file_path}, output paths: {output_filepaths}") + + start_time = time.time() + CURRENT_INFERENCE.inc() # Increment the gauge for ongoing inferences + + try: + en_srt_path, en_srt_merged_path = output_filepaths + transcribe_audio_en(app.logger, path=file_path, modelName="medium", language="en", + srtFilePathAndName=en_srt_path) + srt_sentense_merge(app.logger, en_srt_path, en_srt_merged_path) + + duration = time.time() - start_time + TRANSCRIBE_DURATION.observe(duration) + CURRENT_INFERENCE.dec() # Decrement the gauge + response = { + "message": "Transcribe successful.", + "transcribe_duration_seconds": duration, + } + return jsonify(response), 200 + except Exception as e: + print(f"Error during separation: {e}") + CURRENT_INFERENCE.dec() # Decrement the gauge in case of failure + return jsonify({"error": "An error occurred during audio transcribe."}), 500 + + +if __name__ == '__main__': + app.run(host="0.0.0.0", port=8199) diff --git a/inference_test.py b/inference_test.py new file mode 100644 index 0000000..189ed75 --- /dev/null +++ b/inference_test.py @@ -0,0 +1,68 @@ +import pytest +import numpy as np + +from unittest.mock import patch + +from inference import app, separator + + +@pytest.fixture +def client(): + """Fixture to create a test client.""" + app.config["TESTING"] = True + with app.test_client() as client: + yield client + + +@patch("workloads.inference.os.path.exists", return_value=True) +@patch("workloads.inference.os.path.getsize", return_value=1024) # Mock file size +@patch("workloads.inference.load_spectrogram", return_value=("mock_spectrogram", 44100)) +@patch.object(separator, "separate_tta", return_value=("mock_bg_spec", "mock_v_spec")) +@patch("workloads.lib.spec_utils.spectrogram_to_wave", return_value=np.array([[0.1, 0.2], [0.3, 0.4]])) +@patch("workloads.inference.sf.write") # Mock sound file write function +def test_audio_separation_success(mock_sf_write, mock_spec_to_wave, mock_separate, mock_load_spec, mock_getsize, + mock_exists, client): + """Test successful audio separation.""" + response = client.post("/audio-sep", json={"file_name": "audio.wav"}) + assert response.status_code == 200 + data = response.get_json() + assert "message" in data and data["message"] == "Separation successful." + assert "files" in data and len(data["files"]) == 2 + assert "inference_duration_seconds" in data + assert "input_audio_size_bytes" in data and data["input_audio_size_bytes"] == 1024 + + +def test_audio_separation_missing_file_path(client): + """Test when 'file_path' is missing in the request.""" + response = client.post("/audio-sep", json={}) + assert response.status_code == 400 + data = response.get_json() + assert "error" in data and "Invalid request" in data["error"] + + +@patch("workloads.inference.os.path.exists", return_value=False) +def test_audio_separation_file_not_found(mock_exists, client): + """Test when the provided file path does not exist.""" + response = client.post("/audio-sep", json={"file_name": "invalid_path.wav"}) + assert response.status_code == 404 + data = response.get_json() + assert "error" in data and "File not found" in data["error"] + + +@patch("workloads.inference.os.path.exists", return_value=True) +@patch("workloads.inference.os.path.getsize", return_value=100) +@patch("workloads.inference.load_spectrogram", side_effect=Exception("Spectrogram error")) +def test_audio_separation_internal_error(mock_load_spectrogram, mock_getsize, mock_exists, client): + """Test when an internal error occurs during processing.""" + response = client.post("/audio-sep", json={"file_name": "audio.wav"}) + assert response.status_code == 500 + data = response.get_json() + assert "error" in data and "An error occurred during audio separation." in data["error"] + + +def test_health_check(client): + """Test the health check endpoint.""" + response = client.get("/") + assert response.status_code == 200 + data = response.get_json() + assert "message" in data and data["message"] == "Speech Separation API is running." diff --git a/requirements.txt b/requirements.txt index 63845d1..88cb0f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,151 +1,18 @@ -aiohttp==3.9.3 -aiosignal==1.3.1 -annotated-types==0.6.0 -anyio==4.3.0 -asttokens==2.4.1 -async-timeout==4.0.3 -attrs==23.2.0 -audioread==3.0.1 -av==12.0.0 -backcall==0.2.0 -beautifulsoup4==4.12.3 -bleach==6.1.0 -blinker==1.7.0 -certifi==2024.2.2 -cffi==1.16.0 -chardet==5.2.0 -charset-normalizer==3.3.2 -click==8.1.7 -colorama==0.4.6 -coloredlogs==15.0.1 -ctranslate2==4.2.1 -decorator==4.4.2 deepl==1.17.0 -defusedxml==0.7.1 -distro==1.9.0 -docopt==0.6.2 -easygui==0.98.2 edge-tts==6.1.10 -exceptiongroup==1.2.0 -executing==2.0.1 -faster-whisper==1.0.2 -fastjsonschema==2.19.1 -filelock==3.13.1 flake8==7.0.0 Flask==3.0.3 -flatbuffers==24.3.25 -frozenlist==1.4.1 -fsspec==2024.3.1 -geckodriver-autoinstaller==0.1.0 -h11==0.14.0 -httpcore==1.0.4 -httpx==0.27.0 -huggingface-hub==0.23.0 -humanfriendly==10.0 -idna==3.6 -imageio==2.34.0 -imageio-ffmpeg==0.4.9 -ipython==8.12.3 -itsdangerous==2.2.0 -jedi==0.19.1 -Jinja2==3.1.3 -joblib==1.3.2 -jsonschema==4.21.1 -jsonschema-specifications==2023.12.1 -jupyter_client==8.6.1 -jupyter_core==5.7.2 -jupyterlab_pygments==0.3.0 -lazy_loader==0.3 -librosa==0.10.1 -llvmlite==0.42.0 -MarkupSafe==2.1.5 -matplotlib-inline==0.1.6 -mistune==3.0.2 -more-itertools==10.2.0 moviepy==1.0.3 -mpmath==1.3.0 -msgpack==1.0.8 -multidict==6.0.5 -nbclient==0.10.0 -nbconvert==7.16.3 -nbformat==5.10.4 -networkx==3.2.1 -numba==0.59.1 -numpy==1.26.4 -onnxruntime==1.18.0 openai==1.14.2 openai-whisper==20231117 -opencv-python==4.9.0.80 -packaging==24.0 -pandocfilters==1.5.1 -parso==0.8.3 -pexpect==4.9.0 -pickleshare==0.7.5 -pillow==10.2.0 -pip==23.3.1 -pipreqs==0.5.0 -platformdirs==4.2.0 -pooch==1.8.1 -proglog==0.1.10 prometheus_client==0.20.0 prometheus-flask-exporter==0.23.0 -prompt-toolkit==3.0.43 -protobuf==5.26.1 -ptyprocess==0.7.0 -pure-eval==0.2.2 -pycodestyle==2.11.1 -pycparser==2.22 -pydantic==2.6.4 -pydantic_core==2.16.3 -pydub==0.25.1 -pyflakes==3.2.0 -Pygments==2.17.2 -pygtrans==1.5.3 -PySocks==1.7.1 -pysubs2==1.6.1 -python-dateutil==2.9.0.post0 git+https://github.com/sutro-planet/pytubefix.git@main -PyYAML==6.0.1 -pyzmq==26.0.2 -referencing==0.34.0 -regex==2023.12.25 requests==2.31.0 -rpds-py==0.18.0 -scikit-learn==1.4.1.post1 scipy==1.12.0 -selenium==3.141.0 -semantic-version==2.10.0 -setuptools==68.2.2 -setuptools-rust==1.9.0 -six==1.16.0 -sniffio==1.3.1 -soundfile==0.12.1 -soupsieve==2.5 -soxr==0.3.7 -srt==3.4.1 -srt-deepl==0.9.1 -stable-ts==2.15.9 -stack-data==0.6.3 -sympy==1.12 -tenacity==8.3.0 -threadpoolctl==3.4.0 -tiktoken==0.6.0 -tinycss2==1.2.1 -tokenizers==0.19.1 -tomli==2.0.1 -torch==2.2.2 -torchaudio==2.2.2 -torchvision==0.17.2 -tornado==6.4 -tqdm==4.66.2 -traitlets==5.14.2 -typing_extensions==4.10.0 -urllib3==1.26.5 -wcwidth==0.2.13 -webencodings==0.5.1 -Werkzeug==3.0.2 -wheel==0.41.2 -yarg==0.1.9 -yarl==1.9.4 pika==1.3.2 celery==5.4.0 +srt==3.4.1 +pydub==0.25.1 +pygtrans==1.5.3 +tenacity==8.3.0 diff --git a/src/service/translation/__init__.py b/src/service/translation/__init__.py index 6d7d1aa..9948ac1 100644 --- a/src/service/translation/__init__.py +++ b/src/service/translation/__init__.py @@ -2,15 +2,12 @@ from .deepl_translator import DeepLTranslator from .google_translator import GoogleTranslator from .gpt_translator import GPTTranslator -from .srt import srt_sentense_merge, srt_to_text __all__ = [ "Translator", "DeepLTranslator", "GoogleTranslator", "GPTTranslator", - "srt_sentense_merge", - "srt_to_text", "get_translator", ] diff --git a/src/service/audio_processing/voice_connect.py b/src/service/video_synthesis/voice_connect.py similarity index 97% rename from src/service/audio_processing/voice_connect.py rename to src/service/video_synthesis/voice_connect.py index cc0dd8a..d14cdf3 100644 --- a/src/service/audio_processing/voice_connect.py +++ b/src/service/video_synthesis/voice_connect.py @@ -44,7 +44,8 @@ def connect_voice(logger, sourceDir, outputAndPath, warningFilePath): audioEndPosition = audioPosition + audio.duration_seconds * 1000 + MIN_GAP_DURATION * 1000 audioNextPosition = voiceMapSrt[i + 1].start.total_seconds() * 1000 if audioNextPosition < audioEndPosition: - speedUp = (audio.duration_seconds * 1000 + MIN_GAP_DURATION * 1000) / (audioNextPosition - audioPosition) + speedUp = (audio.duration_seconds * 1000 + MIN_GAP_DURATION * 1000) / ( + audioNextPosition - audioPosition) seconds = audioPosition / 1000.0 timeStr = str(datetime.timedelta(seconds=seconds)) if speedUp > MAX_SPEED_UP: diff --git a/src/test_workload_client.py b/src/test_workload_client.py new file mode 100644 index 0000000..1047c71 --- /dev/null +++ b/src/test_workload_client.py @@ -0,0 +1,92 @@ +import unittest +from unittest.mock import patch, MagicMock +from src.workload_client import EasyVideoTransWorkloadClient, WorkloadResponseError # Import your class + + +class TestEasyVideoTransWorkloadClient(unittest.TestCase): + + def setUp(self): + """Set up a test instance of the client.""" + self.client = EasyVideoTransWorkloadClient( + audio_separation_endpoint="http://localhost:8199/audio-sep", + audio_transcribe_endpoint="http://localhost:8199/audio-transcribe" + ) + + @patch("requests.post") + def test_separate_audio_success(self, mock_post): + """Test successful audio separation.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"files": ["bg.wav", "voice.wav"]} + mock_post.return_value = mock_response + + bg_file, voice_file = self.client.separate_audio("test_audio.wav") + + self.assertEqual(bg_file, "bg.wav") + self.assertEqual(voice_file, "voice.wav") + mock_post.assert_called_once_with( + "http://localhost:8199/audio-sep", + json={"file_name": "test_audio.wav"}, + timeout=120 + ) + + @patch("requests.post") + def test_separate_audio_error_response(self, mock_post): + """Test error handling when API returns a non-200 response.""" + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_post.return_value = mock_response + + with self.assertRaises(WorkloadResponseError) as context: + self.client.separate_audio("test_audio.wav") + + self.assertEqual(context.exception.status_code, 500) + self.assertEqual(context.exception.message, "Internal Server Error") + + @patch("requests.post") + def test_separate_audio_invalid_response_format(self, mock_post): + """Test handling when the API returns an unexpected response format.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"files": ["bg.wav"]} # Only one file instead of two + mock_post.return_value = mock_response + + with self.assertRaises(WorkloadResponseError) as context: + self.client.separate_audio("test_audio.wav") + + self.assertEqual(context.exception.status_code, 500) + self.assertEqual(context.exception.message, "Invalid response format. Expected two separated files.") + + @patch("requests.post") + def test_transcribe_audio_success(self, mock_post): + """Test successful audio transcription.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + self.client.transcribe_audio("test_audio.wav", ["transcript.txt", "summary.txt"]) + + mock_post.assert_called_once_with( + "http://localhost:8199/audio-transcribe", + json={"file_name": "test_audio.wav", "output_filenames": ["transcript.txt", "summary.txt"]}, + timeout=60 + ) + + @patch("requests.post") + def test_transcribe_audio_error_response(self, mock_post): + """Test error handling when transcription API returns an error.""" + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.text = "Bad Request" + mock_post.return_value = mock_response + + with self.assertRaises(WorkloadResponseError) as context: + self.client.transcribe_audio("test_audio.wav", ["transcript.txt"]) + + self.assertEqual(context.exception.status_code, 400) + self.assertEqual(context.exception.message, "Bad Request") + + +if __name__ == "__main__": + unittest.main() diff --git a/src/workload_client.py b/src/workload_client.py new file mode 100644 index 0000000..f409c88 --- /dev/null +++ b/src/workload_client.py @@ -0,0 +1,51 @@ +import requests +from typing import Tuple, List + + +class WorkloadClientError(Exception): + """Base class for workloads client exceptions.""" + pass + + +class WorkloadResponseError(WorkloadClientError): + """Raised when the workloads backend returns an error response.""" + + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + super().__init__(f"Workload Error {status_code}: {message}") + + +class EasyVideoTransWorkloadClient: + + def __init__( + self, + audio_separation_endpoint="localhost:8199/audio_sep", + audio_transcribe_endpoint="localhost:8199/audio-transcribe" + ): + self.audio_separation_endpoint = audio_separation_endpoint + self.audio_transcribe_endpoint = audio_transcribe_endpoint + + def separate_audio(self, audio_filename: str) -> Tuple[str, str]: + payload = {"file_name": audio_filename} + + response = requests.post(self.audio_separation_endpoint, json=payload, timeout=180) + if response.status_code != 200: + raise WorkloadResponseError(response.status_code, response.text) + + response_data = response.json() + + separated_files = response_data.get("files", []) + if len(separated_files) != 2: + raise WorkloadResponseError(500, "Invalid response format. Expected two separated files.") + + bg_filename, voice_filename = separated_files + return bg_filename, voice_filename + + def transcribe_audio(self, audio_filename: str, output_filenames: List[str]) -> None: + payload = {"file_name": audio_filename, + "output_filenames": output_filenames} + + response = requests.post(self.audio_transcribe_endpoint, json=payload, timeout=180) + if response.status_code != 200: + raise WorkloadResponseError(response.status_code, response.text) diff --git a/work_space_tbd.py b/work_space_tbd.py deleted file mode 100644 index c41e03a..0000000 --- a/work_space_tbd.py +++ /dev/null @@ -1,484 +0,0 @@ -# from src.service.audio_processing.audio_remove import audio_remove -# # from tools_tbd.warning_file import WarningFile -# from src.data_models.workflow import Workflow -# import os -# import copy -# from pytube import YouTube -# from pytube.cli import on_progress -# import srt -# import requests -# from tqdm import tqdm -# from pydub import AudioSegment -# import asyncio -# import edge_tts -# import datetime -# from moviepy.editor import VideoFileClip -# import sys -# import traceback -# import tenacity -# -# PROXY = "" -# proxies = None -# TTS_MAX_TRY_TIMES = 16 -# CHATGPT_URL = "https://api.openai.com/v1/" -# GHATGPT_TERMS_FILE = "configs/gpt_terms.json" -# -# diagnosisLog = None -# executeLog = None -# -# # 默认utf-8编码 -# os.environ['PYTHONIOENCODING'] = 'utf-8' -# os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # 强制GPU版本cuda -# -# -# # def stringToVoice(url, string, outputFile): -# # data = { -# # "text": string, -# # "text_language": "zh" -# # } -# # response = requests.post(url, json=data) -# # if response.status_code != 200: -# # return False -# # -# # with open(outputFile, "wb") as f: -# # f.write(response.content) -# # -# # return True -# # -# # -# # def srtToVoice(url, srtFileNameAndPath, outputDir): -# # # create output directory if not exists -# # if not os.path.exists(outputDir): -# # os.makedirs(outputDir) -# # -# # srtContent = open(srtFileNameAndPath, "r", encoding="utf-8").read() -# # subGenerator = srt.parse(srtContent) -# # subTitleList = list(subGenerator) -# # index = 1 -# # fileNames = [] -# # print("Start to convert srt to voice") -# # with tqdm(total=len(subTitleList)) as pbar: -# # for subTitle in subTitleList: -# # string = subTitle.content -# # fileName = str(index) + ".wav" -# # outputNameAndPath = os.path.join(outputDir, fileName) -# # fileNames.append(fileName) -# # tryTimes = 0 -# # -# # while tryTimes < TTS_MAX_TRY_TIMES: -# # if not stringToVoice(url, string, outputNameAndPath): -# # return False -# # -# # # 获取outputNameAndPath的时间长度 -# # audio = AudioSegment.from_wav(outputNameAndPath) -# # duration = len(audio) -# # # 获取最大音量 -# # maxVolume = audio.max_dBFS -# # -# # # 如果音频长度小于500ms,则重试,应该是数据有问题了 -# # if duration > 600 and maxVolume > -15: -# # break -# # -# # tryTimes += 1 -# # -# # if tryTimes >= TTS_MAX_TRY_TIMES: -# # print(f"Warning Failed to convert {fileName} to voice.") -# # print(f"Convert {fileName} duration: {duration}ms, max volume: {maxVolume}dB") -# # -# # index += 1 -# # pbar.update(1) # update progress bar -# # -# # voiceMapSrt = copy.deepcopy(subTitleList) -# # for i in range(len(voiceMapSrt)): -# # voiceMapSrt[i].content = fileNames[i] -# # voiceMapSrtContent = srt.compose(voiceMapSrt) -# # voiceMapSrtFileAndPath = os.path.join(outputDir, "voiceMap.srt") -# # with open(voiceMapSrtFileAndPath, "w", encoding="utf-8") as f: -# # f.write(voiceMapSrtContent) -# # -# # srtAtitionalFile = os.path.join(outputDir, "zh.srt") -# # with open(srtAtitionalFile, "w", encoding="utf-8") as f: -# # f.write(srtContent) -# # -# # print("Convert srt to voice successfully") -# # return True -# -# -# -# if __name__ == "__main__": -# paramDirPathAndName = input("please input the path and name of the parameter file (json format), or press enter " -# "to skip\n") -# if paramDirPathAndName == "": -# paramDirPathAndName = "data/workflow/default_param_dict.json" -# -# # 检查paramDirPathAndName是否存在,是否为json文件 -# if not os.path.exists(paramDirPathAndName) or not os.path.isfile( -# paramDirPathAndName) or not paramDirPathAndName.endswith(".json"): -# print("Please select a valid parameter file.") -# exit(-1) -# -# workflow = Workflow(paramDirPathAndName) -# -# # TODO: change the proxy field to json -# proxies = None if not workflow.proxy else { -# 'http': f"{workflow.proxy}", -# 'https': f"{workflow.proxy}", -# 'socks5': f"{workflow.proxy}" -# } -# -# # create the working directory if it does not exist -# if not os.path.exists(workflow.work_path): -# os.makedirs(workflow.work_path) -# print(f"Directory {workflow.work_path} created.") -# -# # TODO: 日志系统需要改造 -# # logFileName = "diagnosis.log" -# # diagnosisLog = WarningFile(os.path.join(workPath, logFileName)) -# # # 执行日志文件的格式为excute_yyyyMMdd_HHmmss.log -# # logFileName = "execute_" + datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + ".log" -# # executeLog = WarningFile(os.path.join(workPath, logFileName)) -# # -# # nowString = str(datetime.datetime.now()) -# # executeLog.write(f"Start at: {nowString}") -# # executeLog.write("Params\n" + json.dumps(paramDict, indent=4) + "\n") -# -# # 下载视频 -# # voiceFileName = f"{videoId}.mp4" -# # viedoFileNameAndPath = os.path.join(workPath, voiceFileName) -# # -# # if paramDict["download video"]: -# # print(f"Downloading video {videoId} to {viedoFileNameAndPath}") -# # try: -# # # 如果已经有了,就不下载了 -# # if os.path.exists(viedoFileNameAndPath): -# # print(f"Video {videoId} already exists.") -# # executeLog.write("[WORK -] Skip downloading video.") -# # print("Now at: " + str(datetime.datetime.now())) -# # else: -# # yt = YouTube(f'https://www.youtube.com/watch?v={videoId}', proxies=proxies, -# # on_progress_callback=on_progress) -# # video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').asc().first() -# # video.download(output_path=workPath, filename=voiceFileName) -# # # go back to the script directory -# # executeLog.write( -# # f"[WORK o] Download video {videoId} to {viedoFileNameAndPath} whith {video.resolution}.") -# # except Exception as e: -# # logStr = f"[WORK x] Error: Program blocked while downloading video {videoId} to {viedoFileNameAndPath}." -# # executeLog.write(logStr) -# # error_str = traceback.format_exception_only(type(e), e)[-1].strip() -# # executeLog.write(error_str) -# # sys.exit(-1) -# # else: -# # logStr = "[WORK -] Skip downloading video." -# # executeLog.write(logStr) -# # -# # # try download more high-definition video -# # # 需要单独下载最高分辨率视频,因为pytube下载的1080p视频没音频 -# # voiceFhdFileName = f"{videoId}_fhd.mp4" -# # voiceFhdFileNameAndPath = os.path.join(workPath, voiceFhdFileName) -# # if paramDict["download fhd video"]: -# # try: -# # # 如果已经有了,就不下载了 -# # if os.path.exists(voiceFhdFileNameAndPath): -# # print(f"Video {videoId} already exists.") -# # executeLog.write("[WORK -] Skip downloading video.") -# # print("Now at: " + str(datetime.datetime.now())) -# # else: -# # print(f"Try to downloading more high-definition video {videoId} to {voiceFhdFileNameAndPath}") -# # yt = YouTube(f'https://www.youtube.com/watch?v={videoId}', proxies=proxies, -# # on_progress_callback=on_progress) -# # video = yt.streams.filter(progressive=False, file_extension='mp4').order_by('resolution').desc().first() -# # video.download(output_path=workPath, filename=voiceFhdFileName) -# # executeLog.write( -# # f"[WORK o] Download 1080p high-definition {videoId} to {voiceFhdFileNameAndPath} whith {video.resolution}.") -# # except Exception as e: -# # logStr = ( -# # f"[WORK x] Error: Program blocked while downloading high-definition video" -# # f" {videoId} to {voiceFhdFileNameAndPath} with {video.resolution}: {e}") -# # executeLog.write(logStr) -# # logStr = "Program will not exit for that the error is not critical." -# # executeLog.write(logStr) -# # else: -# # logStr = "[WORK -] Skip downloading high-definition video." -# # executeLog.write(logStr) -# # -# # # 打印当前系统时间 -# # print("Now at: " + str(datetime.datetime.now())) -# # -# # # 视频转声音提取 -# # audioFileName = f"{videoId}.wav" -# # audioFileNameAndPath = os.path.join(workPath, audioFileName) -# # if paramDict["extract audio"]: -# # # remove the audio file if it exists -# # print(f"Extracting audio from {viedoFileNameAndPath} to {audioFileNameAndPath}") -# # try: -# # video = VideoFileClip(viedoFileNameAndPath) -# # audio = video.audio -# # audio.write_audiofile(audioFileNameAndPath) -# # executeLog.write( -# # f"[WORK o] Extract audio from {viedoFileNameAndPath} to {audioFileNameAndPath} successfully.") -# # except Exception as e: -# # logStr = f"[WORK x] Error: Program blocked while extracting audio from {viedoFileNameAndPath} to {audioFileNameAndPath}." -# # executeLog.write(logStr) -# # error_str = traceback.format_exception_only(type(e), e)[-1].strip() -# # executeLog.write(error_str) -# # sys.exit(-1) -# # else: -# # logStr = "[WORK -] Skip extracting audio." -# # executeLog.write(logStr) -# # -# # # 去除音频中的音乐 -# # voiceName = videoId + "_voice.wav" -# # voiceNameAndPath = os.path.join(workPath, voiceName) -# # insturmentName = videoId + "_insturment.wav" -# # insturmentNameAndPath = os.path.join(workPath, insturmentName) -# # if paramDict["audio remove"]: -# # print(f"Removing music from {audioFileNameAndPath} to {voiceNameAndPath} and {insturmentNameAndPath}") -# # try: -# # audio_remove(audioFileNameAndPath, voiceNameAndPath, insturmentNameAndPath, audioRemoveModelNameAndPath, -# # "cuda:0") -# # executeLog.write( -# # f"[WORK o] Remove music from {audioFileNameAndPath} to {voiceNameAndPath} and {insturmentNameAndPath} successfully.") -# # except Exception as e: -# # logStr = (f"[WORK x] Error: Program blocked while removing music from {audioFileNameAndPath} " -# # f"to {voiceNameAndPath} and {insturmentNameAndPath}.") -# # executeLog.write(logStr) -# # error_str = traceback.format_exception_only(type(e), e)[-1].strip() -# # executeLog.write(error_str) -# # sys.exit(-1) -# # else: -# # logStr = "[WORK -] Skip removing music." -# # executeLog.write(logStr) -# # -# # # 语音转文字 -# # srtEnFileName = videoId + "_en.srt" -# # srtEnFileNameAndPath = os.path.join(workPath, srtEnFileName) -# # if paramDict["audio transcribe"]: -# # try: -# # print(f"Transcribing audio from {voiceNameAndPath} to {srtEnFileNameAndPath}") -# # transcribeAudioEn(voiceNameAndPath, paramDict["audio transcribe model"], "en", srtEnFileNameAndPath) -# # executeLog.write( -# # f"[WORK o] Transcribe audio from {voiceNameAndPath} to {srtEnFileNameAndPath} successfully.") -# # except Exception as e: -# # logStr = f"[WORK x] Error: Program blocked while transcribing audio from {voiceNameAndPath} to {srtEnFileNameAndPath}." -# # executeLog.write(logStr) -# # error_str = traceback.format_exception_only(type(e), e)[-1].strip() -# # executeLog.write(error_str) -# # sys.exit(-1) -# # else: -# # logStr = "[WORK -] Skip transcription." -# # executeLog.write(logStr) -# # -# # # 字幕语句合并 -# # srtEnFileNameMerge = videoId + "_en_merge.srt" -# # srtEnFileNameMergeAndPath = os.path.join(workPath, srtEnFileNameMerge) -# # if paramDict["srt merge"]: -# # try: -# # print(f"Merging sentences in {srtEnFileNameAndPath} to {srtEnFileNameMergeAndPath}") -# # srtSentanceMerge(srtEnFileNameAndPath, srtEnFileNameMergeAndPath) -# # executeLog.write( -# # f"[WORK o] Merge sentences in {srtEnFileNameAndPath} to {srtEnFileNameMergeAndPath} successfully.") -# # except Exception as e: -# # logStr = f"[WORK x] Error: Program blocked while merging sentences in {srtEnFileNameAndPath} to {srtEnFileNameMergeAndPath}." -# # executeLog.write(logStr) -# # error_str = traceback.format_exception_only(type(e), e)[-1].strip() -# # executeLog.write(error_str) -# # sys.exit(-1) -# # else: -# # logStr = "[WORK -] Skip sentence merge." -# # executeLog.write(logStr) -# # -# # # 英文字幕转文字 -# # tetEnFileName = videoId + "_en_merge.txt" -# # tetEnFileNameAndPath = os.path.join(workPath, tetEnFileName) -# # if paramDict["srt merge en to text"]: -# # try: -# # enText = srt_to_text(srtEnFileNameMergeAndPath) -# # print(f"Writing EN text to {tetEnFileNameAndPath}") -# # with open(tetEnFileNameAndPath, "w") as file: -# # file.write(enText) -# # executeLog.write(f"[WORK o] Write EN text to {tetEnFileNameAndPath} successfully.") -# # except Exception as e: -# # logStr = f"[WORK x] Error: Writing EN text to {tetEnFileNameAndPath} failed." -# # executeLog.write(logStr) -# # error_str = traceback.format_exception_only(type(e), e)[-1].strip() -# # executeLog.write(error_str) -# # # 这不是关键步骤,所以不退出程序 -# # logStr = "Program will not exit for that the error is not critical." -# # executeLog.write(logStr) -# # else: -# # logStr = "[WORK -] Skip writing EN text." -# # executeLog.write(logStr) -# # -# # # 字幕翻译 -# # srtZhFileName = videoId + "_zh_merge.srt" -# # srtZhFileNameAndPath = os.path.join(workPath, srtZhFileName) -# # if paramDict["srt merge translate"]: -# # try: -# # print(f"Translating subtitle from {srtEnFileNameMergeAndPath} to {srtZhFileNameAndPath}") -# # if paramDict["srt merge translate tool"] == "deepl": -# # if paramDict["srt merge translate key"] == "": -# # logStr = "[WORK x] Error: DeepL API key is not provided. Please provide it in the parameter file." -# # executeLog.write(logStr) -# # sys.exit(-1) -# # srtFileDeeplTran(srtEnFileNameMergeAndPath, srtZhFileNameAndPath, paramDict["srt merge translate key"]) -# # elif 'gpt' in paramDict["srt merge translate tool"]: -# # if paramDict['srt merge translate key'] == '': -# # logStr = "[WORK x] Error: GPT API key is not provided. Please provide it in the parameter file." -# # executeLog.write(logStr) -# # sys.exit(-1) -# # srtFileGPTTran(paramDict['srt merge translate tool'], -# # proxies, -# # srtEnFileNameMergeAndPath, -# # srtZhFileNameAndPath, -# # paramDict['srt merge translate key']) -# # else: -# # srtFileGoogleTran(srtEnFileNameMergeAndPath, srtZhFileNameAndPath) -# # executeLog.write( -# # f"[WORK o] Translate subtitle from {srtEnFileNameMergeAndPath} to {srtZhFileNameAndPath} successfully.") -# # except Exception as e: -# # logStr = f"[WORK x] Error: Program blocked while translating subtitle from {srtEnFileNameMergeAndPath} to {srtZhFileNameAndPath}." -# # executeLog.write(logStr) -# # error_str = traceback.format_exception_only(type(e), e)[-1].strip() -# # executeLog.write(error_str) -# # sys.exit(-1) -# # else: -# # logStr = "[WORK -] Skip subtitle translation." -# # executeLog.write(logStr) -# # -# # # 中文字幕转文字 -# # textZhFileName = videoId + "_zh_merge.txt" -# # textZhFileNameAndPath = os.path.join(workPath, textZhFileName) -# # if paramDict["srt merge zh to text"]: -# # try: -# # zhText = srt_to_text(srtZhFileNameAndPath) -# # print(f"Writing ZH text to {textZhFileNameAndPath}") -# # with open(textZhFileNameAndPath, "w", encoding="utf-8") as file: -# # file.write(zhText) -# # executeLog.write(f"[WORK o] Write ZH text to {textZhFileNameAndPath} successfully.") -# # except Exception as e: -# # logStr = f"[WORK x] Error: Writing ZH text to {textZhFileNameAndPath} failed." -# # executeLog.write(logStr) -# # error_str = traceback.format_exception_only(type(e), e)[-1].strip() -# # executeLog.write(error_str) -# # # 这不是关键步骤,所以不退出程序 -# # logStr = "Program will not exit for that the error is not critical." -# # executeLog.write(logStr) -# # else: -# # logStr = "[WORK -] Skip writing ZH text." -# # executeLog.write(logStr) -# # -# # # 字幕转语音 -# # ttsSelect = paramDict["TTS"] -# # voiceDir = os.path.join(workPath, videoId + "_zh_source") -# # voiceSrcSrtName = "zh.srt" -# # voiceSrcSrtNameAndPath = os.path.join(voiceDir, voiceSrcSrtName) -# # voiceSrcMapName = "voiceMap.srt" -# # voiceSrcMapNameAndPath = os.path.join(voiceDir, voiceSrcMapName) -# # if paramDict["srt to voice srouce"]: -# # try: -# # if ttsSelect == "GPT-SoVITS": -# # print(f"Converting subtitle to voice by GPT-SoVITS in {srtZhFileNameAndPath} to {voiceDir}") -# # voiceUrl = paramDict["TTS param"] -# # srtToVoice(voiceUrl, srtZhFileNameAndPath, voiceDir) -# # else: -# # charator = paramDict["TTS param"] -# # if charator == "": -# # srtToVoiceEdge(srtZhFileNameAndPath, voiceDir) -# # else: -# # srtToVoiceEdge(srtZhFileNameAndPath, voiceDir, charator) -# # print(f"Converting subtitle to voice by EdgeTTS in {srtZhFileNameAndPath} to {voiceDir}") -# # executeLog.write( -# # f"[WORK o] Convert subtitle to voice in {srtZhFileNameAndPath} to {voiceDir} successfully.") -# # except Exception as e: -# # logStr = f"[WORK x] Error: Program blocked while converting subtitle to voice in {srtZhFileNameAndPath} to {voiceDir}." -# # executeLog.write(logStr) -# # error_str = traceback.format_exception_only(type(e), e)[-1].strip() -# # executeLog.write(error_str) -# # sys.exit(-1) -# # else: -# # logStr = "[WORK -] Skip voice conversion." -# # executeLog.write(logStr) -# # -# # # 语音合并 -# # voiceConnectedName = videoId + "_zh.wav" -# # voiceConnectedNameAndPath = os.path.join(workPath, voiceConnectedName) -# # if paramDict["voice connect"]: -# # try: -# # print(f"Connecting voice in {voiceDir} to {voiceConnectedNameAndPath}") -# # ret = voiceConnect(voiceDir, voiceConnectedNameAndPath) -# # if ret: -# # executeLog.write(f"[WORK o] Connect voice in {voiceDir} to {voiceConnectedNameAndPath} successfully.") -# # else: -# # executeLog.write(f"[WORK x] Connect voice in {voiceDir} to {voiceConnectedNameAndPath} failed.") -# # sys.exit(-1) -# # except Exception as e: -# # logStr = f"[WORK x] Error: Program blocked while connecting voice in {voiceDir} to {voiceConnectedNameAndPath}." -# # executeLog.write(logStr) -# # error_str = traceback.format_exception_only(type(e), e)[-1].strip() -# # executeLog.write(error_str) -# # sys.exit(-1) -# # else: -# # logStr = "[WORK -] Skip voice connection." -# # executeLog.write(logStr) -# # -# # # 合成后的语音转文字 -# # srtVoiceFileName = videoId + "_zh.srt" -# # srtVoiceFileNameAndPath = os.path.join(workPath, srtVoiceFileName) -# # if paramDict["audio zh transcribe"]: -# # try: -# # if os.path.exists(srtVoiceFileNameAndPath): -# # print("srtVoiceFileNameAndPath exists.") -# # else: -# # print(f"Transcribing audio from {voiceConnectedNameAndPath} to {srtVoiceFileNameAndPath}") -# # transcribeAudioZh(voiceConnectedNameAndPath, paramDict["audio zh transcribe model"], "zh", -# # srtVoiceFileNameAndPath) -# # executeLog.write( -# # f"[WORK o] Transcribe audio from {voiceConnectedNameAndPath} to {srtVoiceFileNameAndPath} successfully.") -# # except Exception as e: -# # logStr = f"[WORK x] Error: Program blocked while transcribing audio from {voiceConnectedNameAndPath} to {srtVoiceFileNameAndPath}." -# # executeLog.write(logStr) -# # error_str = traceback.format_exception_only(type(e), e)[-1].strip() -# # executeLog.write(error_str) -# # sys.exit(-1) -# # else: -# # logStr = "[WORK -] Skip transcription." -# # executeLog.write(logStr) -# # -# # # 合成预览视频 -# # previewVideoName = videoId + "_preview.mp4" -# # previewVideoNameAndPath = os.path.join(workPath, previewVideoName) -# # if paramDict["video zh preview"]: -# # try: -# # sourceVideoNameAndPath = "" -# # if os.path.exists(voiceFhdFileNameAndPath): -# # sourceVideoNameAndPath = voiceFhdFileNameAndPath -# # elif os.path.exists(viedoFileNameAndPath): -# # print( -# # f"Cannot find high-definition video, use low-definition video {viedoFileNameAndPath} -# for preview video {previewVideoNameAndPath}") -# # sourceVideoNameAndPath = viedoFileNameAndPath -# # else: -# # logStr = f"[WORK x] Error: Cannot find source video for preview video {previewVideoNameAndPath}." -# # executeLog.write(logStr) -# # sys.exit(-1) -# # -# # print(f"Generating zh preview video in {previewVideoNameAndPath}") -# # zhVideoPreview(sourceVideoNameAndPath, voiceConnectedNameAndPath, insturmentNameAndPath, -# # srtVoiceFileNameAndPath, previewVideoNameAndPath) -# # executeLog.write(f"[WORK o] Generate zh preview video in {previewVideoNameAndPath} successfully.") -# # except Exception as e: -# # logStr = f"[WORK x] Error: Program blocked while generating zh preview video in {previewVideoNameAndPath}." -# # executeLog.write(logStr) -# # error_str = traceback.format_exception_only(type(e), e)[-1].strip() -# # executeLog.write(error_str) -# # sys.exit(-1) -# # else: -# # logStr = "[WORK -] Skip zh preview video." -# # executeLog.write(logStr) -# # -# # executeLog.write("All done!!") -# # print("dir: " + workPath) -# # -# # # push any key to exit -# # input("Press any key to exit...") diff --git a/workloads/.dockerignore b/workloads/.dockerignore new file mode 100644 index 0000000..395f665 --- /dev/null +++ b/workloads/.dockerignore @@ -0,0 +1,4 @@ +*.wav +.keep +.DS_Store +.pytest_cache diff --git a/models/__init__.py b/workloads/__init__.py similarity index 100% rename from models/__init__.py rename to workloads/__init__.py diff --git a/models/audio_removal_model/__init__.py b/workloads/lib/__init__.py similarity index 100% rename from models/audio_removal_model/__init__.py rename to workloads/lib/__init__.py diff --git a/src/service/audio_processing/__init__.py b/workloads/lib/audio_processing/__init__.py similarity index 100% rename from src/service/audio_processing/__init__.py rename to workloads/lib/audio_processing/__init__.py diff --git a/src/service/audio_processing/transcribe_audio.py b/workloads/lib/audio_processing/transcribe_audio.py similarity index 100% rename from src/service/audio_processing/transcribe_audio.py rename to workloads/lib/audio_processing/transcribe_audio.py diff --git a/models/audio_removal_model/dataset.py b/workloads/lib/dataset.py similarity index 99% rename from models/audio_removal_model/dataset.py rename to workloads/lib/dataset.py index f11fdc1..641135d 100644 --- a/models/audio_removal_model/dataset.py +++ b/workloads/lib/dataset.py @@ -6,7 +6,7 @@ from tqdm import tqdm try: - from models.audio_removal_model import spec_utils + from workloads.lib import spec_utils except ModuleNotFoundError: import spec_utils diff --git a/models/audio_removal_model/layers.py b/workloads/lib/layers.py similarity index 98% rename from models/audio_removal_model/layers.py rename to workloads/lib/layers.py index 3d9e4f5..0197961 100644 --- a/models/audio_removal_model/layers.py +++ b/workloads/lib/layers.py @@ -2,7 +2,7 @@ from torch import nn import torch.nn.functional as F -from models.audio_removal_model import spec_utils +from workloads.lib import spec_utils class Conv2DBNActiv(nn.Module): diff --git a/models/audio_removal_model/nets.py b/workloads/lib/nets.py similarity index 98% rename from models/audio_removal_model/nets.py rename to workloads/lib/nets.py index 188facf..dc61303 100644 --- a/models/audio_removal_model/nets.py +++ b/workloads/lib/nets.py @@ -2,7 +2,7 @@ from torch import nn import torch.nn.functional as F -from models.audio_removal_model import layers +from workloads.lib import layers class BaseNet(nn.Module): diff --git a/src/service/audio_processing/audio_remove.py b/workloads/lib/separator.py similarity index 59% rename from src/service/audio_processing/audio_remove.py rename to workloads/lib/separator.py index 67d424e..abc805b 100644 --- a/src/service/audio_processing/audio_remove.py +++ b/workloads/lib/separator.py @@ -1,14 +1,9 @@ -import librosa import numpy as np -import soundfile as sf import torch from tqdm import tqdm -from models.audio_removal_model import nets, dataset, spec_utils - -AUDIO_REMOVE_DEVICE = "gpu" -AUDIO_REMOVE_FFT_SIZE = 2048 -AUDIO_REMOVE_HOP_SIZE = 1024 +from workloads.lib import dataset +from workloads.lib import spec_utils class Separator(object): @@ -98,56 +93,3 @@ def separate_tta(self, X_spec): y_spec, v_spec = self._postprocess(X_spec, mask) return y_spec, v_spec - - -def audio_remove(audioFileNameAndPath, voiceFileNameAndPath, instrumentFileNameAndPath, modelNameAndPath, - pytorchDevice): - if pytorchDevice not in ["cpu", "cuda:0"]: - raise ValueError("Invalid device: {}, valid choices are cpu or cuda:0. ".format(AUDIO_REMOVE_DEVICE)) - - device = torch.device(pytorchDevice) - - print("Loading model " + pytorchDevice) - model = nets.CascadedNet(AUDIO_REMOVE_FFT_SIZE, AUDIO_REMOVE_HOP_SIZE, 32, 128) # 模型参数 - model.load_state_dict(torch.load(modelNameAndPath, map_location='cpu')) - model.to(device) - print("Model loaded") - - print('loading wave source ' + audioFileNameAndPath) - X, sr = librosa.load( - audioFileNameAndPath, sr=44100, mono=False, dtype=np.float32, res_type='kaiser_fast' - ) - print("Wave source loaded") - - if X.ndim == 1: - # mono to stereo - X = np.asarray([X, X]) - - print('stft of wave source...', end=' ') - X_spec = spec_utils.wave_to_spectrogram(X, AUDIO_REMOVE_HOP_SIZE, AUDIO_REMOVE_FFT_SIZE) - print('done') - - sp = Separator( - model=model, - device=device, - batchsize=4, - cropsize=256, - postprocess=False - ) - - y_spec, v_spec = sp.separate_tta(X_spec) - print('inverse stft of instruments...', end=' ') - wave = spec_utils.spectrogram_to_wave(y_spec, AUDIO_REMOVE_HOP_SIZE) - print('done') - sf.write(instrumentFileNameAndPath, wave.T, sr) - - print('inverse stft of vocals...', end=' ') - wave = spec_utils.spectrogram_to_wave(v_spec, hop_length=AUDIO_REMOVE_HOP_SIZE) - print('done') - sf.write(voiceFileNameAndPath, wave.T, sr) - - -if __name__ == '__main__': - audio_remove("d:\\document\\AI_Work\\whisper\\videos\\proxy\\RXXRguaHZs0.wav", - "d:\\document\\AI_Work\\whisper\\videos\\proxy\\RXXRguaHZs0_voice.wav", - "d:\\document\\AI_Work\\whisper\\videos\\proxy\\RXXRguaHZs0_instrument.wav") diff --git a/models/audio_removal_model/spec_utils.py b/workloads/lib/spec_utils.py similarity index 100% rename from models/audio_removal_model/spec_utils.py rename to workloads/lib/spec_utils.py diff --git a/src/service/translation/srt.py b/workloads/lib/srt.py similarity index 100% rename from src/service/translation/srt.py rename to workloads/lib/srt.py diff --git a/models/audio_removal_model/utils.py b/workloads/lib/utils.py similarity index 100% rename from models/audio_removal_model/utils.py rename to workloads/lib/utils.py diff --git a/models/baseline.pth b/workloads/pretrained_models/baseline.pth similarity index 100% rename from models/baseline.pth rename to workloads/pretrained_models/baseline.pth diff --git a/workloads/requirements.txt b/workloads/requirements.txt new file mode 100644 index 0000000..9e6e5f3 --- /dev/null +++ b/workloads/requirements.txt @@ -0,0 +1,12 @@ +prometheus_client +prometheus-flask-exporter +librosa~=0.10.0 +matplotlib~=3.8.0 +opencv_python~=4.8.0 +resampy~=0.4.0 +tqdm~=4.66.0 +numpy~=1.26.4 +pytest +srt==3.4.1 +srt-deepl==0.9.1 +faster-whisper==1.1.1 diff --git a/models/.gitkeep b/workloads/static/outputs/.keep similarity index 100% rename from models/.gitkeep rename to workloads/static/outputs/.keep