diff --git a/README.md b/README.md index 7954612..d2f325a 100644 --- a/README.md +++ b/README.md @@ -239,7 +239,35 @@ and just prefix the command above with `uvx --with moshi`.
Rust server -Example coming soon. + +The Rust implementation provides a server that can process multiple streaming +queries in parallel. + +In order to run the server, install the [moshi-server +crate](https://crates.io/crates/moshi-server) via the following command. The +server code can be found in the +[kyutai-labs/moshi](https://github.com/kyutai-labs/moshi/tree/main/rust/moshi-server) +repository. +```bash +cargo install --features cuda moshi-server +``` + + +Then the server can be started via the following command using the config file +from this repository. + +```bash +moshi-server worker --config configs/config-tts.toml +``` + +Once the server has started you can connect to it using our script as follows: +```bash +# From stdin, plays audio immediately +echo "Hey, how are you?" | python scripts/tts_rust_server.py - - + +# From text file to audio file +python scripts/tts_rust_server.py text_to_say.txt audio_output.wav +```
diff --git a/configs/config-tts.toml b/configs/config-tts.toml new file mode 100644 index 0000000..829fc98 --- /dev/null +++ b/configs/config-tts.toml @@ -0,0 +1,20 @@ +static_dir = "./static/" +log_dir = "$HOME/tmp/tts-logs" +instance_name = "tts" +authorized_ids = ["public_token"] + +[modules.tts_py] +type = "Py" +path = "/api/tts_streaming" +text_tokenizer_file = "hf://kyutai/unmute/test_en_fr_audio_8000.model" +batch_size = 8 # Adjust to your GPU memory capacity +text_bos_token = 1 + +[modules.tts_py.py] +log_folder = "$HOME/tmp/moshi-server-logs" +voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors" +default_voice = "unmute-prod-website/default_voice.wav" +cfg_coef = 2.0 +cfg_is_no_text = true +padding_between = 1 +n_q = 24 diff --git a/scripts/tts_mlx.py b/scripts/tts_mlx.py index 9d89295..e59f962 100644 --- a/scripts/tts_mlx.py +++ b/scripts/tts_mlx.py @@ -14,19 +14,20 @@ import queue import sys import time -import numpy as np import mlx.core as mx import mlx.nn as nn +import numpy as np import sentencepiece -import sphn -import time - import sounddevice as sd - -from moshi_mlx.client_utils import make_log +import sphn from moshi_mlx import models +from moshi_mlx.client_utils import make_log +from moshi_mlx.models.tts import ( + DEFAULT_DSM_TTS_REPO, + DEFAULT_DSM_TTS_VOICE_REPO, + TTSModel, +) from moshi_mlx.utils.loaders import hf_get -from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO def log(level: str, msg: str): @@ -34,15 +35,32 @@ def log(level: str, msg: str): def main(): - parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi') + parser = argparse.ArgumentParser( + description="Run Kyutai TTS using the PyTorch implementation" + ) parser.add_argument("inp", type=str, help="Input file, use - for stdin") - parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio") - parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO, - help="HF repo in which to look for the pretrained models.") - parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO, - help="HF repo in which to look for pre-computed voice embeddings.") - parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav") - parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.") + parser.add_argument( + "out", type=str, help="Output file to generate, use - for playing the audio" + ) + parser.add_argument( + "--hf-repo", + type=str, + default=DEFAULT_DSM_TTS_REPO, + help="HF repo in which to look for the pretrained models.", + ) + parser.add_argument( + "--voice-repo", + default=DEFAULT_DSM_TTS_VOICE_REPO, + help="HF repo in which to look for pre-computed voice embeddings.", + ) + parser.add_argument( + "--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav" + ) + parser.add_argument( + "--quantize", + type=int, + help="The quantization to be applied, e.g. 8 for 8 bits.", + ) args = parser.parse_args() mx.random.seed(299792458) @@ -96,7 +114,7 @@ def main(): if tts_model.valid_cfg_conditionings: # Model was trained with CFG distillation. cfg_coef_conditioning = tts_model.cfg_coef - tts_model.cfg_coef = 1. + tts_model.cfg_coef = 1.0 cfg_is_no_text = False cfg_is_no_prefix = False else: @@ -118,9 +136,12 @@ def main(): voices = [tts_model.get_voice_path(args.voice)] else: voices = [] - all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)] + all_attributes = [ + tts_model.make_condition_attributes(voices, cfg_coef_conditioning) + ] wav_frames = queue.Queue() + def _on_frame(frame): if (frame == -1).any(): return @@ -146,16 +167,20 @@ def main(): return result if args.out == "-": + def audio_callback(outdata, _a, _b, _c): try: pcm_data = wav_frames.get(block=False) outdata[:, 0] = pcm_data except queue.Empty: outdata[:] = 0 - with sd.OutputStream(samplerate=mimi.sample_rate, - blocksize=1920, - channels=1, - callback=audio_callback): + + with sd.OutputStream( + samplerate=mimi.sample_rate, + blocksize=1920, + channels=1, + callback=audio_callback, + ): run() time.sleep(3) while True: diff --git a/scripts/tts_rust_server.py b/scripts/tts_rust_server.py new file mode 100644 index 0000000..9b67dc8 --- /dev/null +++ b/scripts/tts_rust_server.py @@ -0,0 +1,138 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "msgpack", +# "numpy", +# "sphn", +# "websockets", +# "sounddevice", +# "tqdm", +# ] +# /// +import argparse +import asyncio +import sys +from urllib.parse import urlencode + +import msgpack +import numpy as np +import sounddevice as sd +import sphn +import tqdm +import websockets + +SAMPLE_RATE = 24000 + +TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice." +DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices" +AUTH_TOKEN = "public_token" + + +async def receive_messages(websocket: websockets.ClientConnection, output_queue): + with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar: + accumulated_samples = 0 + last_seconds = 0 + + async for message_bytes in websocket: + msg = msgpack.unpackb(message_bytes) + + if msg["type"] == "Audio": + pcm = np.array(msg["pcm"]).astype(np.float32) + await output_queue.put(pcm) + + accumulated_samples += len(msg["pcm"]) + current_seconds = accumulated_samples // SAMPLE_RATE + if current_seconds > last_seconds: + pbar.update(current_seconds - last_seconds) + last_seconds = current_seconds + + print("End of audio.") + await output_queue.put(None) # Signal end of audio + + +async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]): + if out == "-": + should_exit = False + + def audio_callback(outdata, _a, _b, _c): + nonlocal should_exit + + try: + pcm_data = output_queue.get_nowait() + if pcm_data is not None: + outdata[:, 0] = pcm_data + else: + should_exit = True + outdata[:] = 0 + except asyncio.QueueEmpty: + outdata[:] = 0 + + with sd.OutputStream( + samplerate=SAMPLE_RATE, + blocksize=1920, + channels=1, + callback=audio_callback, + ): + while True: + if should_exit: + break + await asyncio.sleep(1) + else: + frames = [] + while True: + item = await output_queue.get() + if item is None: + break + frames.append(item) + + sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE) + print(f"Saved audio to {out}") + + +async def websocket_client(): + parser = argparse.ArgumentParser(description="Use the TTS streaming API") + parser.add_argument("inp", type=str, help="Input file, use - for stdin.") + parser.add_argument( + "out", type=str, help="Output file to generate, use - for playing the audio" + ) + parser.add_argument( + "--voice", + default="expresso/ex03-ex01_happy_001_channel1_334s.wav", + help="The voice to use, relative to the voice repo root. " + f"See {DEFAULT_DSM_TTS_VOICE_REPO}", + ) + parser.add_argument( + "--url", + help="The URL of the server to which to send the audio", + default="ws://127.0.0.1:8080", + ) + parser.add_argument("--api-key", default="public_token") + args = parser.parse_args() + + params = {"voice": args.voice, "format": "PcmMessagePack"} + uri = f"{args.url}/api/tts_streaming?{urlencode(params)}" + print(uri) + + # TODO: stream the text instead of sending it all at once + if args.inp == "-": + if sys.stdin.isatty(): # Interactive + print("Enter text to synthesize (Ctrl+D to end input):") + text_to_tts = sys.stdin.read().strip() + else: + with open(args.inp, "r") as fobj: + text_to_tts = fobj.read().strip() + + headers = {"kyutai-api-key": args.api_key} + + async with websockets.connect(uri, additional_headers=headers) as websocket: + await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts})) + await websocket.send(msgpack.packb({"type": "Eos"})) + + output_queue = asyncio.Queue() + receive_task = asyncio.create_task(receive_messages(websocket, output_queue)) + output_audio_task = asyncio.create_task(output_audio(args.out, output_queue)) + await asyncio.gather(receive_task, output_audio_task) + + +if __name__ == "__main__": + asyncio.run(websocket_client())