diff --git a/README.md b/README.md
index 7954612..d2f325a 100644
--- a/README.md
+++ b/README.md
@@ -239,7 +239,35 @@ and just prefix the command above with `uvx --with moshi`.
Rust server
-Example coming soon.
+
+The Rust implementation provides a server that can process multiple streaming
+queries in parallel.
+
+In order to run the server, install the [moshi-server
+crate](https://crates.io/crates/moshi-server) via the following command. The
+server code can be found in the
+[kyutai-labs/moshi](https://github.com/kyutai-labs/moshi/tree/main/rust/moshi-server)
+repository.
+```bash
+cargo install --features cuda moshi-server
+```
+
+
+Then the server can be started via the following command using the config file
+from this repository.
+
+```bash
+moshi-server worker --config configs/config-tts.toml
+```
+
+Once the server has started you can connect to it using our script as follows:
+```bash
+# From stdin, plays audio immediately
+echo "Hey, how are you?" | python scripts/tts_rust_server.py - -
+
+# From text file to audio file
+python scripts/tts_rust_server.py text_to_say.txt audio_output.wav
+```
diff --git a/configs/config-tts.toml b/configs/config-tts.toml
new file mode 100644
index 0000000..829fc98
--- /dev/null
+++ b/configs/config-tts.toml
@@ -0,0 +1,20 @@
+static_dir = "./static/"
+log_dir = "$HOME/tmp/tts-logs"
+instance_name = "tts"
+authorized_ids = ["public_token"]
+
+[modules.tts_py]
+type = "Py"
+path = "/api/tts_streaming"
+text_tokenizer_file = "hf://kyutai/unmute/test_en_fr_audio_8000.model"
+batch_size = 8 # Adjust to your GPU memory capacity
+text_bos_token = 1
+
+[modules.tts_py.py]
+log_folder = "$HOME/tmp/moshi-server-logs"
+voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors"
+default_voice = "unmute-prod-website/default_voice.wav"
+cfg_coef = 2.0
+cfg_is_no_text = true
+padding_between = 1
+n_q = 24
diff --git a/scripts/tts_mlx.py b/scripts/tts_mlx.py
index 70dca65..e59f962 100644
--- a/scripts/tts_mlx.py
+++ b/scripts/tts_mlx.py
@@ -2,7 +2,7 @@
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
-# "moshi_mlx>=0.2.8",
+# "moshi_mlx @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi_mlx&subdirectory=moshi_mlx",
# "numpy",
# "sounddevice",
# ]
@@ -14,19 +14,20 @@ import queue
import sys
import time
-import numpy as np
import mlx.core as mx
import mlx.nn as nn
+import numpy as np
import sentencepiece
-import sphn
-import time
-
import sounddevice as sd
-
-from moshi_mlx.client_utils import make_log
+import sphn
from moshi_mlx import models
+from moshi_mlx.client_utils import make_log
+from moshi_mlx.models.tts import (
+ DEFAULT_DSM_TTS_REPO,
+ DEFAULT_DSM_TTS_VOICE_REPO,
+ TTSModel,
+)
from moshi_mlx.utils.loaders import hf_get
-from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
def log(level: str, msg: str):
@@ -34,15 +35,32 @@ def log(level: str, msg: str):
def main():
- parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi')
+ parser = argparse.ArgumentParser(
+ description="Run Kyutai TTS using the PyTorch implementation"
+ )
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
- parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio")
- parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO,
- help="HF repo in which to look for the pretrained models.")
- parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO,
- help="HF repo in which to look for pre-computed voice embeddings.")
- parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav")
- parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.")
+ parser.add_argument(
+ "out", type=str, help="Output file to generate, use - for playing the audio"
+ )
+ parser.add_argument(
+ "--hf-repo",
+ type=str,
+ default=DEFAULT_DSM_TTS_REPO,
+ help="HF repo in which to look for the pretrained models.",
+ )
+ parser.add_argument(
+ "--voice-repo",
+ default=DEFAULT_DSM_TTS_VOICE_REPO,
+ help="HF repo in which to look for pre-computed voice embeddings.",
+ )
+ parser.add_argument(
+ "--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
+ )
+ parser.add_argument(
+ "--quantize",
+ type=int,
+ help="The quantization to be applied, e.g. 8 for 8 bits.",
+ )
args = parser.parse_args()
mx.random.seed(299792458)
@@ -96,7 +114,7 @@ def main():
if tts_model.valid_cfg_conditionings:
# Model was trained with CFG distillation.
cfg_coef_conditioning = tts_model.cfg_coef
- tts_model.cfg_coef = 1.
+ tts_model.cfg_coef = 1.0
cfg_is_no_text = False
cfg_is_no_prefix = False
else:
@@ -118,13 +136,16 @@ def main():
voices = [tts_model.get_voice_path(args.voice)]
else:
voices = []
- all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
+ all_attributes = [
+ tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
+ ]
wav_frames = queue.Queue()
- def _on_audio_hook(audio_tokens):
- if (audio_tokens == -1).any():
+
+ def _on_frame(frame):
+ if (frame == -1).any():
return
- _pcm = tts_model.mimi.decode_step(audio_tokens[None, :, None])
+ _pcm = tts_model.mimi.decode_step(frame[:, :, None])
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
wav_frames.put_nowait(_pcm)
@@ -136,7 +157,7 @@ def main():
all_attributes,
cfg_is_no_prefix=cfg_is_no_prefix,
cfg_is_no_text=cfg_is_no_text,
- on_audio_hook=_on_audio_hook,
+ on_frame=_on_frame,
)
frames = mx.concat(result.frames, axis=-1)
total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate
@@ -146,16 +167,20 @@ def main():
return result
if args.out == "-":
+
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = wav_frames.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
- with sd.OutputStream(samplerate=mimi.sample_rate,
- blocksize=1920,
- channels=1,
- callback=audio_callback):
+
+ with sd.OutputStream(
+ samplerate=mimi.sample_rate,
+ blocksize=1920,
+ channels=1,
+ callback=audio_callback,
+ ):
run()
time.sleep(3)
while True:
@@ -163,6 +188,7 @@ def main():
break
time.sleep(1)
else:
+ run()
frames = []
while True:
try:
diff --git a/scripts/tts_rust_server.py b/scripts/tts_rust_server.py
new file mode 100644
index 0000000..9b67dc8
--- /dev/null
+++ b/scripts/tts_rust_server.py
@@ -0,0 +1,138 @@
+# /// script
+# requires-python = ">=3.12"
+# dependencies = [
+# "msgpack",
+# "numpy",
+# "sphn",
+# "websockets",
+# "sounddevice",
+# "tqdm",
+# ]
+# ///
+import argparse
+import asyncio
+import sys
+from urllib.parse import urlencode
+
+import msgpack
+import numpy as np
+import sounddevice as sd
+import sphn
+import tqdm
+import websockets
+
+SAMPLE_RATE = 24000
+
+TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."
+DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices"
+AUTH_TOKEN = "public_token"
+
+
+async def receive_messages(websocket: websockets.ClientConnection, output_queue):
+ with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar:
+ accumulated_samples = 0
+ last_seconds = 0
+
+ async for message_bytes in websocket:
+ msg = msgpack.unpackb(message_bytes)
+
+ if msg["type"] == "Audio":
+ pcm = np.array(msg["pcm"]).astype(np.float32)
+ await output_queue.put(pcm)
+
+ accumulated_samples += len(msg["pcm"])
+ current_seconds = accumulated_samples // SAMPLE_RATE
+ if current_seconds > last_seconds:
+ pbar.update(current_seconds - last_seconds)
+ last_seconds = current_seconds
+
+ print("End of audio.")
+ await output_queue.put(None) # Signal end of audio
+
+
+async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]):
+ if out == "-":
+ should_exit = False
+
+ def audio_callback(outdata, _a, _b, _c):
+ nonlocal should_exit
+
+ try:
+ pcm_data = output_queue.get_nowait()
+ if pcm_data is not None:
+ outdata[:, 0] = pcm_data
+ else:
+ should_exit = True
+ outdata[:] = 0
+ except asyncio.QueueEmpty:
+ outdata[:] = 0
+
+ with sd.OutputStream(
+ samplerate=SAMPLE_RATE,
+ blocksize=1920,
+ channels=1,
+ callback=audio_callback,
+ ):
+ while True:
+ if should_exit:
+ break
+ await asyncio.sleep(1)
+ else:
+ frames = []
+ while True:
+ item = await output_queue.get()
+ if item is None:
+ break
+ frames.append(item)
+
+ sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE)
+ print(f"Saved audio to {out}")
+
+
+async def websocket_client():
+ parser = argparse.ArgumentParser(description="Use the TTS streaming API")
+ parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
+ parser.add_argument(
+ "out", type=str, help="Output file to generate, use - for playing the audio"
+ )
+ parser.add_argument(
+ "--voice",
+ default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
+ help="The voice to use, relative to the voice repo root. "
+ f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
+ )
+ parser.add_argument(
+ "--url",
+ help="The URL of the server to which to send the audio",
+ default="ws://127.0.0.1:8080",
+ )
+ parser.add_argument("--api-key", default="public_token")
+ args = parser.parse_args()
+
+ params = {"voice": args.voice, "format": "PcmMessagePack"}
+ uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
+ print(uri)
+
+ # TODO: stream the text instead of sending it all at once
+ if args.inp == "-":
+ if sys.stdin.isatty(): # Interactive
+ print("Enter text to synthesize (Ctrl+D to end input):")
+ text_to_tts = sys.stdin.read().strip()
+ else:
+ with open(args.inp, "r") as fobj:
+ text_to_tts = fobj.read().strip()
+
+ headers = {"kyutai-api-key": args.api_key}
+
+ async with websockets.connect(uri, additional_headers=headers) as websocket:
+ await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts}))
+ await websocket.send(msgpack.packb({"type": "Eos"}))
+
+ output_queue = asyncio.Queue()
+ receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
+ output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
+ await asyncio.gather(receive_task, output_audio_task)
+
+
+if __name__ == "__main__":
+ asyncio.run(websocket_client())