Merge stuff.
This commit is contained in:
commit
9ba4e88553
30
README.md
30
README.md
|
|
@ -239,7 +239,35 @@ and just prefix the command above with `uvx --with moshi`.
|
|||
<details>
|
||||
<summary>Rust server</summary>
|
||||
|
||||
Example coming soon.
|
||||
|
||||
The Rust implementation provides a server that can process multiple streaming
|
||||
queries in parallel.
|
||||
|
||||
In order to run the server, install the [moshi-server
|
||||
crate](https://crates.io/crates/moshi-server) via the following command. The
|
||||
server code can be found in the
|
||||
[kyutai-labs/moshi](https://github.com/kyutai-labs/moshi/tree/main/rust/moshi-server)
|
||||
repository.
|
||||
```bash
|
||||
cargo install --features cuda moshi-server
|
||||
```
|
||||
|
||||
|
||||
Then the server can be started via the following command using the config file
|
||||
from this repository.
|
||||
|
||||
```bash
|
||||
moshi-server worker --config configs/config-tts.toml
|
||||
```
|
||||
|
||||
Once the server has started you can connect to it using our script as follows:
|
||||
```bash
|
||||
# From stdin, plays audio immediately
|
||||
echo "Hey, how are you?" | python scripts/tts_rust_server.py - -
|
||||
|
||||
# From text file to audio file
|
||||
python scripts/tts_rust_server.py text_to_say.txt audio_output.wav
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
|
|
|||
20
configs/config-tts.toml
Normal file
20
configs/config-tts.toml
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
static_dir = "./static/"
|
||||
log_dir = "$HOME/tmp/tts-logs"
|
||||
instance_name = "tts"
|
||||
authorized_ids = ["public_token"]
|
||||
|
||||
[modules.tts_py]
|
||||
type = "Py"
|
||||
path = "/api/tts_streaming"
|
||||
text_tokenizer_file = "hf://kyutai/unmute/test_en_fr_audio_8000.model"
|
||||
batch_size = 8 # Adjust to your GPU memory capacity
|
||||
text_bos_token = 1
|
||||
|
||||
[modules.tts_py.py]
|
||||
log_folder = "$HOME/tmp/moshi-server-logs"
|
||||
voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors"
|
||||
default_voice = "unmute-prod-website/default_voice.wav"
|
||||
cfg_coef = 2.0
|
||||
cfg_is_no_text = true
|
||||
padding_between = 1
|
||||
n_q = 24
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "huggingface_hub",
|
||||
# "moshi_mlx>=0.2.8",
|
||||
# "moshi_mlx @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi_mlx&subdirectory=moshi_mlx",
|
||||
# "numpy",
|
||||
# "sounddevice",
|
||||
# ]
|
||||
|
|
@ -14,19 +14,20 @@ import queue
|
|||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import sentencepiece
|
||||
import sphn
|
||||
import time
|
||||
|
||||
import sounddevice as sd
|
||||
|
||||
from moshi_mlx.client_utils import make_log
|
||||
import sphn
|
||||
from moshi_mlx import models
|
||||
from moshi_mlx.client_utils import make_log
|
||||
from moshi_mlx.models.tts import (
|
||||
DEFAULT_DSM_TTS_REPO,
|
||||
DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
TTSModel,
|
||||
)
|
||||
from moshi_mlx.utils.loaders import hf_get
|
||||
from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
|
||||
|
||||
|
||||
def log(level: str, msg: str):
|
||||
|
|
@ -34,15 +35,32 @@ def log(level: str, msg: str):
|
|||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Kyutai TTS using the PyTorch implementation"
|
||||
)
|
||||
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
|
||||
parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio")
|
||||
parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO,
|
||||
help="HF repo in which to look for the pretrained models.")
|
||||
parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
help="HF repo in which to look for pre-computed voice embeddings.")
|
||||
parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav")
|
||||
parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.")
|
||||
parser.add_argument(
|
||||
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-repo",
|
||||
type=str,
|
||||
default=DEFAULT_DSM_TTS_REPO,
|
||||
help="HF repo in which to look for the pretrained models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice-repo",
|
||||
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
type=int,
|
||||
help="The quantization to be applied, e.g. 8 for 8 bits.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(299792458)
|
||||
|
|
@ -96,7 +114,7 @@ def main():
|
|||
if tts_model.valid_cfg_conditionings:
|
||||
# Model was trained with CFG distillation.
|
||||
cfg_coef_conditioning = tts_model.cfg_coef
|
||||
tts_model.cfg_coef = 1.
|
||||
tts_model.cfg_coef = 1.0
|
||||
cfg_is_no_text = False
|
||||
cfg_is_no_prefix = False
|
||||
else:
|
||||
|
|
@ -118,13 +136,16 @@ def main():
|
|||
voices = [tts_model.get_voice_path(args.voice)]
|
||||
else:
|
||||
voices = []
|
||||
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
|
||||
all_attributes = [
|
||||
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
|
||||
]
|
||||
|
||||
wav_frames = queue.Queue()
|
||||
def _on_audio_hook(audio_tokens):
|
||||
if (audio_tokens == -1).any():
|
||||
|
||||
def _on_frame(frame):
|
||||
if (frame == -1).any():
|
||||
return
|
||||
_pcm = tts_model.mimi.decode_step(audio_tokens[None, :, None])
|
||||
_pcm = tts_model.mimi.decode_step(frame[:, :, None])
|
||||
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
|
||||
wav_frames.put_nowait(_pcm)
|
||||
|
||||
|
|
@ -136,7 +157,7 @@ def main():
|
|||
all_attributes,
|
||||
cfg_is_no_prefix=cfg_is_no_prefix,
|
||||
cfg_is_no_text=cfg_is_no_text,
|
||||
on_audio_hook=_on_audio_hook,
|
||||
on_frame=_on_frame,
|
||||
)
|
||||
frames = mx.concat(result.frames, axis=-1)
|
||||
total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate
|
||||
|
|
@ -146,16 +167,20 @@ def main():
|
|||
return result
|
||||
|
||||
if args.out == "-":
|
||||
|
||||
def audio_callback(outdata, _a, _b, _c):
|
||||
try:
|
||||
pcm_data = wav_frames.get(block=False)
|
||||
outdata[:, 0] = pcm_data
|
||||
except queue.Empty:
|
||||
outdata[:] = 0
|
||||
with sd.OutputStream(samplerate=mimi.sample_rate,
|
||||
blocksize=1920,
|
||||
channels=1,
|
||||
callback=audio_callback):
|
||||
|
||||
with sd.OutputStream(
|
||||
samplerate=mimi.sample_rate,
|
||||
blocksize=1920,
|
||||
channels=1,
|
||||
callback=audio_callback,
|
||||
):
|
||||
run()
|
||||
time.sleep(3)
|
||||
while True:
|
||||
|
|
@ -163,6 +188,7 @@ def main():
|
|||
break
|
||||
time.sleep(1)
|
||||
else:
|
||||
run()
|
||||
frames = []
|
||||
while True:
|
||||
try:
|
||||
|
|
|
|||
138
scripts/tts_rust_server.py
Normal file
138
scripts/tts_rust_server.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
# /// script
|
||||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "msgpack",
|
||||
# "numpy",
|
||||
# "sphn",
|
||||
# "websockets",
|
||||
# "sounddevice",
|
||||
# "tqdm",
|
||||
# ]
|
||||
# ///
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
import sphn
|
||||
import tqdm
|
||||
import websockets
|
||||
|
||||
SAMPLE_RATE = 24000
|
||||
|
||||
TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."
|
||||
DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices"
|
||||
AUTH_TOKEN = "public_token"
|
||||
|
||||
|
||||
async def receive_messages(websocket: websockets.ClientConnection, output_queue):
|
||||
with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar:
|
||||
accumulated_samples = 0
|
||||
last_seconds = 0
|
||||
|
||||
async for message_bytes in websocket:
|
||||
msg = msgpack.unpackb(message_bytes)
|
||||
|
||||
if msg["type"] == "Audio":
|
||||
pcm = np.array(msg["pcm"]).astype(np.float32)
|
||||
await output_queue.put(pcm)
|
||||
|
||||
accumulated_samples += len(msg["pcm"])
|
||||
current_seconds = accumulated_samples // SAMPLE_RATE
|
||||
if current_seconds > last_seconds:
|
||||
pbar.update(current_seconds - last_seconds)
|
||||
last_seconds = current_seconds
|
||||
|
||||
print("End of audio.")
|
||||
await output_queue.put(None) # Signal end of audio
|
||||
|
||||
|
||||
async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]):
|
||||
if out == "-":
|
||||
should_exit = False
|
||||
|
||||
def audio_callback(outdata, _a, _b, _c):
|
||||
nonlocal should_exit
|
||||
|
||||
try:
|
||||
pcm_data = output_queue.get_nowait()
|
||||
if pcm_data is not None:
|
||||
outdata[:, 0] = pcm_data
|
||||
else:
|
||||
should_exit = True
|
||||
outdata[:] = 0
|
||||
except asyncio.QueueEmpty:
|
||||
outdata[:] = 0
|
||||
|
||||
with sd.OutputStream(
|
||||
samplerate=SAMPLE_RATE,
|
||||
blocksize=1920,
|
||||
channels=1,
|
||||
callback=audio_callback,
|
||||
):
|
||||
while True:
|
||||
if should_exit:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
frames = []
|
||||
while True:
|
||||
item = await output_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
frames.append(item)
|
||||
|
||||
sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE)
|
||||
print(f"Saved audio to {out}")
|
||||
|
||||
|
||||
async def websocket_client():
|
||||
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
|
||||
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
|
||||
parser.add_argument(
|
||||
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice",
|
||||
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
||||
help="The voice to use, relative to the voice repo root. "
|
||||
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
help="The URL of the server to which to send the audio",
|
||||
default="ws://127.0.0.1:8080",
|
||||
)
|
||||
parser.add_argument("--api-key", default="public_token")
|
||||
args = parser.parse_args()
|
||||
|
||||
params = {"voice": args.voice, "format": "PcmMessagePack"}
|
||||
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
|
||||
print(uri)
|
||||
|
||||
# TODO: stream the text instead of sending it all at once
|
||||
if args.inp == "-":
|
||||
if sys.stdin.isatty(): # Interactive
|
||||
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||
text_to_tts = sys.stdin.read().strip()
|
||||
else:
|
||||
with open(args.inp, "r") as fobj:
|
||||
text_to_tts = fobj.read().strip()
|
||||
|
||||
headers = {"kyutai-api-key": args.api_key}
|
||||
|
||||
async with websockets.connect(uri, additional_headers=headers) as websocket:
|
||||
await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts}))
|
||||
await websocket.send(msgpack.packb({"type": "Eos"}))
|
||||
|
||||
output_queue = asyncio.Queue()
|
||||
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
|
||||
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
|
||||
await asyncio.gather(receive_task, output_audio_task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(websocket_client())
|
||||
Loading…
Reference in New Issue
Block a user