diff --git a/ct2_gui.py b/ct2_gui.py index 1b64fe2..8901052 100644 --- a/ct2_gui.py +++ b/ct2_gui.py @@ -1,13 +1,7 @@ -from PySide6.QtWidgets import ( - QApplication, QWidget, QVBoxLayout, QPushButton, QLabel, - QComboBox, QHBoxLayout, QGroupBox, QMessageBox -) +from PySide6.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox, QHBoxLayout, QGroupBox from PySide6.QtCore import Qt from ct2_logic import VoiceRecorder import yaml -import logging - -logger = logging.getLogger(__name__) class MyWindow(QWidget): def __init__(self, cuda_available=False): @@ -25,10 +19,9 @@ def __init__(self, cuda_available=False): config = yaml.safe_load(f) model = config.get("model_name", "base.en") quantization = config.get("quantization_type", "int8") - device = config.get("device_type", "cpu") + device = config.get("device_type", "auto") self.supported_quantizations = config.get("supported_quantizations", {"cpu": [], "cuda": []}) except FileNotFoundError: - logger.warning("config.yaml not found. Using default settings.") model, quantization, device = "base.en", "int8", "cpu" self.supported_quantizations = {"cpu": [], "cuda": []} @@ -37,7 +30,7 @@ def __init__(self, cuda_available=False): layout.addWidget(self.record_button) self.stop_button = QPushButton("Stop and Transcribe", self) - self.stop_button.clicked.connect(self.recorder.save_audio) + self.stop_button.clicked.connect(self.recorder.stop_recording) layout.addWidget(self.stop_button) settings_group = QGroupBox("Settings") @@ -128,7 +121,8 @@ def set_widgets_enabled(self, enabled): self.quantization_dropdown.setEnabled(enabled) self.device_dropdown.setEnabled(enabled) self.update_model_btn.setEnabled(enabled) - if not enabled: - QApplication.setOverrideCursor(Qt.WaitCursor) - else: - QApplication.restoreOverrideCursor() + + def closeEvent(self, event): + if hasattr(self, 'recorder'): + self.recorder.stop_all_threads() + super().closeEvent(event) diff --git a/ct2_logic.py b/ct2_logic.py index 4c13a38..d75b345 100644 --- a/ct2_logic.py +++ b/ct2_logic.py @@ -2,13 +2,16 @@ import numpy as np import wave import os -import tempfile +import psutil from PySide6.QtWidgets import QApplication -from PySide6.QtCore import QObject, Signal, Slot, QThread +from PySide6.QtCore import QObject, Signal, Slot, QThread, QMutex, QWaitCondition from faster_whisper import WhisperModel import yaml -import threading import logging +import tempfile +from contextlib import contextmanager +from pathlib import Path +import queue logger = logging.getLogger(__name__) @@ -24,17 +27,27 @@ def __init__(self, model_name, quantization_type, device_type): def run(self): try: + if self.isInterruptionRequested(): + return + if self.model_name.startswith("distil-whisper"): model_str = f"ctranslate2-4you/{self.model_name}-ct2-{self.quantization_type}" else: model_str = f"ctranslate2-4you/whisper-{self.model_name}-ct2-{self.quantization_type}" + if self.isInterruptionRequested(): + return + model = WhisperModel( model_str, device=self.device_type, compute_type=self.quantization_type, - cpu_threads=26 + cpu_threads=psutil.cpu_count(logical=False) ) + + if self.isInterruptionRequested(): + return + self.model_loaded.emit(model, self.model_name) except Exception as e: error_message = f"Error loading model: {str(e)}" @@ -52,7 +65,14 @@ def __init__(self, model, audio_file): def run(self): try: + if self.isInterruptionRequested(): + return + segments, _ = self.model.transcribe(self.audio_file) + + if self.isInterruptionRequested(): + return + clipboard_text = "\n".join([segment.text for segment in segments]) self.transcription_done.emit(clipboard_text) except Exception as e: @@ -61,10 +81,58 @@ def run(self): self.error_occurred.emit(error_message) finally: try: - os.remove(self.audio_file) + Path(self.audio_file).unlink(missing_ok=True) except OSError as e: logger.warning(f"Error deleting temporary file: {e}") +class RecordingThread(QThread): + update_status_signal = Signal(str) + recording_error = Signal(str) + recording_finished = Signal() + + def __init__(self, samplerate, channels, dtype): + super().__init__() + self.samplerate = samplerate + self.channels = channels + self.dtype = dtype + self.is_recording = QWaitCondition() + self.mutex = QMutex() + self.buffer = queue.Queue() + + @contextmanager + def audio_stream(self): + stream = sd.InputStream(samplerate=self.samplerate, channels=self.channels, dtype=self.dtype, callback=self.audio_callback) + try: + with stream: + yield + finally: + stream.close() + + def audio_callback(self, indata, frames, time, status): + if status: + logger.warning(status) + self.buffer.put(indata.copy()) + + def run(self): + self.mutex.lock() + self.update_status_signal.emit("Recording...") + + try: + with self.audio_stream(): + while not self.isInterruptionRequested(): + self.is_recording.wait(self.mutex) + except Exception as e: + error_message = f"Recording error: {e}" + logger.error(error_message) + self.recording_error.emit(error_message) + finally: + self.mutex.unlock() + self.recording_finished.emit() + + def stop(self): + self.requestInterruption() + self.is_recording.wakeAll() + class VoiceRecorder(QObject): update_status_signal = Signal(str) enable_widgets_signal = Signal(bool) @@ -75,15 +143,14 @@ def __init__(self, window, samplerate=44100, channels=1, dtype='int16'): self.channels = channels self.dtype = dtype self.window = window - self.is_recording = False - self.frames = [] self.model = None - self.model_lock = threading.Lock() + self.model_mutex = QMutex() self.load_settings() def load_settings(self): + config_path = Path("config.yaml") try: - with open("config.yaml", "r") as f: + with config_path.open("r") as f: config = yaml.safe_load(f) model_name = config.get("model_name", "base.en") quantization_type = config.get("quantization_type", "int8") @@ -99,13 +166,14 @@ def save_settings(self, model_name, quantization_type, device_type): "quantization_type": quantization_type, "device_type": device_type } - with open("config.yaml", "w") as f: + config_path = Path("config.yaml") + with config_path.open("w") as f: yaml.safe_dump(config, f) def update_model(self, model_name, quantization_type, device_type): self.enable_widgets_signal.emit(False) self.update_status_signal.emit(f"Updating model to {model_name}...") - + self.model_loader_thread = ModelLoaderThread(model_name, quantization_type, device_type) self.model_loader_thread.model_loaded.connect(self.on_model_loaded) self.model_loader_thread.error_occurred.connect(self.on_model_load_error) @@ -113,8 +181,9 @@ def update_model(self, model_name, quantization_type, device_type): @Slot(object, str) def on_model_loaded(self, model, model_name): - with self.model_lock: - self.model = model + self.model_mutex.lock() + self.model = model + self.model_mutex.unlock() self.save_settings(model_name, self.model_loader_thread.quantization_type, self.model_loader_thread.device_type) self.update_status_signal.emit(f"Model updated to {model_name} on {self.model_loader_thread.device_type} device") self.enable_widgets_signal.emit(True) @@ -126,12 +195,14 @@ def on_model_load_error(self, error_message): def transcribe_audio(self, audio_file): self.update_status_signal.emit("Transcribing audio...") - with self.model_lock: - if self.model is None: - self.update_status_signal.emit("No model loaded.") - self.enable_widgets_signal.emit(True) - return - model = self.model + self.model_mutex.lock() + if self.model is None: + self.model_mutex.unlock() + self.update_status_signal.emit("No model loaded.") + self.enable_widgets_signal.emit(True) + return + model = self.model + self.model_mutex.unlock() self.transcription_thread = TranscriptionThread(model, audio_file) self.transcription_thread.transcription_done.connect(self.on_transcription_done) @@ -149,34 +220,48 @@ def on_transcription_error(self, error_message): self.update_status_signal.emit(error_message) self.enable_widgets_signal.emit(True) - def record_audio(self): - self.update_status_signal.emit("Recording...") - def callback(indata, frames, time, status): - if status: - logger.warning(status) - self.frames.append(indata.copy()) - try: - with sd.InputStream(samplerate=self.samplerate, channels=self.channels, dtype=self.dtype, callback=callback): - while self.is_recording: - sd.sleep(100) - except Exception as e: - error_message = f"Recording error: {e}" - logger.error(error_message) - self.update_status_signal.emit(error_message) - self.enable_widgets_signal.emit(True) + @Slot(str) + def on_recording_error(self, error_message): + self.update_status_signal.emit(error_message) + self.enable_widgets_signal.emit(True) + + @Slot() + def on_recording_finished(self): + self.save_audio() + + def start_recording(self): + if not hasattr(self, 'recording_thread') or not self.recording_thread.isRunning(): + self.recording_thread = RecordingThread(self.samplerate, self.channels, self.dtype) + self.recording_thread.update_status_signal.connect(self.update_status_signal) + self.recording_thread.recording_error.connect(self.on_recording_error) + self.recording_thread.recording_finished.connect(self.on_recording_finished) + self.recording_thread.start() + else: + self.update_status_signal.emit("Already recording.") + + def stop_recording(self): + if hasattr(self, 'recording_thread') and self.recording_thread.isRunning(): + self.recording_thread.stop() + else: + self.update_status_signal.emit("Not currently recording.") def save_audio(self): - self.is_recording = False self.enable_widgets_signal.emit(False) - temp_filename = tempfile.mktemp(suffix=".wav") - data = np.concatenate(self.frames, axis=0) + audio_data = [] + while not self.recording_thread.buffer.empty(): + audio_data.append(self.recording_thread.buffer.get()) + data = np.concatenate(audio_data) + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: + temp_filename = temp_file.name + try: with wave.open(temp_filename, "wb") as wf: wf.setnchannels(self.channels) - wf.setsampwidth(2) # Always 2 for int16 + wf.setsampwidth(2) wf.setframerate(self.samplerate) wf.writeframes(data.tobytes()) - + self.update_status_signal.emit("Audio saved, starting transcription...") self.transcribe_audio(temp_filename) except Exception as e: @@ -184,12 +269,16 @@ def save_audio(self): logger.error(error_message) self.update_status_signal.emit(error_message) self.enable_widgets_signal.emit(True) - finally: - self.frames.clear() - def start_recording(self): - if not self.is_recording: - self.is_recording = True - threading.Thread(target=self.record_audio).start() - else: - self.update_status_signal.emit("Already recording.") + def stop_all_threads(self): + if hasattr(self, 'recording_thread') and self.recording_thread.isRunning(): + self.recording_thread.stop() + self.recording_thread.wait(timeout=5000) + + if hasattr(self, 'model_loader_thread') and self.model_loader_thread.isRunning(): + self.model_loader_thread.requestInterruption() + self.model_loader_thread.wait(timeout=5000) + + if hasattr(self, 'transcription_thread') and self.transcription_thread.isRunning(): + self.transcription_thread.requestInterruption() + self.transcription_thread.wait(timeout=5000) \ No newline at end of file diff --git a/ct2_main.py b/ct2_main.py index 7be578b..90611c8 100644 --- a/ct2_main.py +++ b/ct2_main.py @@ -1,29 +1,22 @@ import sys import os from pathlib import Path -import logging - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +import queue +from contextlib import contextmanager def set_cuda_paths(): - try: - venv_base = Path(sys.executable).parent.parent - nvidia_base_path = venv_base / 'Lib' / 'site-packages' / 'nvidia' - cuda_path = nvidia_base_path / 'cuda_runtime' / 'bin' - cublas_path = nvidia_base_path / 'cublas' / 'bin' - cudnn_path = nvidia_base_path / 'cudnn' / 'bin' - paths_to_add = [str(cuda_path), str(cublas_path), str(cudnn_path)] - env_vars = ['CUDA_PATH', 'CUDA_PATH_V12_1', 'PATH'] - - for env_var in env_vars: - current_value = os.environ.get(env_var, '') - new_value = os.pathsep.join(paths_to_add + [current_value] if current_value else paths_to_add) - os.environ[env_var] = new_value - logger.info("CUDA paths set successfully.") - except Exception as e: - logger.error(f"Failed to set CUDA paths: {e}") + venv_base = Path(sys.executable).parent.parent + nvidia_base_path = venv_base / 'Lib' / 'site-packages' / 'nvidia' + cuda_path = nvidia_base_path / 'cuda_runtime' / 'bin' + cublas_path = nvidia_base_path / 'cublas' / 'bin' + cudnn_path = nvidia_base_path / 'cudnn' / 'bin' + paths_to_add = [str(cuda_path), str(cublas_path), str(cudnn_path)] + env_vars = ['CUDA_PATH', 'CUDA_PATH_V12_1', 'PATH'] + + for env_var in env_vars: + current_value = os.environ.get(env_var, '') + new_value = os.pathsep.join(paths_to_add + [current_value] if current_value else paths_to_add) + os.environ[env_var] = new_value set_cuda_paths() @@ -35,9 +28,10 @@ def set_cuda_paths(): quantization_checker = CheckQuantizationSupport() cuda_available = quantization_checker.has_cuda_device() quantization_checker.update_supported_quantizations() - + app = QApplication(sys.argv) app.setStyle('Fusion') window = MyWindow(cuda_available) window.show() - sys.exit(app.exec()) + + sys.exit(app.exec()) \ No newline at end of file diff --git a/ct2_utils.py b/ct2_utils.py index 0608bd9..60fda96 100644 --- a/ct2_utils.py +++ b/ct2_utils.py @@ -1,35 +1,24 @@ import ctranslate2 import yaml -import logging - -logger = logging.getLogger(__name__) +import platform class CheckQuantizationSupport: + excluded_types = ['int16', 'int8', 'int8_float32', 'int8_float16', 'int8_bfloat16'] - + def has_cuda_device(self): - try: - cuda_device_count = ctranslate2.get_cuda_device_count() - return cuda_device_count > 0 - except Exception as e: - logger.error(f"Error checking CUDA devices: {e}") - return False + cuda_device_count = ctranslate2.get_cuda_device_count() + return cuda_device_count > 0 def get_supported_quantizations_cuda(self): - try: - cuda_quantizations = ctranslate2.get_supported_compute_types("cuda") - return [q for q in cuda_quantizations if q not in self.excluded_types] - except Exception as e: - logger.error(f"Error getting CUDA quantizations: {e}") - return [] + cuda_quantizations = ctranslate2.get_supported_compute_types("cuda") + excluded_types = self.excluded_types + return [q for q in cuda_quantizations if q not in excluded_types] def get_supported_quantizations_cpu(self): - try: - cpu_quantizations = ctranslate2.get_supported_compute_types("cpu") - return [q for q in cpu_quantizations if q not in self.excluded_types] - except Exception as e: - logger.error(f"Error getting CPU quantizations: {e}") - return [] + cpu_quantizations = ctranslate2.get_supported_compute_types("cpu") + excluded_types = self.excluded_types + return [q for q in cpu_quantizations if q not in excluded_types] def update_supported_quantizations(self): cpu_quantizations = self.get_supported_quantizations_cpu() @@ -38,8 +27,6 @@ def update_supported_quantizations(self): if self.has_cuda_device(): cuda_quantizations = self.get_supported_quantizations_cuda() self._update_supported_quantizations_in_config("cuda", cuda_quantizations) - else: - self._update_supported_quantizations_in_config("cuda", []) def _update_supported_quantizations_in_config(self, device, quantizations): try: @@ -54,5 +41,4 @@ def _update_supported_quantizations_in_config(self, device, quantizations): config["supported_quantizations"][device] = quantizations with open("config.yaml", "w") as f: - yaml.safe_dump(config, f) - logger.info(f"Updated supported quantizations for {device}: {quantizations}") + yaml.safe_dump(config, f, default_style="'") \ No newline at end of file diff --git a/setup.py b/setup.py index 1a76040..ee64e53 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ def install_libraries_with_retry(max_retries=3, delay=3): "packaging==24.1", "pip==24.2", "protobuf==5.28.2", + "psutil==6.0.0", "pycparser==2.22", "pyreadline3==3.5.4", "PyYAML==6.0.1",