Add tts_rust_server.py example
This commit is contained in:
parent
da83b4b63f
commit
df36dfd918
30
README.md
30
README.md
|
|
@ -222,7 +222,35 @@ and just prefix the command above with `uvx --with moshi`.
|
||||||
|
|
||||||
### Rust server
|
### Rust server
|
||||||
|
|
||||||
Example coming soon.
|
|
||||||
|
The Rust implementation provides a server that can process multiple streaming
|
||||||
|
queries in parallel.
|
||||||
|
|
||||||
|
In order to run the server, install the [moshi-server
|
||||||
|
crate](https://crates.io/crates/moshi-server) via the following command. The
|
||||||
|
server code can be found in the
|
||||||
|
[kyutai-labs/moshi](https://github.com/kyutai-labs/moshi/tree/main/rust/moshi-server)
|
||||||
|
repository.
|
||||||
|
```bash
|
||||||
|
cargo install --features cuda moshi-server
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Then the server can be started via the following command using the config file
|
||||||
|
from this repository.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
moshi-server worker --config configs/config-tts.toml
|
||||||
|
```
|
||||||
|
|
||||||
|
Once the server has started you can connect to it using our script as follows:
|
||||||
|
```bash
|
||||||
|
# From stdin, plays audio immediately
|
||||||
|
echo "Hey, how are you?" | python scripts/tts_rust_server.py - -
|
||||||
|
|
||||||
|
# From text file to audio file
|
||||||
|
python scripts/tts_rust_server.py text_to_say.txt audio_output.wav
|
||||||
|
```
|
||||||
|
|
||||||
### MLX implementation
|
### MLX implementation
|
||||||
|
|
||||||
|
|
|
||||||
20
configs/config-tts.toml
Normal file
20
configs/config-tts.toml
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
static_dir = "./static/"
|
||||||
|
log_dir = "$HOME/tmp/tts-logs"
|
||||||
|
instance_name = "tts"
|
||||||
|
authorized_ids = ["public_token"]
|
||||||
|
|
||||||
|
[modules.tts_py]
|
||||||
|
type = "Py"
|
||||||
|
path = "/api/tts_streaming"
|
||||||
|
text_tokenizer_file = "hf://kyutai/unmute/test_en_fr_audio_8000.model"
|
||||||
|
batch_size = 8 # Adjust to your GPU memory capacity
|
||||||
|
text_bos_token = 1
|
||||||
|
|
||||||
|
[modules.tts_py.py]
|
||||||
|
log_folder = "$HOME/tmp/moshi-server-logs"
|
||||||
|
voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors"
|
||||||
|
default_voice = "unmute-prod-website/default_voice.wav"
|
||||||
|
cfg_coef = 2.0
|
||||||
|
cfg_is_no_text = true
|
||||||
|
padding_between = 1
|
||||||
|
n_q = 24
|
||||||
149
scripts/tts_rust_server.py
Normal file
149
scripts/tts_rust_server.py
Normal file
|
|
@ -0,0 +1,149 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "msgpack",
|
||||||
|
# "numpy",
|
||||||
|
# "sphn",
|
||||||
|
# "websockets",
|
||||||
|
# "sounddevice",
|
||||||
|
# "tqdm",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
import sounddevice as sd
|
||||||
|
import sphn
|
||||||
|
import tqdm
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
SAMPLE_RATE = 24000
|
||||||
|
|
||||||
|
TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."
|
||||||
|
DEFAULT_DSM_TTS_REPO = "kyutai/tts-1.6b-en_fr"
|
||||||
|
DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices"
|
||||||
|
AUTH_TOKEN = "public_token"
|
||||||
|
|
||||||
|
|
||||||
|
async def receive_messages(websocket: websockets.ClientConnection, output_queue):
|
||||||
|
with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar:
|
||||||
|
accumulated_samples = 0
|
||||||
|
last_seconds = 0
|
||||||
|
|
||||||
|
async for message_bytes in websocket:
|
||||||
|
msg = msgpack.unpackb(message_bytes)
|
||||||
|
|
||||||
|
if msg["type"] == "Audio":
|
||||||
|
pcm = np.array(msg["pcm"]).astype(np.float32)
|
||||||
|
await output_queue.put(pcm)
|
||||||
|
|
||||||
|
accumulated_samples += len(msg["pcm"])
|
||||||
|
current_seconds = accumulated_samples // SAMPLE_RATE
|
||||||
|
if current_seconds > last_seconds:
|
||||||
|
pbar.update(current_seconds - last_seconds)
|
||||||
|
last_seconds = current_seconds
|
||||||
|
|
||||||
|
print("End of audio.")
|
||||||
|
await output_queue.put(None) # Signal end of audio
|
||||||
|
|
||||||
|
|
||||||
|
async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]):
|
||||||
|
if out == "-":
|
||||||
|
should_exit = False
|
||||||
|
|
||||||
|
def audio_callback(outdata, _a, _b, _c):
|
||||||
|
nonlocal should_exit
|
||||||
|
|
||||||
|
try:
|
||||||
|
pcm_data = output_queue.get_nowait()
|
||||||
|
if pcm_data is not None:
|
||||||
|
outdata[:, 0] = pcm_data
|
||||||
|
else:
|
||||||
|
should_exit = True
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
outdata[:] = 0
|
||||||
|
|
||||||
|
with sd.OutputStream(
|
||||||
|
samplerate=SAMPLE_RATE,
|
||||||
|
blocksize=1920,
|
||||||
|
channels=1,
|
||||||
|
callback=audio_callback,
|
||||||
|
):
|
||||||
|
while True:
|
||||||
|
if should_exit:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
else:
|
||||||
|
frames = []
|
||||||
|
while True:
|
||||||
|
item = await output_queue.get()
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
frames.append(item)
|
||||||
|
|
||||||
|
sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE)
|
||||||
|
print(f"Saved audio to {out}")
|
||||||
|
|
||||||
|
|
||||||
|
async def websocket_client():
|
||||||
|
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
|
||||||
|
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
|
||||||
|
parser.add_argument(
|
||||||
|
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-repo",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_DSM_TTS_REPO,
|
||||||
|
help="HF repo in which to look for the pretrained models.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice-repo",
|
||||||
|
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice",
|
||||||
|
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
||||||
|
help="The voice to use, relative to the voice repo root. "
|
||||||
|
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--url",
|
||||||
|
help="The URL of the server to which to send the audio",
|
||||||
|
default="ws://127.0.0.1:8080",
|
||||||
|
)
|
||||||
|
parser.add_argument("--api-key", default="public_token")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = {"voice": args.voice, "format": "PcmMessagePack"}
|
||||||
|
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
|
||||||
|
print(uri)
|
||||||
|
|
||||||
|
# TODO: stream the text instead of sending it all at once
|
||||||
|
if args.inp == "-":
|
||||||
|
if sys.stdin.isatty(): # Interactive
|
||||||
|
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||||
|
text_to_tts = sys.stdin.read().strip()
|
||||||
|
else:
|
||||||
|
with open(args.inp, "r") as fobj:
|
||||||
|
text_to_tts = fobj.read().strip()
|
||||||
|
|
||||||
|
headers = {"kyutai-api-key": args.api_key}
|
||||||
|
|
||||||
|
async with websockets.connect(uri, additional_headers=headers) as websocket:
|
||||||
|
await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts}))
|
||||||
|
await websocket.send(msgpack.packb({"type": "Eos"}))
|
||||||
|
|
||||||
|
output_queue = asyncio.Queue()
|
||||||
|
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
|
||||||
|
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
|
||||||
|
await asyncio.gather(receive_task, output_audio_task)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(websocket_client())
|
||||||
Loading…
Reference in New Issue
Block a user