Skip to content

Commit

Permalink
Elevenlabs integration (#25)
Browse files Browse the repository at this point in the history
* Initial Elevenlabs integration

* Redirect sys std logging

* Bug fixes

* Add elevenlabs link to readme
  • Loading branch information
WolfwithSword authored Jan 14, 2025
1 parent 38a3745 commit 181e3fd
Show file tree
Hide file tree
Showing 23 changed files with 873 additions and 107 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ user_token.json
config.ini
.tcdnd-cache/
*.db
logs/
*.db-journal

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# TwitchChatDND
Host a D&D Session with Twitch Chat and TTS Voices!

ElevenLabs API is optional, but gives you access to several new voices.
Too add a new voice, copy its voice_id into the add menu after configuring your API key.

The free tier will work for testing purposes (locks us out of some fancy voices plus only has a 10k credit limit), but right now we hardcode a turbo model which is half credit usage since we don't need perfection.

[Elevenlabs subscription plans](https://try.elevenlabs.io/chatdnd)
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ Quart==0.20.0
SQLAlchemy==2.0.36
aiosqlite==0.20.0
pillow==11.1.0
requests==2.32.3
requests==2.32.3
elevenlabs==1.50.3
CTkListbox==1.4
static-ffmpeg==2.7
5 changes: 5 additions & 0 deletions src/chatdnd/events/tts_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from helpers import Event

on_elevenlabs_connect = Event()
request_elevenlabs_connect = Event()
on_elevenlabs_test_speak = Event()
16 changes: 16 additions & 0 deletions src/custom_logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ def __init__(self, log_file):
formatter = logging.Formatter(_format)
self.setFormatter(formatter)

class RedirectSysLogger(object):
def __init__(self, logger, level):
self.logger = logger
self.level = level
self.linebuf = ''

def write(self, buf):
for line in buf.rstrip().splitlines():
if line.strip() and len(line.strip()) > 1:
self.logger.log(self.level, line.rstrip())

def flush(self):
pass

class CustomLogger:
def __init__(self, name):
self.logger = logging.getLogger(name)
Expand Down Expand Up @@ -65,3 +79,5 @@ def shutdown(self):
self.listener.stop()

logger = CustomLogger("ChatDND").logger
sys.stdout = RedirectSysLogger(logger, logging.INFO)
sys.stderr = RedirectSysLogger(logger, logging.ERROR)
3 changes: 2 additions & 1 deletion src/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from data.member import Member
from data.voices import Voice
from data.session import Session, SessionState

__all__ = [Member, Session, SessionState]
__all__ = [Member, Session, SessionState, Voice]
35 changes: 29 additions & 6 deletions src/data/member.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, Mapped, mapped_column
from sqlalchemy import String, Integer, JSON, asc
from sqlalchemy.orm import sessionmaker, Mapped, mapped_column, relationship
from sqlalchemy import String, Integer, JSON, asc, ForeignKey, null
from sqlalchemy.future import select

from custom_logger.logger import logger

from data.base import Base
from data.voices import Voice

from sqlalchemy.future import select
from db import async_session

from data.voices import fetch_voice

# We don't need to pass the DB object around after it's been initialized by main
# Simply import async_session from it, or objects made from it around


class Member(Base):
__tablename__ = "members"

id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
pfp_url: Mapped[str] = mapped_column(String, default="")
num_sessions: Mapped[int] = mapped_column(Integer, default=0) # increment on end/complete session
preferred_tts: Mapped[str] = mapped_column(String, default="") # local vs cloud, and what name for tts voice
preferred_tts_uid: Mapped[str] = mapped_column(ForeignKey("voices.uid"), nullable=True)
preferred_tts: Mapped[Voice] = relationship("Voice", lazy="subquery")

data: Mapped[dict] = mapped_column(JSON, default=dict) # We can store arbitrary data in here if we need extra columns and stuff later, just need to be safe with checking
# elsewise, we will need to setup alembic and migrations with an updater script/function/exe
Expand Down Expand Up @@ -71,15 +78,31 @@ async def _upsert_member(name: str, pfp_url: str) -> Member:
session.add(new_member)
return new_member

async def update_tts(member: Member, preferred_tts: str = ""):
async def update_tts(member: Member, voice_id: str):
async with async_session() as session:
async with session.begin():
voice_in_db = await fetch_voice(uid=voice_id)
member_in_db = await session.get(Member, member.id)
if member_in_db:
member_in_db.preferred_tts = preferred_tts
if member_in_db and voice_in_db:
member.preferred_tts_uid = voice_id
member.preferred_tts = voice_in_db

member_in_db.preferred_tts_uid = voice_id
member_in_db.preferred_tts = voice_in_db
await session.commit()


async def remove_tts(voice_id: str):
if not voice_id:
return
async with async_session() as session:
async with session.begin():
result = await session.execute(select(Member).where(Member.preferred_tts_uid == voice_id))
for member in result.scalars().all():
member.preferred_tts_uid = null()
await session.commit()


async def fetch_member(name: str) -> Member | None:
name = name.lower()
async with async_session() as session:
Expand Down
133 changes: 133 additions & 0 deletions src/data/voices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, Mapped, mapped_column
from sqlalchemy import String, Integer, asc, CheckConstraint
from sqlalchemy.future import select

from data.base import Base

from sqlalchemy.future import select
from db import async_session

from custom_logger.logger import logger
from helpers.constants import SOURCES, SOURCE_11L, SOURCE_LOCAL

class Voice(Base):
__tablename__ = "voices"

id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True, nullable=False) # Unique?
uid: Mapped[str] = mapped_column(String, unique=True, nullable=False)
source: Mapped[str] = mapped_column(String, default=SOURCE_LOCAL, nullable=False)

__table_args__ = (
CheckConstraint(
f"source IN ({', '.join(repr(value) for value in SOURCES)})",
name="chk_source_valid_val"
),
)

def __init__(self, name: str, uid: str, source: str):
if source not in SOURCES:
return None
self.name: str = name
self.uid: str = uid
self.source: str = source

def __eq__(self, other):
return self.name == other.name and self.uid == other.uid and self.source == other.source


def __hash__(self):
return hash(f"{self.name}{self.uid}{self.source}")


def __repr__(self):
return f"Voice(name='{self.name}', uid='{self.uid}', source='{self.source}')"

def __lt__(self, other):
return self.name < other.name

def __gt__(self, other):
return self.name > other.name


async def _upsert_voice(name: str, uid: str, source: str) -> Voice | None:
if source not in SOURCES:
return None
async with async_session() as session:
async with session.begin():
query = select(Voice).where(Voice.name == name).where(Voice.uid == uid).where(Voice.source == source)
result = await session.execute(query)
voice = result.scalars().first()

if voice:
return voice
else:
# Create new voice
new_voice = Voice(name=name, uid=uid, source=source)
session.add(new_voice)
return new_voice


async def delete_voice(uid: str = None, source: str = None) -> bool:
if not uid:
return False
async with async_session() as session:
voice = await fetch_voice(uid=uid, source=source)
if voice:
await session.delete(voice)
await session.commit()
return True
return False


async def fetch_voice(name: str = None, uid: str = None, source: str = None) -> Voice | None:
if not any([name, uid]):
return None
async with async_session() as session:
query = select(Voice)
if uid:
query = query.where(Voice.uid == uid)
if name:
query = query.where(Voice.name == name)
if source:
query = query.where(Voice.source == source)
result = await session.execute(query)
return result.scalars().first()


async def fetch_voices(source: str = None, limit: int = 100) -> list[Voice]:

async with async_session() as session:
query = select(Voice)
if source:
query = query.where(Voice.source == source)
query = query.limit(limit)
result = await session.execute(query)
res = result.scalars().all()
if not res and source == 'elevenlabs':
logger.info("Adding default ElevenLabs voice 'Will'")
v = await _upsert_voice(name="Will", uid="bIHbv24MWmeRgasZH58o", source=SOURCE_11L)
return [v]
return res


async def fetch_paginated_voices(page: int, per_page: int=20,
name_filter: str = None, filter_source: str = None) -> list[Voice]:
if not exclude_names:
exclude_names = []

async with async_session() as session:
query = select(Voice).order_by(asc(Voice.name))

if name_filter:
query = query.where(Voice.name.like(f"%{name_filter.lower()}%"))

if filter_source:
query = query.where(Voice.source == source)

offset = (page -1) * per_page
query = query.offset(offset).limit(per_page)

result = await session.execute(query)
return result.scalars().all()
5 changes: 4 additions & 1 deletion src/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
import os

DATABASE_URL = "sqlite+aiosqlite:///tcdnd_data.db"
engine = create_async_engine(DATABASE_URL, echo=False)
Expand All @@ -9,8 +10,10 @@
import logging
from data.base import Base


debug_mode = os.environ['TCDND_DEBUG_MODE'] == '1'
logging.getLogger("sqlalchemy.engine").setLevel(logging.DEBUG if debug_mode else logging.INFO)
logging.getLogger("sqlalchemy.engine").handlers = logger.handlers
logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO)

async def initialize_database():
#Initialize the database and create tables.#
Expand Down
10 changes: 10 additions & 0 deletions src/helpers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def setup(self, path: str):
if not self.has_option(section="CACHE", option="cache_expiry"):
needs_init = True
self.set(section="CACHE", option="cache_expiry", value=str(7*24*60*60))
if not self.has_option(section="CACHE", option="tts_cache_expiry"):
needs_init = True
self.set(section="CACHE", option="tts_cache_expiry", value=str(7*24*60*60*4*3))

if not self.has_section("TWITCH"):
needs_init = True
Expand All @@ -60,6 +63,13 @@ def setup(self, path: str):
needs_init = True
self.set(section="DND", option="party_size", value=str(4))

if not self.has_section("ELEVENLABS"):
needs_init = True
self.add_section("ELEVENLABS")
if not self.has_option(section="ELEVENLABS", option="api_key"):
needs_init = True
self.set(section="ELEVENLABS", option="api_key", value='')

if needs_init:
self.write_updates()

Expand Down
8 changes: 8 additions & 0 deletions src/helpers/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

SOURCE_LOCAL = "local"
SOURCE_11L = "elevenlabs"

SOURCES = [
SOURCE_LOCAL,
SOURCE_11L
]
33 changes: 32 additions & 1 deletion src/helpers/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import os, sys
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Coroutine, TypeVar

T = TypeVar("T")

def get_resource_path(relative_path, from_resources: bool = False):
# Get absolute path to a resource, frozen or local (relative to helpers/utils.py)
Expand All @@ -13,4 +19,29 @@ def get_resource_path(relative_path, from_resources: bool = False):
base_path = os.path.dirname(os.path.abspath(__file__))
else:
base_path = os.path.dirname(os.path.abspath(__file__))
return os.path.join(base_path, relative_path)
return os.path.join(base_path, relative_path)


def run_coroutine_sync(coroutine: Coroutine[Any, Any, T], timeout: float = 30) -> T:
def run_in_new_loop():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(coroutine)
finally:
new_loop.close()

try:
loop = asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coroutine)

if threading.current_thread() is threading.main_thread():
if not loop.is_running():
return loop.run_until_complete(coroutine)
else:
with ThreadPoolExecutor() as pool:
future = pool.submit(run_in_new_loop)
return future.result(timeout=timeout)
else:
return asyncio.run_coroutine_threadsafe(coroutine, loop).result()
Loading

0 comments on commit 181e3fd

Please sign in to comment.