From 181e3fd7923eb6ebdf652dbeede24008c2d5b267 Mon Sep 17 00:00:00 2001 From: WolfwithSword Date: Tue, 14 Jan 2025 14:37:36 -0400 Subject: [PATCH] Elevenlabs integration (#25) * Initial Elevenlabs integration * Redirect sys std logging * Bug fixes * Add elevenlabs link to readme --- .gitignore | 2 + README.md | 6 + requirements.txt | 5 +- src/chatdnd/events/tts_events.py | 5 + src/custom_logger/logger.py | 16 +++ src/data/__init__.py | 3 +- src/data/member.py | 35 +++++- src/data/voices.py | 133 +++++++++++++++++++++ src/db.py | 5 +- src/helpers/config.py | 10 ++ src/helpers/constants.py | 8 ++ src/helpers/utils.py | 33 +++++- src/main.py | 27 +++-- src/server/app.py | 76 +++++++++--- src/server/static/overlay.html | 13 +- src/tts/__init__.py | 3 +- src/tts/elevenlabs_tts.py | 174 +++++++++++++++++++++++++++ src/tts/local_tts.py | 70 +++++------ src/tts/tts.py | 31 ++++- src/ui/tabs/home.py | 53 ++++++++- src/ui/tabs/settings.py | 196 +++++++++++++++++++++++++++++-- src/ui/tabs/users.py | 2 +- src/ui/widgets/member_card.py | 74 +++++++++--- 23 files changed, 873 insertions(+), 107 deletions(-) create mode 100644 src/chatdnd/events/tts_events.py create mode 100644 src/data/voices.py create mode 100644 src/helpers/constants.py create mode 100644 src/tts/elevenlabs_tts.py diff --git a/.gitignore b/.gitignore index bb49f65..60a75bb 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ user_token.json config.ini .tcdnd-cache/ *.db +logs/ +*.db-journal # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/README.md b/README.md index cddb293..355ed6c 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,9 @@ # TwitchChatDND Host a D&D Session with Twitch Chat and TTS Voices! +ElevenLabs API is optional, but gives you access to several new voices. +Too add a new voice, copy its voice_id into the add menu after configuring your API key. + +The free tier will work for testing purposes (locks us out of some fancy voices plus only has a 10k credit limit), but right now we hardcode a turbo model which is half credit usage since we don't need perfection. + +[Elevenlabs subscription plans](https://try.elevenlabs.io/chatdnd) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c19e323..e3ea27a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,7 @@ Quart==0.20.0 SQLAlchemy==2.0.36 aiosqlite==0.20.0 pillow==11.1.0 -requests==2.32.3 \ No newline at end of file +requests==2.32.3 +elevenlabs==1.50.3 +CTkListbox==1.4 +static-ffmpeg==2.7 \ No newline at end of file diff --git a/src/chatdnd/events/tts_events.py b/src/chatdnd/events/tts_events.py new file mode 100644 index 0000000..65691ca --- /dev/null +++ b/src/chatdnd/events/tts_events.py @@ -0,0 +1,5 @@ +from helpers import Event + +on_elevenlabs_connect = Event() +request_elevenlabs_connect = Event() +on_elevenlabs_test_speak = Event() diff --git a/src/custom_logger/logger.py b/src/custom_logger/logger.py index 29baae4..dacd5fa 100644 --- a/src/custom_logger/logger.py +++ b/src/custom_logger/logger.py @@ -24,6 +24,20 @@ def __init__(self, log_file): formatter = logging.Formatter(_format) self.setFormatter(formatter) +class RedirectSysLogger(object): + def __init__(self, logger, level): + self.logger = logger + self.level = level + self.linebuf = '' + + def write(self, buf): + for line in buf.rstrip().splitlines(): + if line.strip() and len(line.strip()) > 1: + self.logger.log(self.level, line.rstrip()) + + def flush(self): + pass + class CustomLogger: def __init__(self, name): self.logger = logging.getLogger(name) @@ -65,3 +79,5 @@ def shutdown(self): self.listener.stop() logger = CustomLogger("ChatDND").logger +sys.stdout = RedirectSysLogger(logger, logging.INFO) +sys.stderr = RedirectSysLogger(logger, logging.ERROR) \ No newline at end of file diff --git a/src/data/__init__.py b/src/data/__init__.py index ea3f6b9..1100074 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -1,4 +1,5 @@ from data.member import Member +from data.voices import Voice from data.session import Session, SessionState -__all__ = [Member, Session, SessionState] \ No newline at end of file +__all__ = [Member, Session, SessionState, Voice] \ No newline at end of file diff --git a/src/data/member.py b/src/data/member.py index 0cb04ae..a454ba9 100644 --- a/src/data/member.py +++ b/src/data/member.py @@ -1,16 +1,22 @@ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.orm import sessionmaker, Mapped, mapped_column -from sqlalchemy import String, Integer, JSON, asc +from sqlalchemy.orm import sessionmaker, Mapped, mapped_column, relationship +from sqlalchemy import String, Integer, JSON, asc, ForeignKey, null from sqlalchemy.future import select +from custom_logger.logger import logger + from data.base import Base +from data.voices import Voice from sqlalchemy.future import select from db import async_session +from data.voices import fetch_voice + # We don't need to pass the DB object around after it's been initialized by main # Simply import async_session from it, or objects made from it around + class Member(Base): __tablename__ = "members" @@ -18,7 +24,8 @@ class Member(Base): name: Mapped[str] = mapped_column(String, unique=True, nullable=False) pfp_url: Mapped[str] = mapped_column(String, default="") num_sessions: Mapped[int] = mapped_column(Integer, default=0) # increment on end/complete session - preferred_tts: Mapped[str] = mapped_column(String, default="") # local vs cloud, and what name for tts voice + preferred_tts_uid: Mapped[str] = mapped_column(ForeignKey("voices.uid"), nullable=True) + preferred_tts: Mapped[Voice] = relationship("Voice", lazy="subquery") data: Mapped[dict] = mapped_column(JSON, default=dict) # We can store arbitrary data in here if we need extra columns and stuff later, just need to be safe with checking # elsewise, we will need to setup alembic and migrations with an updater script/function/exe @@ -71,15 +78,31 @@ async def _upsert_member(name: str, pfp_url: str) -> Member: session.add(new_member) return new_member -async def update_tts(member: Member, preferred_tts: str = ""): +async def update_tts(member: Member, voice_id: str): async with async_session() as session: async with session.begin(): + voice_in_db = await fetch_voice(uid=voice_id) member_in_db = await session.get(Member, member.id) - if member_in_db: - member_in_db.preferred_tts = preferred_tts + if member_in_db and voice_in_db: + member.preferred_tts_uid = voice_id + member.preferred_tts = voice_in_db + + member_in_db.preferred_tts_uid = voice_id + member_in_db.preferred_tts = voice_in_db await session.commit() +async def remove_tts(voice_id: str): + if not voice_id: + return + async with async_session() as session: + async with session.begin(): + result = await session.execute(select(Member).where(Member.preferred_tts_uid == voice_id)) + for member in result.scalars().all(): + member.preferred_tts_uid = null() + await session.commit() + + async def fetch_member(name: str) -> Member | None: name = name.lower() async with async_session() as session: diff --git a/src/data/voices.py b/src/data/voices.py new file mode 100644 index 0000000..76a59a4 --- /dev/null +++ b/src/data/voices.py @@ -0,0 +1,133 @@ +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker, Mapped, mapped_column +from sqlalchemy import String, Integer, asc, CheckConstraint +from sqlalchemy.future import select + +from data.base import Base + +from sqlalchemy.future import select +from db import async_session + +from custom_logger.logger import logger +from helpers.constants import SOURCES, SOURCE_11L, SOURCE_LOCAL + +class Voice(Base): + __tablename__ = "voices" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + name: Mapped[str] = mapped_column(String, unique=True, nullable=False) # Unique? + uid: Mapped[str] = mapped_column(String, unique=True, nullable=False) + source: Mapped[str] = mapped_column(String, default=SOURCE_LOCAL, nullable=False) + + __table_args__ = ( + CheckConstraint( + f"source IN ({', '.join(repr(value) for value in SOURCES)})", + name="chk_source_valid_val" + ), + ) + + def __init__(self, name: str, uid: str, source: str): + if source not in SOURCES: + return None + self.name: str = name + self.uid: str = uid + self.source: str = source + + def __eq__(self, other): + return self.name == other.name and self.uid == other.uid and self.source == other.source + + + def __hash__(self): + return hash(f"{self.name}{self.uid}{self.source}") + + + def __repr__(self): + return f"Voice(name='{self.name}', uid='{self.uid}', source='{self.source}')" + + def __lt__(self, other): + return self.name < other.name + + def __gt__(self, other): + return self.name > other.name + + +async def _upsert_voice(name: str, uid: str, source: str) -> Voice | None: + if source not in SOURCES: + return None + async with async_session() as session: + async with session.begin(): + query = select(Voice).where(Voice.name == name).where(Voice.uid == uid).where(Voice.source == source) + result = await session.execute(query) + voice = result.scalars().first() + + if voice: + return voice + else: + # Create new voice + new_voice = Voice(name=name, uid=uid, source=source) + session.add(new_voice) + return new_voice + + +async def delete_voice(uid: str = None, source: str = None) -> bool: + if not uid: + return False + async with async_session() as session: + voice = await fetch_voice(uid=uid, source=source) + if voice: + await session.delete(voice) + await session.commit() + return True + return False + + +async def fetch_voice(name: str = None, uid: str = None, source: str = None) -> Voice | None: + if not any([name, uid]): + return None + async with async_session() as session: + query = select(Voice) + if uid: + query = query.where(Voice.uid == uid) + if name: + query = query.where(Voice.name == name) + if source: + query = query.where(Voice.source == source) + result = await session.execute(query) + return result.scalars().first() + + +async def fetch_voices(source: str = None, limit: int = 100) -> list[Voice]: + + async with async_session() as session: + query = select(Voice) + if source: + query = query.where(Voice.source == source) + query = query.limit(limit) + result = await session.execute(query) + res = result.scalars().all() + if not res and source == 'elevenlabs': + logger.info("Adding default ElevenLabs voice 'Will'") + v = await _upsert_voice(name="Will", uid="bIHbv24MWmeRgasZH58o", source=SOURCE_11L) + return [v] + return res + + +async def fetch_paginated_voices(page: int, per_page: int=20, + name_filter: str = None, filter_source: str = None) -> list[Voice]: + if not exclude_names: + exclude_names = [] + + async with async_session() as session: + query = select(Voice).order_by(asc(Voice.name)) + + if name_filter: + query = query.where(Voice.name.like(f"%{name_filter.lower()}%")) + + if filter_source: + query = query.where(Voice.source == source) + + offset = (page -1) * per_page + query = query.offset(offset).limit(per_page) + + result = await session.execute(query) + return result.scalars().all() diff --git a/src/db.py b/src/db.py index 1f4df4e..402b98c 100644 --- a/src/db.py +++ b/src/db.py @@ -1,5 +1,6 @@ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker +import os DATABASE_URL = "sqlite+aiosqlite:///tcdnd_data.db" engine = create_async_engine(DATABASE_URL, echo=False) @@ -9,8 +10,10 @@ import logging from data.base import Base + +debug_mode = os.environ['TCDND_DEBUG_MODE'] == '1' +logging.getLogger("sqlalchemy.engine").setLevel(logging.DEBUG if debug_mode else logging.INFO) logging.getLogger("sqlalchemy.engine").handlers = logger.handlers -logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) async def initialize_database(): #Initialize the database and create tables.# diff --git a/src/helpers/config.py b/src/helpers/config.py index 15588ab..b8db9c6 100644 --- a/src/helpers/config.py +++ b/src/helpers/config.py @@ -45,6 +45,9 @@ def setup(self, path: str): if not self.has_option(section="CACHE", option="cache_expiry"): needs_init = True self.set(section="CACHE", option="cache_expiry", value=str(7*24*60*60)) + if not self.has_option(section="CACHE", option="tts_cache_expiry"): + needs_init = True + self.set(section="CACHE", option="tts_cache_expiry", value=str(7*24*60*60*4*3)) if not self.has_section("TWITCH"): needs_init = True @@ -60,6 +63,13 @@ def setup(self, path: str): needs_init = True self.set(section="DND", option="party_size", value=str(4)) + if not self.has_section("ELEVENLABS"): + needs_init = True + self.add_section("ELEVENLABS") + if not self.has_option(section="ELEVENLABS", option="api_key"): + needs_init = True + self.set(section="ELEVENLABS", option="api_key", value='') + if needs_init: self.write_updates() diff --git a/src/helpers/constants.py b/src/helpers/constants.py new file mode 100644 index 0000000..f10bd2f --- /dev/null +++ b/src/helpers/constants.py @@ -0,0 +1,8 @@ + +SOURCE_LOCAL = "local" +SOURCE_11L = "elevenlabs" + +SOURCES = [ + SOURCE_LOCAL, + SOURCE_11L +] \ No newline at end of file diff --git a/src/helpers/utils.py b/src/helpers/utils.py index 21f1377..dd52e5e 100644 --- a/src/helpers/utils.py +++ b/src/helpers/utils.py @@ -1,4 +1,10 @@ import os, sys +import asyncio +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Coroutine, TypeVar + +T = TypeVar("T") def get_resource_path(relative_path, from_resources: bool = False): # Get absolute path to a resource, frozen or local (relative to helpers/utils.py) @@ -13,4 +19,29 @@ def get_resource_path(relative_path, from_resources: bool = False): base_path = os.path.dirname(os.path.abspath(__file__)) else: base_path = os.path.dirname(os.path.abspath(__file__)) - return os.path.join(base_path, relative_path) \ No newline at end of file + return os.path.join(base_path, relative_path) + + +def run_coroutine_sync(coroutine: Coroutine[Any, Any, T], timeout: float = 30) -> T: + def run_in_new_loop(): + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + return new_loop.run_until_complete(coroutine) + finally: + new_loop.close() + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coroutine) + + if threading.current_thread() is threading.main_thread(): + if not loop.is_running(): + return loop.run_until_complete(coroutine) + else: + with ThreadPoolExecutor() as pool: + future = pool.submit(run_in_new_loop) + return future.result(timeout=timeout) + else: + return asyncio.run_coroutine_threadsafe(coroutine, loop).result() \ No newline at end of file diff --git a/src/main.py b/src/main.py index 13fb632..9366180 100644 --- a/src/main.py +++ b/src/main.py @@ -10,9 +10,12 @@ def parse_args(): parser.add_argument('--debug', action='store_true', help="Enable debug logging level.") return parser.parse_args() +cwd = os.getcwd() args = parse_args() os.environ['TCDND_DEBUG_MODE'] = '1' if args.debug else '0' +import static_ffmpeg + from queue import Queue _tasks = Queue() import helpers.event as _event_module @@ -34,13 +37,24 @@ def parse_args(): from custom_logger.logger import logger from db import initialize_database -cwd = os.getcwd() +import static_ffmpeg +logger.info("Setting up ffmpeg...") +static_ffmpeg.add_paths() +logger.info("Done setting up ffmpeg") + +async def run_db_init(): + await initialize_database() +asyncio.run(run_db_init())#"DB-Setup" + config_path = os.path.join(cwd, 'config.ini') cache_dir = os.path.join(cwd, '.tcdnd-cache/') config = Config() config.setup(config_path) +if not config.has_option(section="CACHE",option="directory"): + config.set(section="CACHE", option="directory", value=cache_dir) + twitch_utils = TwitchUtils(config, cache_dir) session_mgr: SessionManager = SessionManager() @@ -50,9 +64,7 @@ def parse_args(): APP_RUNNING = True - async def run_twitch(): - async def try_setup(): while not config.twitch_auth: await asyncio.sleep(5) @@ -93,9 +105,6 @@ async def try_channel(): try: if chat.channel: return True - # await chat.start(twitch_utils) - - # ui_settings_twitch_channel_update_event.trigger([twitch_utils]) await asyncio.sleep(4) if chat.channel: return True @@ -122,7 +131,6 @@ async def try_channel(): logger.info("Starting") - async def run_server(): await server.run_task(host="0.0.0.0") @@ -135,7 +143,6 @@ async def run_ui(): APP_RUNNING = False await asyncio.sleep(2) sys.exit(0) - # app.mainloop() async def run_queued_tasks(): while APP_RUNNING: @@ -155,10 +162,9 @@ async def run_queued_tasks(): except Exception as e: logger.error(f"Error in queued task: {callback} ({args}) - {e}") - async def run_all(): + tasks = [ - asyncio.create_task(initialize_database(), name="DB-Setup"), asyncio.create_task(run_server(), name="Server"), asyncio.create_task(run_twitch(), name="Twitch"), asyncio.create_task(run_ui(), name="UI"), @@ -177,6 +183,5 @@ async def run_all(): except asyncio.CancelledError: logger.warning(f"{task.get_name()} task was cancelled") - if __name__ == "__main__": asyncio.run(run_all()) \ No newline at end of file diff --git a/src/server/app.py b/src/server/app.py index 9be8702..c9b1a04 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -1,12 +1,15 @@ import asyncio from asyncio import Queue -from quart import Quart, redirect, request, jsonify, websocket, render_template, send_from_directory, Response +from functools import wraps +from quart import Quart, redirect, request, jsonify, websocket, render_template, send_from_directory, Response, copy_current_websocket_context import os, sys from data import Member -from tts import LocalTTS +from data.voices import fetch_voice +from tts import LocalTTS, ElevenLabsTTS from helpers import TCDNDConfig as Config +from helpers.utils import run_coroutine_sync from custom_logger.logger import logger from chatdnd.events.chat_events import chat_say_command @@ -14,6 +17,9 @@ from chatdnd.events.web_events import on_overlay_open from helpers.utils import get_resource_path +from data.voices import fetch_voice +from helpers.constants import SOURCE_11L, SOURCE_LOCAL + STATIC_DIR = get_resource_path("../server/static", from_resources=True) message_queue = Queue() @@ -21,16 +27,49 @@ clients = set() overlay_clients = set() +def collect_tts_websockets(func): + @wraps(func) + async def wrapper(*args, **kwargs): + global clients + clients.add(websocket._get_current_object()) + try: + return await func(*args, **kwargs) + finally: + clients.discard(websocket._get_current_object()) + return wrapper + +def collect_member_websockets(func): + @wraps(func) + async def wrapper(*args, **kwargs): + global overlay_clients + overlay_clients.add(websocket._get_current_object()) + try: + return await func(*args, **kwargs) + finally: + overlay_clients.discard(websocket._get_current_object()) + return wrapper + +async def broadcast_tts(chunk): + for websock in clients: + await asyncio.wait_for(websock.send(chunk), timeout=10) + +async def broadcast_member_update(message): + for websock in overlay_clients: + await asyncio.wait_for(websock.send_json(message), timeout=5) + class ServerApp(): def __init__(self, config: Config): self.app = Quart(__name__) self.config = config - self.tts = LocalTTS(config) # TODO both local and cloud + self.tts = { + SOURCE_LOCAL: LocalTTS(config), + SOURCE_11L: ElevenLabsTTS(config, full_instance=True) + } self._setup_routes() self._party: set[Member] = set() - # Setup here temporarily for POC + # Setup here temporarily for POC - or just keep tbh f chat_say_command.addListener(self.chat_say) on_party_update.addListener(self.send_members) on_overlay_open.trigger() @@ -38,8 +77,8 @@ def __init__(self, config: Config): def _setup_routes(self): @self.app.websocket("/ws/tts") + @collect_tts_websockets async def audio_stream(): - clients.add(websocket) try: logger.debug("tts ws opened") await websocket.send_json({"type":"heartbeat"}) @@ -55,13 +94,22 @@ async def audio_stream(): duration = 0 last_chunk_duration = 0 send_bounce = False - async for chunk, _duration in self.tts.get_stream(message, '' if not member else member.preferred_tts): + + tts_type = SOURCE_LOCAL + voice_id = '' + if member and member.preferred_tts_uid: + _voice = await fetch_voice(uid=member.preferred_tts_uid) + if _voice: + voice_id = member.preferred_tts_uid + tts_type = _voice.source + async for chunk, _duration in self.tts[tts_type].get_stream(message, voice_id): # TODO: Allow for break / interruption from emergency stuff - also hide stuff. Or yknow, just instruct to hide the browser source. + # Yeah, to mute, best to just hide the browser source. if not send_bounce: send_bounce = True await self.animate_member(member.name, "bounce") await members_queue.put(speech_message) - await asyncio.wait_for(websocket.send(chunk), timeout=10) + await broadcast_tts(chunk) duration += _duration last_chunk_duration = _duration await asyncio.sleep(last_chunk_duration) @@ -69,7 +117,7 @@ async def audio_stream(): "type": "endspeech" } await self.animate_member(member.name, "idle") - await asyncio.sleep(0.3) + await asyncio.sleep(0.05) await members_queue.put(speech_message) await asyncio.sleep(0.1) @@ -78,33 +126,31 @@ async def audio_stream(): pass finally: logger.debug("tts ws closed") - clients.discard(websocket) @self.app.websocket("/ws/members") + @collect_member_websockets async def user_overlay_ws(): - overlay_clients.add(websocket) try: await websocket.send_json({"type":"heartbeat"}) logger.debug('overlay ws opened') + await asyncio.sleep(0.3) on_overlay_open.trigger() - await asyncio.sleep(0.5) + await asyncio.sleep(0.3) while True: while not members_queue.empty(): message = await members_queue.get() logger.info(f"member msg {message}") - await asyncio.wait_for(websocket.send_json(message), timeout=5) + await broadcast_member_update(message) + await asyncio.sleep(0.05) await asyncio.sleep(0.5) finally: logger.debug('overlay ws closed') - overlay_clients.discard(websocket) @self.app.route('/overlay') async def overlay(): return await send_from_directory(STATIC_DIR, 'overlay.html') async def chat_say(self, member: Member, text: str): - # If client was connected but dc'd, this can revive connection when it runs. But also, we don't want things to queue up forever... Might not be a problem, needs hard testing later - # But cannot simply do an if check here for clients await message_queue.put((member, text)) diff --git a/src/server/static/overlay.html b/src/server/static/overlay.html index 4e04dd1..e953f59 100644 --- a/src/server/static/overlay.html +++ b/src/server/static/overlay.html @@ -201,14 +201,12 @@ }); } - function playFromQueue() { if (audioQueue.length > 0 && !isPlaying) { const audioBuffer = audioQueue.shift(); const source = audioContext.createBufferSource(); source.buffer = audioBuffer; source.connect(audioContext.destination); - source.onended = () => { isPlaying = false; if (audioQueue.length > 0) { @@ -411,6 +409,17 @@ } if (immediately) { speechBubble.remove(); + const userCards = document.querySelectorAll('.user-card'); + userCards.forEach((card) => { + let _name = card.querySelector('span').textContent; + let image = card.querySelector('img'); + image.classList.forEach((cls) => { + if (cls.startsWith('anim-')) { + image.classList.remove(cls); + } + }); + + }); } else { setTimeout(() => { diff --git a/src/tts/__init__.py b/src/tts/__init__.py index 5a5ea1d..74c34c5 100644 --- a/src/tts/__init__.py +++ b/src/tts/__init__.py @@ -1,3 +1,4 @@ from tts.local_tts import LocalTTS +from tts.elevenlabs_tts import ElevenLabsTTS -__all__ = [LocalTTS] \ No newline at end of file +__all__ = [LocalTTS, ElevenLabsTTS] \ No newline at end of file diff --git a/src/tts/elevenlabs_tts.py b/src/tts/elevenlabs_tts.py new file mode 100644 index 0000000..2fcdb3d --- /dev/null +++ b/src/tts/elevenlabs_tts.py @@ -0,0 +1,174 @@ +import asyncio +import threading +import base64, io, struct + +from tts.tts import TTS, create_wav_header + +from elevenlabs.client import AsyncElevenLabs +from elevenlabs.client import ElevenLabs +from elevenlabs.types import Voice as ELVoice +from elevenlabs import play + +from diskcache import Cache + +from helpers import TCDNDConfig as Config +from helpers.utils import run_coroutine_sync +from helpers.constants import SOURCE_11L +from custom_logger.logger import logger + +from chatdnd.events.tts_events import on_elevenlabs_connect, request_elevenlabs_connect, on_elevenlabs_test_speak + +from data.voices import _upsert_voice, fetch_voices +from data import Voice + +FORMAT = 'pcm_22050' # Match local tts quality, not top but still good +MODEL = 'eleven_flash_v2_5'#'eleven_monolingual_v1' + + +class ElevenLabsTTS(TTS): + def __init__(self, config: Config, full_instance: bool = False): + super().__init__(config=None) + self.full_instance = full_instance + self.config = config + self.client: AsyncElevenLabs = None + + if self.full_instance: + request_elevenlabs_connect.addListener(self.setup) + self.setup() + on_elevenlabs_test_speak.addListener(self.test_speak) + + self.sample_rate = int(FORMAT.split("_")[-1]) + self.bits_per_sample = 16 + self.num_channels = 1 + + self.max_chunk_size = 1024*8*8*2*2 # 256kb + + if config.getboolean(section="CACHE", option="enabled"): + # Caching in general, but for here, it's specific to API results + cache_dir = config.get(section="CACHE", option="directory", fallback=None) + if not cache_dir: + self.cache = Cache() + else: + self.cache = Cache(directory=cache_dir) + else: + self.cache = None + + @property + def voices(self) -> dict: + d = {} + _voices = run_coroutine_sync(fetch_voices(source=SOURCE_11L)) + for v in _voices: + if v: + d.setdefault(f"{v.name} ({v.uid})", v.uid) + return d + + def setup(self): + + if self.full_instance: + on_elevenlabs_connect.trigger([False]) + if key := self.config.get(section="ELEVENLABS", option="api_key"): + try: + test_client = ElevenLabs(api_key=key) + # This will cause an exception if invalid api key + user = test_client.user.get() # TODO : display in settings the amount left in credits in user get or get subscription? Or swap ppl to local when low, or just popup warn? + + self.client = AsyncElevenLabs(api_key=key) + if self.full_instance: + on_elevenlabs_connect.trigger([True]) + except Exception as e: + logger.warn(f"ElevenLabs Exception: {e}") + if self.full_instance: + on_elevenlabs_connect.trigger([False]) + + + async def audio_stream_generator(self, text="Hello World!", voice_id: str = None): + if not voice_id or not self.client: + return + + async def fetch_stream(client, text, voice_id, model, _format): + output = io.BytesIO() + async for chunk in client.text_to_speech.convert_as_stream(text=text, + voice_id=voice_id, + model_id=model, + output_format=_format): + output.write(chunk) + output.seek(0) + return output + return await fetch_stream(self.client, text, voice_id, MODEL, FORMAT) + + async def get_stream(self, text="Hello World!", voice_id: str = None): + if not voice_id or not self.client: + yield None, None + + output = await self.audio_stream_generator(text, voice_id) + + header = create_wav_header(self.sample_rate, self.bits_per_sample, self.num_channels, len(output.getvalue())) + chunk_size = min(self.max_chunk_size, len(output.getvalue())) + chunk = output.read(chunk_size) + + while chunk: + duration = (len(chunk) / (self.sample_rate * self.num_channels * (self.bits_per_sample // 8))) + await asyncio.sleep(duration) + yield (header + chunk, duration) + chunk = output.read(chunk_size) + + + def get_voice_object(self, voice_id: str = "", run_sync_always: bool = False) -> ELVoice | None : + # Voice by id, will attempt to find it and cache it + if not voice_id or not self.config.get(section="ELEVENLABS", option="api_key", fallback=None): + return None + + key = f"11l.voice.{voice_id}" + v: ELVoice = None + if self.cache is not None: + v = self.cache.get(key=key, default=None) + if v: + logger.debug(f"Fetched cached preview audio for `{voice_id}`") + + if not v: + client = ElevenLabs(api_key=self.config.get(section="ELEVENLABS", option="api_key")) + v = None + try: + v = client.voices.get(voice_id) + except: + pass + if v: + self.cache.set(key=key, expire=self.config.getint(section="CACHE", option="tts_cache_expiry", fallback=7*24*60*60*4*3), value=v) + + if v: + if run_coroutine_sync: + run_coroutine_sync(_upsert_voice(name=v.name, uid=v.voice_id, source=SOURCE_11L)) + elif loop := asyncio.get_event_loop(): + asyncio.create_task(_upsert_voice(name=v.name, uid=v.voice_id, source=SOURCE_11L)) + else: + run_coroutine_sync(_upsert_voice(name=v.name, uid=v.voice_id, source=SOURCE_11L)) + return v + + + def test_speak(self, text:str = "Hello there. How are you?", voice_id: str = None): + if not voice_id or not self.config.get(section="ELEVENLABS", option="api_key", fallback=None): + return + # Could use `preview_url` from Voice object, but it is different for each voice, likely best to spend a few credits for consistency and cache it for a long time + voice_o = self.get_voice_object(voice_id) + + key = f"11l.preview.{voice_id}" + audio = None + if self.cache is not None: + audio_l = self.cache.get(key=key, default=None) + if audio_l: + audio = iter(audio_l) + logger.debug(f"Fetched cached preview audio for `{voice_id}`") + if not audio: + try: + client = ElevenLabs(api_key=self.config.get(section="ELEVENLABS", option="api_key")) + user = client.user.get() # Trigger bad api key + except: + on_elevenlabs_connect.trigger([False]) # needed? + return + audio = list(client.generate(text=text, voice=voice_id, model=MODEL)) + self.cache.set(key=key, expire=self.config.getint(section="CACHE", option="tts_cache_expiry", fallback=7*24*60*60*4*3), value=list(audio)) + audio = iter(audio) + + thread = threading.Thread(target=play, args=(audio,)) + thread.daemon = True + thread.start() diff --git a/src/tts/local_tts.py b/src/tts/local_tts.py index 8a1e8b6..df7862f 100644 --- a/src/tts/local_tts.py +++ b/src/tts/local_tts.py @@ -3,39 +3,21 @@ import base64, io, struct import pyttsx4 -from tts.tts import TTS +from tts.tts import TTS, create_wav_header from helpers import TCDNDConfig as Config +from helpers.utils import run_coroutine_sync +from helpers.constants import SOURCE_LOCAL from custom_logger.logger import logger +from data.voices import _upsert_voice, fetch_voices +from data import Voice -# TODO: For local TTS, there is a slight minor clipping when transitioning between chunks. Mitigated with a large chunk size, need better solution - -# May be needed by cloud tts as well? so prob candidate to move -def create_wav_header(sample_rate, bits_per_sample, num_channels, data_size): - # WAV File header - byte_rate = sample_rate * num_channels * bits_per_sample // 8 - block_align = num_channels * bits_per_sample // 8 - - header = struct.pack( - '<4sI4s4sIHHIIHH4sI', - b'RIFF', # ChunkID - 36 + data_size, # ChunkSize - b'WAVE', # Format - b'fmt ', # Subchunk1ID - 16, # Subchunk1Size - 1, # AudioFormat (PCM) - num_channels, # NumChannels - sample_rate, # SampleRate - byte_rate, # ByteRate - block_align, # BlockAlign - bits_per_sample, # BitsPerSample - b'data', # Subchunk2ID - data_size, # Subchunk2Size - ) - return header - -class LocalTTS(TTS): # TODO Refactor with some inheritance from a TTS class, so we can abstract both Local and Cloud TTS later on + +# TODO: For local TTS, there is a slight minor clipping when transitioning between chunks. Mitigated with a large chunk size, need better solution? may be fixed + + +class LocalTTS(TTS): def __init__(self, config: Config): super().__init__(config=None) @@ -47,17 +29,23 @@ def __init__(self, config: Config): engine = pyttsx4.init() for v in engine.getProperty('voices'): - self.voices.setdefault(v.name, v.id) + run_coroutine_sync(_upsert_voice(name=v.name, uid=v.id, source=SOURCE_LOCAL)) + @property + def voices(self) -> dict: + d = {} + _voices = run_coroutine_sync(fetch_voices(source='local')) + for v in _voices: + d.setdefault(f"{v.name}", v.uid) + return d - def audio_stream_generator(self, text="Hello World!", voice: str = ''): + def audio_stream_generator(self, text="Hello World!", voice_id: str = None): engine = pyttsx4.init() # We are using the fork for x4 as it works with outputting to bytesIO output = io.BytesIO() - if voice and voice in self.get_voices().keys(): - engine.setProperty('voice', self.get_voices()[voice]) + if voice_id and voice_id in self.get_voices().values(): + engine.setProperty('voice', voice_id) - # TODO Uses default tts on at the moment. Can configure later engine.setProperty('rate', 150) # Speed of speech engine.setProperty('volume', 1) # Volume level (0.0 to 1.0) @@ -71,8 +59,8 @@ def audio_stream_generator(self, text="Hello World!", voice: str = ''): return output - async def get_stream(self, text="Hello World!", voice: str = ''): - output = self.audio_stream_generator(text, voice) + async def get_stream(self, text="Hello World!", voice_id: str = ''): + output = self.audio_stream_generator(text, voice_id) header = create_wav_header(self.sample_rate, self.bits_per_sample, self.num_channels, len(output.getvalue())) chunk_size = min(self.max_chunk_size, len(output.getvalue())) chunk = output.read(chunk_size) @@ -83,13 +71,13 @@ async def get_stream(self, text="Hello World!", voice: str = ''): yield (header + chunk, duration) chunk = output.read(chunk_size) - def test_speak(self, text:str ="Hello there. How are you?", voice:str = None): - def _run(text, voice): + def test_speak(self, text:str ="Hello there. How are you?", voice_id: str = None): + def _run(text, voice_id): engine = pyttsx4.init() - if voice and voice in self.voices.keys(): - engine.setProperty('voice', self.voices.get(voice)) + if voice_id in self.voices.values(): + engine.setProperty('voice', voice_id) engine.say(text) engine.runAndWait() - thread = threading.Thread(target=_run, args=(text, voice)) + thread = threading.Thread(target=_run, args=(text, voice_id)) thread.daemon = True - thread.start() \ No newline at end of file + thread.start() diff --git a/src/tts/tts.py b/src/tts/tts.py index a7d800b..df06587 100644 --- a/src/tts/tts.py +++ b/src/tts/tts.py @@ -1,5 +1,31 @@ from helpers import TCDNDConfig as Config from custom_logger.logger import logger +import base64, io, struct +from helpers.utils import run_coroutine_sync + +def create_wav_header(sample_rate, bits_per_sample, num_channels, data_size): + # WAV File header + byte_rate = sample_rate * num_channels * bits_per_sample // 8 + block_align = num_channels * bits_per_sample // 8 + + header = struct.pack( + '<4sI4s4sIHHIIHH4sI', + b'RIFF', # ChunkID + 36 + data_size, # ChunkSize + b'WAVE', # Format + b'fmt ', # Subchunk1ID + 16, # Subchunk1Size + 1, # AudioFormat (PCM) + num_channels, # NumChannels + sample_rate, # SampleRate + byte_rate, # ByteRate + block_align, # BlockAlign + bits_per_sample, # BitsPerSample + b'data', # Subchunk2ID + data_size, # Subchunk2Size + ) + return header + class TTS(): @@ -9,4 +35,7 @@ def __init__(self, config: Config): self.config = config def get_voices(self) -> dict: - return self.voices \ No newline at end of file + return self.voices + + async def get_stream(self): + yield (None, 0) \ No newline at end of file diff --git a/src/ui/tabs/home.py b/src/ui/tabs/home.py index 9fe4855..c1cc009 100644 --- a/src/ui/tabs/home.py +++ b/src/ui/tabs/home.py @@ -3,7 +3,7 @@ from ui.widgets.member_card import MemberCard from twitch.chat import ChatController -from chatdnd.events.chat_events import chat_on_join_queue, chat_bot_on_connect +from chatdnd.events.chat_events import chat_on_join_queue, chat_bot_on_connect, chat_say_command class HomeTab(): @@ -77,6 +77,52 @@ def __init__(self, parent, chat_ctrl: ChatController): ################################## + ####### Chat View ####### + + # TODO Listen to chat events, display :. Append to scrollframe, make it move. Make visually distinct? + self._chat_frame = ctk.CTkScrollableFrame(self.parent, width=288, height=574) + self._chat_frame.place(relx=0.740, rely = 0.057) + chat_say_command.addListener(self._add_chat_msg) + + ################################## + + + def _add_chat_msg(self, member, message): + name = member.name + + msg_frame = ctk.CTkFrame(self._chat_frame) + msg_frame.pack(fill='x', pady=5, padx=2, anchor='w') + + username_label = ctk.CTkLabel( + msg_frame, + text=f"{name}:", + font=("Arial", 12, "bold"), + text_color="#e8e8e8", + justify='left', + anchor='nw' + ) + username_label.grid(row=0, column=0, padx=(6,3), pady=(2,4), sticky='nw') + + available_width = self._chat_frame.winfo_width() - 30 + + message_label = ctk.CTkLabel( + msg_frame, + text=message, + font=("Arial", 12), + wraplength=available_width, + text_color="#c8c8c8", + justify='left', + anchor='w' + ) + message_label.grid(row=1, column=0, padx=(6,2), pady=(2,8), sticky='w') + + self._chat_frame.after(100, self._chat_frame._parent_canvas.yview_moveto(1.0)) # TODO doesnt seem to scroll all the way to bottom? + + + def _clear_chat_log(self): + for widget in self._chat_frame.winfo_children(): + widget.destroy() + def _fill_party_frame(self): for child in self._party_frame.winfo_children(): @@ -86,7 +132,7 @@ def _fill_party_frame(self): for index, member in enumerate(sorted(self.chat_ctrl.session_mgr.session.party)): row = index // columns col = index % columns - member_card = MemberCard(self._party_frame, member, width=130, height=170, textsize=10) + member_card = MemberCard(self._party_frame, member, self.config, width=130, height=170, textsize=10) member_card.grid(row=row, column=col, padx=(35, 10), pady=(12,12), sticky="w") @@ -109,6 +155,7 @@ def _open_session(self): self.start_session.configure(state="normal") self.end_button.configure(state="disabled") self._fill_party_frame() + self._clear_chat_log() def _start_session(self): @@ -122,6 +169,7 @@ def _start_session(self): self.start_session.configure(state="disabled") self.end_button.configure(state="normal") self._fill_party_frame() + self._clear_chat_log() def _end_session(self): @@ -134,6 +182,7 @@ def _end_session(self): self.start_session.configure(state="disabled") self.end_button.configure(state="disabled") self._fill_party_frame() + self._clear_chat_log() def add_queue_user(self, name): diff --git a/src/ui/tabs/settings.py b/src/ui/tabs/settings.py index 423df25..fa4c2f5 100644 --- a/src/ui/tabs/settings.py +++ b/src/ui/tabs/settings.py @@ -1,11 +1,18 @@ import customtkinter as ctk +from CTkListbox import * from custom_logger.logger import logger from helpers import TCDNDConfig as Config +from helpers.utils import run_coroutine_sync +from helpers.constants import SOURCE_11L from twitch.utils import TwitchUtils +from tts import ElevenLabsTTS +from data.voices import delete_voice +from data.member import remove_tts from chatdnd.events.ui_events import * from chatdnd.events.chat_events import * from chatdnd.events.twitchutils_events import twitchutils_twitch_on_connect_event +from chatdnd.events.tts_events import request_elevenlabs_connect, on_elevenlabs_connect, on_elevenlabs_test_speak class SettingsTab(): # TODO Someone, please, clean this POS up. I'm begging. Nvm it's actually sorta clean, not really, but good enough @@ -79,21 +86,121 @@ def __init__(self, parent, config: Config, twitch_utils: TwitchUtils): row+=1 column=0 label_web = ctk.CTkLabel(self.parent, text="Browser Source", anchor="w", font=header_font) - label_web.grid(row=row, column=column, padx=10, pady=(50,10)) + label_web.grid(row=row, column=column, padx=10, pady=(30,10)) row+=1 - label_bs = ctk.CTkLabel(self.parent, text="http://localhost:5000/overlay") # TODO copy to clipboard button, dynamic string with port from config. - label_bs.grid(row=row, column=column, padx=(10,10), pady=(10, 2)) + label_bs_var = ctk.StringVar(value=f"http://localhost:{self.config.get(section="SERVER", option='port', fallback='5000')}/overlay") # TODO update dynamically when port adjustment is implemented + label_bs = ctk.CTkEntry(self.parent, textvariable=label_bs_var, state="readonly", width=190) + label_bs.grid(row=row, column=column, padx=(10,2), pady=(10, 2)) + def _copy_web_clipboard(): + self.parent.clipboard_clear() + self.parent.clipboard_append(label_bs.get()) + button_web = ctk.CTkButton(self.parent, width=30, height=30, text="Copy", command=_copy_web_clipboard) + column += 1 + button_web.grid(row=row, column=column, padx=(2, 10), pady=(10,2), sticky='w') # TODO: Port configuration. Can we easily restart the quart server while live, or require application restart? - # TODO: Copy button for the overlay URL - - ######### TTS ########## + ######### 11L TTS ########## row+=1 column=0 - label_tts = ctk.CTkLabel(self.parent, text="TTS", anchor="w", font=header_font) - label_tts.grid(row=row, column=0, padx=10, pady=(50,10)) - # Select a default for new members, later later later on toggle for use elevenlabs too (needs its own config section) + label_tts = ctk.CTkLabel(self.parent, text="ElevenLabs TTS", anchor="w", font=header_font) + label_tts.grid(row=row, column=column, padx=10, pady=(40,10)) + column+=1 + self.e11labs_con_label = ctk.CTkLabel(self.parent, text="ElevenLabs Disconnected", text_color="red") + self.e11labs_con_label.grid(row=row, column=column, padx=10, pady=(30,2)) + + row += 1 + column = 0 + button_el = ctk.CTkButton(self.parent,height=30, text="Save", command=self._update_el_settings) + button_el.grid(row=row, column=column, padx=10, pady=(20,10)) + column +=1 + el_api_label = ctk.CTkLabel(self.parent, text="API Key") + el_api_label.grid(row=row, column=column, padx=(10,10), pady=(10,2)) + column +=1 + el_voices_label = ctk.CTkLabel(self.parent, text="ElevenLabs Added Voices") + el_voices_label.grid(row=row, column=column, padx=(10,10), pady=(10,2)) + + + column=1 + row+=1 + self.el_api_key_var = ctk.StringVar(value=self.config.get(section="ELEVENLABS", option="api_key", fallback="")) + el_api_key_entry = ctk.CTkEntry(self.parent, width=180, height=30, border_width=1, fg_color="white", placeholder_text="API Key", text_color="black", textvariable=self.el_api_key_var) + el_api_key_entry.configure(justify="center", show="*") + el_api_key_entry.grid(row=row, column=column, padx=(20,20), pady=(2, 20), sticky='n') + + self.e11_voices = CTkListbox(self.parent, width=450, height=200, command=self._on_voice_option_select) + + column +=1 + self.e11_voices.grid(row=row, column=column, columnspan=3) + column += 3 + + # could undo some of these self's if in a method? + self.add_v_button = ctk.CTkButton(self.parent,height=30, text="Add Voice", command=self.open_edit_popup) + self.add_v_button.grid(row=row, column=column, padx=10, pady=(2,10), sticky='n') + self.del_v_button = ctk.CTkButton(self.parent,height=30, text="Remove Voice", fg_color="#b1363d", hover_color="#772429", command=self._delete_voice) + self.del_v_button.grid(row=row, column=column, padx=10, pady=(42,10), sticky='n') + self.preview_v_button = ctk.CTkButton(self.parent,height=30, text="Preview Voice", command=self._preview_e11_voice) + self.preview_v_button.grid(row=row, column=column, padx=10, pady=(138,10), sticky='n') + + on_elevenlabs_connect.addListener(self._update_elevenlabs_connection) + request_elevenlabs_connect.trigger() + + + def open_edit_popup(self, event=None): + AddVoiceCard(self.config, self._update_voice_list) + + + def _on_voice_option_select(self, option): + if option: + self.preview_v_button.configure(state="normal") + if self.e11_voices.size() > 1: + self.del_v_button.configure(state="normal") + else: + self.del_v_button.configure(state="disabled") + else: + self.del_v_button.configure(state="disabled") + self.preview_v_button.configure(state="disabled") + + + def _preview_e11_voice(self): + option = self.e11_voices.get() + if not option: + return + client = ElevenLabsTTS(self.config) + if uid := client.get_voices()[option]: + on_elevenlabs_test_speak.trigger(["Hello there. How are you?", uid]) + + + def _update_voice_list(self): + client = ElevenLabsTTS(self.config) + if self.e11_voices.size(): + self.e11_voices.selection_clear() + while self.e11_voices.size(): + self.e11_voices.deactivate('END') + self.e11_voices.delete('END') + for k in client.get_voices().keys(): + self.e11_voices.insert("END", option=k) + self.del_v_button.configure(state="disabled") + self.preview_v_button.configure(state="disabled") + + + def _delete_voice(self): + if self.e11_voices.size() <= 1: + return + option = self.e11_voices.get() + client = ElevenLabsTTS(self.config) + result = None + if uid := client.get_voices()[option]: + result0 = run_coroutine_sync(remove_tts(voice_id=uid)) + result1 = run_coroutine_sync(delete_voice(uid=uid, source=SOURCE_11L)) + if result1: + self._update_voice_list() + + + def _update_el_settings(self): + self.config.set(section="ELEVENLABS", option="api_key", value=str(self.el_api_key_var.get())) + self.config.write_updates() + request_elevenlabs_connect.trigger() def _update_bot_settings(self): @@ -106,6 +213,24 @@ def _update_bot_settings(self): ui_settings_twitch_channel_update_event.trigger([True, self.twitch_utils, 5]) + def _update_elevenlabs_connection(self, status: bool): + if status: + self.e11labs_con_label.configure(text="ElevenLabs Connected", text_color="green") + self._update_voice_list() + self.add_v_button.configure(state="normal") + else: + self.e11labs_con_label.configure(text="ElevenLabs Disconnected", text_color="red") + self.add_v_button.configure(state="disabled") + if self.e11_voices.size(): + self.e11_voices.selection_clear() + while self.e11_voices.size(): + self.e11_voices.deactivate('END') + self.e11_voices.delete('END') + self.del_v_button.configure(state="disabled") + self.preview_v_button.configure(state="disabled") + self.parent.focus() + + def _update_bot_connection(self, status: bool): if status: self.chat_con_label.configure(text="Chat Connected", text_color="green") @@ -113,9 +238,62 @@ def _update_bot_connection(self, status: bool): self.chat_con_label.configure(text="Chat Disconnected", text_color="red") self.parent.focus() + def _update_twitch_connect(self, status: bool, twitchutils = None): if status: self.t_con_label.configure(text="Twitch Connected", text_color="green") else: self.t_con_label.configure(text="Twitch Disconnected", text_color="red") self.parent.focus() + + +class AddVoiceCard(ctk.CTkToplevel): + open_popup = None + + def __init__(self, config: Config, update_list_callback: callable): + if AddVoiceCard.open_popup is not None: + AddVoiceCard.open_popup.focus_set() + return + + super().__init__() + AddVoiceCard.open_popup = self + self.config: Config = config + self.update_list_callback = update_list_callback + self.title(f"Add new ElevenLabs Voice") + self.geometry("400x400") + self.resizable(False, False) + self.protocol("WM_DELETE_WINDOW", self.close_popup) + + self.tts = ElevenLabsTTS(self.config) + self.attributes("-topmost", True) + self.create_widgets() + + + def create_widgets(self): + label1 = ctk.CTkLabel(self, text="ElevenLabs Voice Id:") + label1.pack(pady=(20, 5)) + + self.voice_id_var = ctk.StringVar() + self.voice_id_input = ctk.CTkEntry(self, width=160, height=30, textvariable=self.voice_id_var) + self.voice_id_input.pack(pady=(10,10)) + + self.label_warn = ctk.CTkLabel(self, text="", text_color="red") + self.label_warn.pack(pady=10) + + self.save_button = ctk.CTkButton(self, text="Add", command=self.save_changes) + self.save_button.pack(pady=10) + + + def save_changes(self): + elvoice = self.tts.get_voice_object(voice_id=self.voice_id_var.get(), run_sync_always=True) + if elvoice: + self.close_popup() + else: + self.label_warn.configure(text="Voice Id not found!") + # keep open and change label to error + + + def close_popup(self): + AddVoiceCard.open_popup = None + self.update_list_callback() + self.destroy() diff --git a/src/ui/tabs/users.py b/src/ui/tabs/users.py index 1c5a9b7..fd4cf0e 100644 --- a/src/ui/tabs/users.py +++ b/src/ui/tabs/users.py @@ -69,7 +69,7 @@ async def load_members(self): new_cards = list() columns = 6 for member in members: - member_card = MemberCard(self.members_list_frame, member) + member_card = MemberCard(self.members_list_frame, member,self.chat_ctrl.config) new_cards.append(member_card) new_cards = sorted(new_cards[:], key=lambda x: x.member.name) for index, card in enumerate(new_cards): diff --git a/src/ui/widgets/member_card.py b/src/ui/widgets/member_card.py index 0316c0f..e938765 100644 --- a/src/ui/widgets/member_card.py +++ b/src/ui/widgets/member_card.py @@ -7,13 +7,17 @@ from data import Member from custom_logger.logger import logger -from tts import LocalTTS -from data.member import update_tts +from tts import LocalTTS, ElevenLabsTTS +from data.member import update_tts, fetch_member +from helpers import TCDNDConfig as Config +from helpers.utils import run_coroutine_sync +from helpers.constants import SOURCES, SOURCE_11L, SOURCE_LOCAL class MemberCard(ctk.CTkFrame): - def __init__(self, parent, member: Member, width=160, height=200, textsize=12, *args, **kwargs): + def __init__(self, parent, member: Member, config: Config, width=160, height=200, textsize=12, *args, **kwargs): super().__init__(parent, width=width, height=height, *args, **kwargs) self.member: Member = member + self.config: Config = config self.width = width self.height = height self.textsize = textsize @@ -52,12 +56,12 @@ def setup_pfp(self): self.bg_label.bind("", self.open_edit_popup) def open_edit_popup(self, event=None): - MemberEditCard(self.member) + MemberEditCard(self.member, self.config) class MemberEditCard(ctk.CTkToplevel): open_popup = None - def __init__(self, member: Member): + def __init__(self, member: Member, config: Config): if MemberEditCard.open_popup is not None: MemberEditCard.open_popup.focus_set() return @@ -65,31 +69,58 @@ def __init__(self, member: Member): super().__init__() MemberEditCard.open_popup = self self.member: Member = member + self.config: Config = config self.title(f"Edit {self.member.name.upper()}") self.geometry("400x400") self.resizable(False, False) - self.localTTS = LocalTTS(None) + self.protocol("WM_DELETE_WINDOW", self.close_popup) + + self.tts = { + SOURCE_LOCAL: LocalTTS(self.config), + SOURCE_11L: ElevenLabsTTS(self.config) + } self.attributes("-topmost", True) self.create_widgets() - - self.protocol("WM_DELETE_WINDOW", self.close_popup) #self.deiconify() def create_widgets(self): # TODO add stuff, make pretty, idk - self.label = ctk.CTkLabel(self, text="Preferred TTS:") + label1 = ctk.CTkLabel(self, text="TTS Voice Source:") + label1.pack(pady=(20, 5)) + + + self.label = ctk.CTkLabel(self, text="TTS Voice:") self.label.pack(pady=(20, 5)) + current_source = SOURCE_LOCAL - self.tts_options = list(self.localTTS.get_voices().keys()) # Get keys from the voices dict + db_member = run_coroutine_sync(fetch_member(name=self.member.name)) + if db_member.preferred_tts_uid and db_member.preferred_tts: + current_source = db_member.preferred_tts.source + + self.tts_source_var = ctk.StringVar(value=current_source) + if not self.config.get(section="ELEVENLABS", option="api_key"): + SOURCES.remove(SOURCE_11L) + self.tts_source_dropdown = ctk.CTkOptionMenu(self, values=SOURCES, variable=self.tts_source_var, command=self._update_voicelist) + + voices = self.tts[current_source].get_voices() + self.tts_options = list(voices.keys()) self.tts_dropdown = ctk.CTkOptionMenu( self, values=self.tts_options ) - self.tts_dropdown.set(self.member.preferred_tts or self.tts_options[0]) + if self.member.preferred_tts_uid and self.member.preferred_tts_uid in voices.values(): + for k, v in voices.items(): + if v == self.member.preferred_tts_uid: + self.tts_dropdown.set(value=k) + break + else: + self.tts_dropdown.set(value=self.tts_options[0]) + + self.tts_source_dropdown.pack(pady=(5, 20)) self.tts_dropdown.pack(pady=(5, 20)) self.test_button = ctk.CTkButton(self, text="Preview", command=self.test_tts) @@ -99,15 +130,30 @@ def create_widgets(self): self.save_button.pack(pady=10) + def _update_voicelist(self, choice): + voices = self.tts[choice].get_voices() + self.tts_options = list(voices.keys()) + self.tts_dropdown.configure(values=self.tts_options) + if self.member.preferred_tts_uid and self.member.preferred_tts_uid in voices.values(): + for k, v in voices.items(): + if v == self.member.preferred_tts_uid: + self.tts_dropdown.set(value=k) + break + else: + self.tts_dropdown.set(value=self.tts_options[0]) + + def test_tts(self): - self.localTTS.test_speak(voice=self.tts_dropdown.get()) + voice_id = self.tts[self.tts_source_var.get()].get_voices()[self.tts_dropdown.get()] + self.tts[self.tts_source_var.get()].test_speak(voice_id = voice_id) def save_changes(self): new_tts = self.tts_dropdown.get() - self.member.preferred_tts = new_tts + voices = self.tts[self.tts_source_var.get()].get_voices() + voice_id = voices[new_tts] - asyncio.create_task(update_tts(self.member, new_tts)) + asyncio.create_task(update_tts(self.member, voice_id)) logger.info(f"Updated preferred_tts for {self.member.name} to {new_tts}") def close_popup(self):