Add tts_rust_server.py example

This commit is contained in:
Václav Volhejn 2025-07-02 22:01:01 +02:00
parent da83b4b63f
commit df36dfd918
3 changed files with 198 additions and 1 deletions

View File

@ -222,7 +222,35 @@ and just prefix the command above with `uvx --with moshi`.
### Rust server
Example coming soon.
The Rust implementation provides a server that can process multiple streaming
queries in parallel.
In order to run the server, install the [moshi-server
crate](https://crates.io/crates/moshi-server) via the following command. The
server code can be found in the
[kyutai-labs/moshi](https://github.com/kyutai-labs/moshi/tree/main/rust/moshi-server)
repository.
```bash
cargo install --features cuda moshi-server
```
Then the server can be started via the following command using the config file
from this repository.
```bash
moshi-server worker --config configs/config-tts.toml
```
Once the server has started you can connect to it using our script as follows:
```bash
# From stdin, plays audio immediately
echo "Hey, how are you?" | python scripts/tts_rust_server.py - -
# From text file to audio file
python scripts/tts_rust_server.py text_to_say.txt audio_output.wav
```
### MLX implementation

20
configs/config-tts.toml Normal file
View File

@ -0,0 +1,20 @@
static_dir = "./static/"
log_dir = "$HOME/tmp/tts-logs"
instance_name = "tts"
authorized_ids = ["public_token"]
[modules.tts_py]
type = "Py"
path = "/api/tts_streaming"
text_tokenizer_file = "hf://kyutai/unmute/test_en_fr_audio_8000.model"
batch_size = 8 # Adjust to your GPU memory capacity
text_bos_token = 1
[modules.tts_py.py]
log_folder = "$HOME/tmp/moshi-server-logs"
voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors"
default_voice = "unmute-prod-website/default_voice.wav"
cfg_coef = 2.0
cfg_is_no_text = true
padding_between = 1
n_q = 24

149
scripts/tts_rust_server.py Normal file
View File

@ -0,0 +1,149 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "msgpack",
# "numpy",
# "sphn",
# "websockets",
# "sounddevice",
# "tqdm",
# ]
# ///
import argparse
import asyncio
import sys
from urllib.parse import urlencode
import msgpack
import numpy as np
import sounddevice as sd
import sphn
import tqdm
import websockets
SAMPLE_RATE = 24000
TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."
DEFAULT_DSM_TTS_REPO = "kyutai/tts-1.6b-en_fr"
DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices"
AUTH_TOKEN = "public_token"
async def receive_messages(websocket: websockets.ClientConnection, output_queue):
with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar:
accumulated_samples = 0
last_seconds = 0
async for message_bytes in websocket:
msg = msgpack.unpackb(message_bytes)
if msg["type"] == "Audio":
pcm = np.array(msg["pcm"]).astype(np.float32)
await output_queue.put(pcm)
accumulated_samples += len(msg["pcm"])
current_seconds = accumulated_samples // SAMPLE_RATE
if current_seconds > last_seconds:
pbar.update(current_seconds - last_seconds)
last_seconds = current_seconds
print("End of audio.")
await output_queue.put(None) # Signal end of audio
async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]):
if out == "-":
should_exit = False
def audio_callback(outdata, _a, _b, _c):
nonlocal should_exit
try:
pcm_data = output_queue.get_nowait()
if pcm_data is not None:
outdata[:, 0] = pcm_data
else:
should_exit = True
except asyncio.QueueEmpty:
outdata[:] = 0
with sd.OutputStream(
samplerate=SAMPLE_RATE,
blocksize=1920,
channels=1,
callback=audio_callback,
):
while True:
if should_exit:
break
await asyncio.sleep(1)
else:
frames = []
while True:
item = await output_queue.get()
if item is None:
break
frames.append(item)
sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE)
print(f"Saved audio to {out}")
async def websocket_client():
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice",
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
help="The voice to use, relative to the voice repo root. "
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
)
parser.add_argument(
"--url",
help="The URL of the server to which to send the audio",
default="ws://127.0.0.1:8080",
)
parser.add_argument("--api-key", default="public_token")
args = parser.parse_args()
params = {"voice": args.voice, "format": "PcmMessagePack"}
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
print(uri)
# TODO: stream the text instead of sending it all at once
if args.inp == "-":
if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
text_to_tts = sys.stdin.read().strip()
else:
with open(args.inp, "r") as fobj:
text_to_tts = fobj.read().strip()
headers = {"kyutai-api-key": args.api_key}
async with websockets.connect(uri, additional_headers=headers) as websocket:
await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts}))
await websocket.send(msgpack.packb({"type": "Eos"}))
output_queue = asyncio.Queue()
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
await asyncio.gather(receive_task, output_audio_task)
if __name__ == "__main__":
asyncio.run(websocket_client())