Add Rust server usage example (#32)

* Run Ruff on tts_mlx.py

* Add tts_rust_server.py example

* Remove unused HF repo arguments and reset audio output data in TTS server script
This commit is contained in:
Václav Volhejn 2025-07-03 09:47:50 +02:00 committed by GitHub
parent d92e4c2695
commit ef52b8ef0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 233 additions and 22 deletions

View File

@ -239,7 +239,35 @@ and just prefix the command above with `uvx --with moshi`.
<details>
<summary>Rust server</summary>
Example coming soon.
The Rust implementation provides a server that can process multiple streaming
queries in parallel.
In order to run the server, install the [moshi-server
crate](https://crates.io/crates/moshi-server) via the following command. The
server code can be found in the
[kyutai-labs/moshi](https://github.com/kyutai-labs/moshi/tree/main/rust/moshi-server)
repository.
```bash
cargo install --features cuda moshi-server
```
Then the server can be started via the following command using the config file
from this repository.
```bash
moshi-server worker --config configs/config-tts.toml
```
Once the server has started you can connect to it using our script as follows:
```bash
# From stdin, plays audio immediately
echo "Hey, how are you?" | python scripts/tts_rust_server.py - -
# From text file to audio file
python scripts/tts_rust_server.py text_to_say.txt audio_output.wav
```
</details>
<details>

20
configs/config-tts.toml Normal file
View File

@ -0,0 +1,20 @@
static_dir = "./static/"
log_dir = "$HOME/tmp/tts-logs"
instance_name = "tts"
authorized_ids = ["public_token"]
[modules.tts_py]
type = "Py"
path = "/api/tts_streaming"
text_tokenizer_file = "hf://kyutai/unmute/test_en_fr_audio_8000.model"
batch_size = 8 # Adjust to your GPU memory capacity
text_bos_token = 1
[modules.tts_py.py]
log_folder = "$HOME/tmp/moshi-server-logs"
voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors"
default_voice = "unmute-prod-website/default_voice.wav"
cfg_coef = 2.0
cfg_is_no_text = true
padding_between = 1
n_q = 24

View File

@ -14,19 +14,20 @@ import queue
import sys
import time
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import sentencepiece
import sphn
import time
import sounddevice as sd
from moshi_mlx.client_utils import make_log
import sphn
from moshi_mlx import models
from moshi_mlx.client_utils import make_log
from moshi_mlx.models.tts import (
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
)
from moshi_mlx.utils.loaders import hf_get
from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
def log(level: str, msg: str):
@ -34,15 +35,32 @@ def log(level: str, msg: str):
def main():
parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi')
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
)
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio")
parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.")
parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.")
parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav")
parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.")
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
)
parser.add_argument(
"--quantize",
type=int,
help="The quantization to be applied, e.g. 8 for 8 bits.",
)
args = parser.parse_args()
mx.random.seed(299792458)
@ -96,7 +114,7 @@ def main():
if tts_model.valid_cfg_conditionings:
# Model was trained with CFG distillation.
cfg_coef_conditioning = tts_model.cfg_coef
tts_model.cfg_coef = 1.
tts_model.cfg_coef = 1.0
cfg_is_no_text = False
cfg_is_no_prefix = False
else:
@ -118,9 +136,12 @@ def main():
voices = [tts_model.get_voice_path(args.voice)]
else:
voices = []
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
all_attributes = [
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
]
wav_frames = queue.Queue()
def _on_frame(frame):
if (frame == -1).any():
return
@ -146,16 +167,20 @@ def main():
return result
if args.out == "-":
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = wav_frames.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(samplerate=mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback):
with sd.OutputStream(
samplerate=mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
):
run()
time.sleep(3)
while True:

138
scripts/tts_rust_server.py Normal file
View File

@ -0,0 +1,138 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "msgpack",
# "numpy",
# "sphn",
# "websockets",
# "sounddevice",
# "tqdm",
# ]
# ///
import argparse
import asyncio
import sys
from urllib.parse import urlencode
import msgpack
import numpy as np
import sounddevice as sd
import sphn
import tqdm
import websockets
SAMPLE_RATE = 24000
TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."
DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices"
AUTH_TOKEN = "public_token"
async def receive_messages(websocket: websockets.ClientConnection, output_queue):
with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar:
accumulated_samples = 0
last_seconds = 0
async for message_bytes in websocket:
msg = msgpack.unpackb(message_bytes)
if msg["type"] == "Audio":
pcm = np.array(msg["pcm"]).astype(np.float32)
await output_queue.put(pcm)
accumulated_samples += len(msg["pcm"])
current_seconds = accumulated_samples // SAMPLE_RATE
if current_seconds > last_seconds:
pbar.update(current_seconds - last_seconds)
last_seconds = current_seconds
print("End of audio.")
await output_queue.put(None) # Signal end of audio
async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]):
if out == "-":
should_exit = False
def audio_callback(outdata, _a, _b, _c):
nonlocal should_exit
try:
pcm_data = output_queue.get_nowait()
if pcm_data is not None:
outdata[:, 0] = pcm_data
else:
should_exit = True
outdata[:] = 0
except asyncio.QueueEmpty:
outdata[:] = 0
with sd.OutputStream(
samplerate=SAMPLE_RATE,
blocksize=1920,
channels=1,
callback=audio_callback,
):
while True:
if should_exit:
break
await asyncio.sleep(1)
else:
frames = []
while True:
item = await output_queue.get()
if item is None:
break
frames.append(item)
sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE)
print(f"Saved audio to {out}")
async def websocket_client():
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--voice",
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
help="The voice to use, relative to the voice repo root. "
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
)
parser.add_argument(
"--url",
help="The URL of the server to which to send the audio",
default="ws://127.0.0.1:8080",
)
parser.add_argument("--api-key", default="public_token")
args = parser.parse_args()
params = {"voice": args.voice, "format": "PcmMessagePack"}
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
print(uri)
# TODO: stream the text instead of sending it all at once
if args.inp == "-":
if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
text_to_tts = sys.stdin.read().strip()
else:
with open(args.inp, "r") as fobj:
text_to_tts = fobj.read().strip()
headers = {"kyutai-api-key": args.api_key}
async with websockets.connect(uri, additional_headers=headers) as websocket:
await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts}))
await websocket.send(msgpack.packb({"type": "Eos"}))
output_queue = asyncio.Queue()
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
await asyncio.gather(receive_task, output_audio_task)
if __name__ == "__main__":
asyncio.run(websocket_client())