Skip to content

Commit 1f60e9d

Browse files
committed
feat: add tts functionality
1 parent b5ae573 commit 1f60e9d

File tree

2 files changed

+261
-3
lines changed

2 files changed

+261
-3
lines changed

go2_robot_sdk/go2_robot_sdk/tts.py

+257
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
#!/usr/bin/env python3
2+
3+
# BSD 3-Clause License
4+
#
5+
# Copyright (c) 2024, The RoboVerse community
6+
# All rights reserved.
7+
#
8+
# Redistribution and use in source and binary forms, with or without
9+
# modification, are permitted provided that the following conditions are met:
10+
#
11+
# * Redistributions of source code must retain the above copyright notice, this
12+
# list of conditions and the following disclaimer.
13+
#
14+
# * Redistributions in binary form must reproduce the above copyright notice,
15+
# this list of conditions and the following disclaimer in the documentation
16+
# and/or other materials provided with the distribution.
17+
#
18+
# * Neither the name of the copyright holder nor the names of its
19+
# contributors may be used to endorse or promote products derived from
20+
# this software without specific prior written permission.
21+
#
22+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
23+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
25+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
26+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
27+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
28+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32+
33+
import base64
34+
from datetime import datetime
35+
import io
36+
import json
37+
import os
38+
import time
39+
40+
from pydub import AudioSegment
41+
from pydub.playback import play
42+
import rclpy
43+
from rclpy.node import Node
44+
import requests
45+
from std_msgs.msg import String
46+
from scripts.go2_constants import RTC_TOPIC
47+
from unitree_go.msg import WebRtcReq
48+
49+
# flake8: noqa: Q000
50+
51+
52+
class TTSNode(Node):
53+
54+
def __init__(self):
55+
super().__init__("tts_node")
56+
57+
# Initialize parameters
58+
self.declare_parameter("elevenlabs_api_key", "")
59+
self.declare_parameter("local_playback", False) # Default to robot playback
60+
61+
self.api_key = self.get_parameter("elevenlabs_api_key").value
62+
self.local_playback = self.get_parameter("local_playback").value
63+
64+
if not self.api_key:
65+
self.get_logger().error("ElevenLabs API key not provided!")
66+
return
67+
68+
# Create subscription for TTS requests
69+
self.subscription = self.create_subscription(
70+
String, "/tts", self.tts_callback, 10
71+
)
72+
73+
# Create publisher for robot audio hub requests
74+
self.audio_pub = self.create_publisher(WebRtcReq, "/webrtc_req", 10)
75+
76+
# Create output directory for wave files
77+
self.output_dir = "tts_output"
78+
os.makedirs(self.output_dir, exist_ok=True)
79+
80+
self.get_logger().info(
81+
f'TTS Node initialized ({"local" if self.local_playback else "robot"} playback)'
82+
)
83+
84+
def tts_callback(self, msg):
85+
"""Handle incoming TTS requests."""
86+
try:
87+
self.get_logger().info(
88+
f'Received TTS request: "{msg.data}" with voice: {msg}'
89+
)
90+
voice_name = 'XrExE9yKIg1WjnnlVkGX'
91+
92+
# Call ElevenLabs API
93+
audio_data = self.generate_speech(msg.data, voice_name)
94+
95+
if audio_data:
96+
# Save to WAV file
97+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
98+
filename = f"{self.output_dir}/tts_{timestamp}.wav"
99+
wav_data = self.save_wav(audio_data, filename)
100+
101+
if self.local_playback:
102+
# Play locally
103+
self.play_audio(audio_data)
104+
else:
105+
# Send to robot
106+
self.play_on_robot(wav_data)
107+
108+
self.get_logger().info(
109+
f"Successfully processed TTS request. Saved to {filename}"
110+
)
111+
else:
112+
self.get_logger().error("Failed to generate speech")
113+
114+
except Exception as e:
115+
self.get_logger().error(f"Error processing TTS request: {str(e)}")
116+
117+
def generate_speech(self, text, voice_name):
118+
"""Generate speech using ElevenLabs API."""
119+
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_name}"
120+
121+
headers = {
122+
"Accept": "audio/mpeg",
123+
"Content-Type": "application/json",
124+
"xi-api-key": self.api_key,
125+
}
126+
127+
data = {
128+
"text": text,
129+
"model_id": "eleven_turbo_v2_5",
130+
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
131+
}
132+
133+
try:
134+
response = requests.post(url, json=data, headers=headers)
135+
response.raise_for_status()
136+
return response.content
137+
138+
except requests.exceptions.RequestException as e:
139+
self.get_logger().error(f"API request failed: {str(e)}")
140+
return None
141+
142+
def save_wav(self, audio_data, filename):
143+
"""Save audio data to WAV file and return the WAV data."""
144+
try:
145+
# Convert MP3 to WAV
146+
audio = AudioSegment.from_mp3(io.BytesIO(audio_data))
147+
148+
# Export to file
149+
audio.export(filename, format="wav")
150+
self.get_logger().info(f"Saved WAV file: {filename}")
151+
152+
# Return WAV data
153+
wav_io = io.BytesIO()
154+
audio.export(wav_io, format="wav")
155+
return wav_io.getvalue()
156+
157+
except Exception as e:
158+
self.get_logger().error(f"Error saving WAV file: {str(e)}")
159+
return None
160+
161+
def play_audio(self, audio_data):
162+
"""Play audio locally using pydub."""
163+
try:
164+
audio = AudioSegment.from_mp3(io.BytesIO(audio_data))
165+
play(audio)
166+
except Exception as e:
167+
self.get_logger().error(f"Error playing audio: {str(e)}")
168+
169+
def split_into_chunks(self, data, chunk_size=16 * 1024):
170+
"""Split data into chunks of specified size."""
171+
return [
172+
data[i: i + chunk_size] for i in range(0, len(data), chunk_size) # noqa: E203
173+
]
174+
175+
def play_on_robot(self, wav_data):
176+
"""Send audio to robot's audio hub in chunks."""
177+
try:
178+
b64_encoded = base64.b64encode(wav_data).decode("utf-8")
179+
chunks = self.split_into_chunks(b64_encoded)
180+
total_chunks = len(chunks)
181+
182+
self.get_logger().info(f"Sending audio in {total_chunks} chunks")
183+
184+
# Start audio
185+
start_req = WebRtcReq()
186+
start_req.api_id = 4001
187+
start_req.priority = 0
188+
start_req.parameter = ""
189+
start_req.topic = RTC_TOPIC["AUDIO_HUB_REQ"]
190+
191+
self.audio_pub.publish(start_req)
192+
193+
time.sleep(0.1)
194+
195+
# Send WAV data in chunks
196+
for chunk_idx, chunk in enumerate(chunks, 1):
197+
wav_req = WebRtcReq()
198+
wav_req.api_id = 4003
199+
wav_req.priority = 0
200+
wav_req.topic = RTC_TOPIC["AUDIO_HUB_REQ"]
201+
202+
audio_block = {
203+
"current_block_index": chunk_idx,
204+
"total_block_number": total_chunks,
205+
"block_content": chunk,
206+
}
207+
wav_req.parameter = json.dumps(audio_block)
208+
209+
self.audio_pub.publish(wav_req)
210+
self.get_logger().info(
211+
f"Sent chunk {chunk_idx}/{total_chunks} ({len(chunk)} bytes)"
212+
)
213+
214+
# Add a small delay between chunks to prevent flooding
215+
# time.sleep(0.01)
216+
time.sleep(0.15)
217+
218+
# Wait until playback finished
219+
audio = AudioSegment.from_wav(io.BytesIO(wav_data))
220+
duration_ms = len(audio)
221+
duration_s = duration_ms / 1000.0
222+
223+
self.get_logger().info(
224+
f"Waiting for audio playback ({duration_s:.2f} seconds)..."
225+
)
226+
time.sleep(duration_s + 1)
227+
228+
# End audio
229+
end_req = WebRtcReq()
230+
end_req.api_id = 4002
231+
end_req.priority = 0
232+
end_req.parameter = ""
233+
end_req.topic = RTC_TOPIC["AUDIO_HUB_REQ"]
234+
235+
self.audio_pub.publish(end_req)
236+
237+
self.get_logger().info("Completed sending audio to robot")
238+
239+
except Exception as e:
240+
self.get_logger().error(f"Error sending audio to robot: {str(e)}")
241+
242+
243+
def main(args=None):
244+
rclpy.init(args=args)
245+
node = TTSNode()
246+
247+
try:
248+
rclpy.spin(node)
249+
except KeyboardInterrupt:
250+
pass
251+
finally:
252+
node.destroy_node()
253+
rclpy.shutdown()
254+
255+
256+
if __name__ == "__main__":
257+
main()

go2_robot_sdk/setup.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@
4545
(os.path.join('share', package_name, 'calibration'), glob(os.path.join('calibration', '*'))),
4646
(os.path.join('share', package_name, 'external_lib'), ['external_lib/libvoxel.wasm']),
4747
(os.path.join('share', package_name, 'external_lib/aioice'), glob(os.path.join('external_lib/aioice/src/aioice', '*'))),
48-
49-
48+
49+
5050
],
5151
install_requires=['setuptools'],
5252
zip_safe=True,
@@ -58,7 +58,8 @@
5858
entry_points={
5959
'console_scripts': [
6060
'go2_driver_node = go2_robot_sdk.go2_driver_node:main',
61-
'lidar_to_pointcloud = go2_robot_sdk.lidar_to_point:main'
61+
'lidar_to_pointcloud = go2_robot_sdk.lidar_to_point:main',
62+
'tts_node = go2_robot_sdk.tts:main'
6263
],
6364
},
6465
)

0 commit comments

Comments
 (0)