diff --git a/README.md b/README.md index 364d6c2..e571542 100644 --- a/README.md +++ b/README.md @@ -222,7 +222,35 @@ and just prefix the command above with `uvx --with moshi`. ### Rust server -Example coming soon. + +The Rust implementation provides a server that can process multiple streaming +queries in parallel. + +In order to run the server, install the [moshi-server +crate](https://crates.io/crates/moshi-server) via the following command. The +server code can be found in the +[kyutai-labs/moshi](https://github.com/kyutai-labs/moshi/tree/main/rust/moshi-server) +repository. +```bash +cargo install --features cuda moshi-server +``` + + +Then the server can be started via the following command using the config file +from this repository. + +```bash +moshi-server worker --config configs/config-tts.toml +``` + +Once the server has started you can connect to it using our script as follows: +```bash +# From stdin, plays audio immediately +echo "Hey, how are you?" | python scripts/tts_rust_server.py - - + +# From text file to audio file +python scripts/tts_rust_server.py text_to_say.txt audio_output.wav +``` ### MLX implementation diff --git a/configs/config-tts.toml b/configs/config-tts.toml new file mode 100644 index 0000000..829fc98 --- /dev/null +++ b/configs/config-tts.toml @@ -0,0 +1,20 @@ +static_dir = "./static/" +log_dir = "$HOME/tmp/tts-logs" +instance_name = "tts" +authorized_ids = ["public_token"] + +[modules.tts_py] +type = "Py" +path = "/api/tts_streaming" +text_tokenizer_file = "hf://kyutai/unmute/test_en_fr_audio_8000.model" +batch_size = 8 # Adjust to your GPU memory capacity +text_bos_token = 1 + +[modules.tts_py.py] +log_folder = "$HOME/tmp/moshi-server-logs" +voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors" +default_voice = "unmute-prod-website/default_voice.wav" +cfg_coef = 2.0 +cfg_is_no_text = true +padding_between = 1 +n_q = 24 diff --git a/scripts/tts_rust_server.py b/scripts/tts_rust_server.py new file mode 100644 index 0000000..8a18439 --- /dev/null +++ b/scripts/tts_rust_server.py @@ -0,0 +1,149 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "msgpack", +# "numpy", +# "sphn", +# "websockets", +# "sounddevice", +# "tqdm", +# ] +# /// +import argparse +import asyncio +import sys +from urllib.parse import urlencode + +import msgpack +import numpy as np +import sounddevice as sd +import sphn +import tqdm +import websockets + +SAMPLE_RATE = 24000 + +TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice." +DEFAULT_DSM_TTS_REPO = "kyutai/tts-1.6b-en_fr" +DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices" +AUTH_TOKEN = "public_token" + + +async def receive_messages(websocket: websockets.ClientConnection, output_queue): + with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar: + accumulated_samples = 0 + last_seconds = 0 + + async for message_bytes in websocket: + msg = msgpack.unpackb(message_bytes) + + if msg["type"] == "Audio": + pcm = np.array(msg["pcm"]).astype(np.float32) + await output_queue.put(pcm) + + accumulated_samples += len(msg["pcm"]) + current_seconds = accumulated_samples // SAMPLE_RATE + if current_seconds > last_seconds: + pbar.update(current_seconds - last_seconds) + last_seconds = current_seconds + + print("End of audio.") + await output_queue.put(None) # Signal end of audio + + +async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]): + if out == "-": + should_exit = False + + def audio_callback(outdata, _a, _b, _c): + nonlocal should_exit + + try: + pcm_data = output_queue.get_nowait() + if pcm_data is not None: + outdata[:, 0] = pcm_data + else: + should_exit = True + except asyncio.QueueEmpty: + outdata[:] = 0 + + with sd.OutputStream( + samplerate=SAMPLE_RATE, + blocksize=1920, + channels=1, + callback=audio_callback, + ): + while True: + if should_exit: + break + await asyncio.sleep(1) + else: + frames = [] + while True: + item = await output_queue.get() + if item is None: + break + frames.append(item) + + sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE) + print(f"Saved audio to {out}") + + +async def websocket_client(): + parser = argparse.ArgumentParser(description="Use the TTS streaming API") + parser.add_argument("inp", type=str, help="Input file, use - for stdin.") + parser.add_argument( + "out", type=str, help="Output file to generate, use - for playing the audio" + ) + parser.add_argument( + "--hf-repo", + type=str, + default=DEFAULT_DSM_TTS_REPO, + help="HF repo in which to look for the pretrained models.", + ) + parser.add_argument( + "--voice-repo", + default=DEFAULT_DSM_TTS_VOICE_REPO, + help="HF repo in which to look for pre-computed voice embeddings.", + ) + parser.add_argument( + "--voice", + default="expresso/ex03-ex01_happy_001_channel1_334s.wav", + help="The voice to use, relative to the voice repo root. " + f"See {DEFAULT_DSM_TTS_VOICE_REPO}", + ) + parser.add_argument( + "--url", + help="The URL of the server to which to send the audio", + default="ws://127.0.0.1:8080", + ) + parser.add_argument("--api-key", default="public_token") + args = parser.parse_args() + + params = {"voice": args.voice, "format": "PcmMessagePack"} + uri = f"{args.url}/api/tts_streaming?{urlencode(params)}" + print(uri) + + # TODO: stream the text instead of sending it all at once + if args.inp == "-": + if sys.stdin.isatty(): # Interactive + print("Enter text to synthesize (Ctrl+D to end input):") + text_to_tts = sys.stdin.read().strip() + else: + with open(args.inp, "r") as fobj: + text_to_tts = fobj.read().strip() + + headers = {"kyutai-api-key": args.api_key} + + async with websockets.connect(uri, additional_headers=headers) as websocket: + await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts})) + await websocket.send(msgpack.packb({"type": "Eos"})) + + output_queue = asyncio.Queue() + receive_task = asyncio.create_task(receive_messages(websocket, output_queue)) + output_audio_task = asyncio.create_task(output_audio(args.out, output_queue)) + await asyncio.gather(receive_task, output_audio_task) + + +if __name__ == "__main__": + asyncio.run(websocket_client())