Merge stuff.
This commit is contained in:
commit
9ba4e88553
30
README.md
30
README.md
|
|
@ -239,7 +239,35 @@ and just prefix the command above with `uvx --with moshi`.
|
||||||
<details>
|
<details>
|
||||||
<summary>Rust server</summary>
|
<summary>Rust server</summary>
|
||||||
|
|
||||||
Example coming soon.
|
|
||||||
|
The Rust implementation provides a server that can process multiple streaming
|
||||||
|
queries in parallel.
|
||||||
|
|
||||||
|
In order to run the server, install the [moshi-server
|
||||||
|
crate](https://crates.io/crates/moshi-server) via the following command. The
|
||||||
|
server code can be found in the
|
||||||
|
[kyutai-labs/moshi](https://github.com/kyutai-labs/moshi/tree/main/rust/moshi-server)
|
||||||
|
repository.
|
||||||
|
```bash
|
||||||
|
cargo install --features cuda moshi-server
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Then the server can be started via the following command using the config file
|
||||||
|
from this repository.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
moshi-server worker --config configs/config-tts.toml
|
||||||
|
```
|
||||||
|
|
||||||
|
Once the server has started you can connect to it using our script as follows:
|
||||||
|
```bash
|
||||||
|
# From stdin, plays audio immediately
|
||||||
|
echo "Hey, how are you?" | python scripts/tts_rust_server.py - -
|
||||||
|
|
||||||
|
# From text file to audio file
|
||||||
|
python scripts/tts_rust_server.py text_to_say.txt audio_output.wav
|
||||||
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|
|
||||||
20
configs/config-tts.toml
Normal file
20
configs/config-tts.toml
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
static_dir = "./static/"
|
||||||
|
log_dir = "$HOME/tmp/tts-logs"
|
||||||
|
instance_name = "tts"
|
||||||
|
authorized_ids = ["public_token"]
|
||||||
|
|
||||||
|
[modules.tts_py]
|
||||||
|
type = "Py"
|
||||||
|
path = "/api/tts_streaming"
|
||||||
|
text_tokenizer_file = "hf://kyutai/unmute/test_en_fr_audio_8000.model"
|
||||||
|
batch_size = 8 # Adjust to your GPU memory capacity
|
||||||
|
text_bos_token = 1
|
||||||
|
|
||||||
|
[modules.tts_py.py]
|
||||||
|
log_folder = "$HOME/tmp/moshi-server-logs"
|
||||||
|
voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors"
|
||||||
|
default_voice = "unmute-prod-website/default_voice.wav"
|
||||||
|
cfg_coef = 2.0
|
||||||
|
cfg_is_no_text = true
|
||||||
|
padding_between = 1
|
||||||
|
n_q = 24
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
# requires-python = ">=3.12"
|
# requires-python = ">=3.12"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "huggingface_hub",
|
# "huggingface_hub",
|
||||||
# "moshi_mlx>=0.2.8",
|
# "moshi_mlx @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi_mlx&subdirectory=moshi_mlx",
|
||||||
# "numpy",
|
# "numpy",
|
||||||
# "sounddevice",
|
# "sounddevice",
|
||||||
# ]
|
# ]
|
||||||
|
|
@ -14,19 +14,20 @@ import queue
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
import sentencepiece
|
import sentencepiece
|
||||||
import sphn
|
|
||||||
import time
|
|
||||||
|
|
||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
|
import sphn
|
||||||
from moshi_mlx.client_utils import make_log
|
|
||||||
from moshi_mlx import models
|
from moshi_mlx import models
|
||||||
|
from moshi_mlx.client_utils import make_log
|
||||||
|
from moshi_mlx.models.tts import (
|
||||||
|
DEFAULT_DSM_TTS_REPO,
|
||||||
|
DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
TTSModel,
|
||||||
|
)
|
||||||
from moshi_mlx.utils.loaders import hf_get
|
from moshi_mlx.utils.loaders import hf_get
|
||||||
from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
|
|
||||||
|
|
||||||
|
|
||||||
def log(level: str, msg: str):
|
def log(level: str, msg: str):
|
||||||
|
|
@ -34,15 +35,32 @@ def log(level: str, msg: str):
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi')
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run Kyutai TTS using the PyTorch implementation"
|
||||||
|
)
|
||||||
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
|
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
|
||||||
parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio")
|
parser.add_argument(
|
||||||
parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO,
|
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||||
help="HF repo in which to look for the pretrained models.")
|
)
|
||||||
parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO,
|
parser.add_argument(
|
||||||
help="HF repo in which to look for pre-computed voice embeddings.")
|
"--hf-repo",
|
||||||
parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav")
|
type=str,
|
||||||
parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.")
|
default=DEFAULT_DSM_TTS_REPO,
|
||||||
|
help="HF repo in which to look for the pretrained models.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice-repo",
|
||||||
|
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize",
|
||||||
|
type=int,
|
||||||
|
help="The quantization to be applied, e.g. 8 for 8 bits.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
mx.random.seed(299792458)
|
mx.random.seed(299792458)
|
||||||
|
|
@ -96,7 +114,7 @@ def main():
|
||||||
if tts_model.valid_cfg_conditionings:
|
if tts_model.valid_cfg_conditionings:
|
||||||
# Model was trained with CFG distillation.
|
# Model was trained with CFG distillation.
|
||||||
cfg_coef_conditioning = tts_model.cfg_coef
|
cfg_coef_conditioning = tts_model.cfg_coef
|
||||||
tts_model.cfg_coef = 1.
|
tts_model.cfg_coef = 1.0
|
||||||
cfg_is_no_text = False
|
cfg_is_no_text = False
|
||||||
cfg_is_no_prefix = False
|
cfg_is_no_prefix = False
|
||||||
else:
|
else:
|
||||||
|
|
@ -118,13 +136,16 @@ def main():
|
||||||
voices = [tts_model.get_voice_path(args.voice)]
|
voices = [tts_model.get_voice_path(args.voice)]
|
||||||
else:
|
else:
|
||||||
voices = []
|
voices = []
|
||||||
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
|
all_attributes = [
|
||||||
|
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
|
||||||
|
]
|
||||||
|
|
||||||
wav_frames = queue.Queue()
|
wav_frames = queue.Queue()
|
||||||
def _on_audio_hook(audio_tokens):
|
|
||||||
if (audio_tokens == -1).any():
|
def _on_frame(frame):
|
||||||
|
if (frame == -1).any():
|
||||||
return
|
return
|
||||||
_pcm = tts_model.mimi.decode_step(audio_tokens[None, :, None])
|
_pcm = tts_model.mimi.decode_step(frame[:, :, None])
|
||||||
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
|
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
|
||||||
wav_frames.put_nowait(_pcm)
|
wav_frames.put_nowait(_pcm)
|
||||||
|
|
||||||
|
|
@ -136,7 +157,7 @@ def main():
|
||||||
all_attributes,
|
all_attributes,
|
||||||
cfg_is_no_prefix=cfg_is_no_prefix,
|
cfg_is_no_prefix=cfg_is_no_prefix,
|
||||||
cfg_is_no_text=cfg_is_no_text,
|
cfg_is_no_text=cfg_is_no_text,
|
||||||
on_audio_hook=_on_audio_hook,
|
on_frame=_on_frame,
|
||||||
)
|
)
|
||||||
frames = mx.concat(result.frames, axis=-1)
|
frames = mx.concat(result.frames, axis=-1)
|
||||||
total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate
|
total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate
|
||||||
|
|
@ -146,16 +167,20 @@ def main():
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if args.out == "-":
|
if args.out == "-":
|
||||||
|
|
||||||
def audio_callback(outdata, _a, _b, _c):
|
def audio_callback(outdata, _a, _b, _c):
|
||||||
try:
|
try:
|
||||||
pcm_data = wav_frames.get(block=False)
|
pcm_data = wav_frames.get(block=False)
|
||||||
outdata[:, 0] = pcm_data
|
outdata[:, 0] = pcm_data
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
outdata[:] = 0
|
outdata[:] = 0
|
||||||
with sd.OutputStream(samplerate=mimi.sample_rate,
|
|
||||||
|
with sd.OutputStream(
|
||||||
|
samplerate=mimi.sample_rate,
|
||||||
blocksize=1920,
|
blocksize=1920,
|
||||||
channels=1,
|
channels=1,
|
||||||
callback=audio_callback):
|
callback=audio_callback,
|
||||||
|
):
|
||||||
run()
|
run()
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -163,6 +188,7 @@ def main():
|
||||||
break
|
break
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
else:
|
else:
|
||||||
|
run()
|
||||||
frames = []
|
frames = []
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
138
scripts/tts_rust_server.py
Normal file
138
scripts/tts_rust_server.py
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "msgpack",
|
||||||
|
# "numpy",
|
||||||
|
# "sphn",
|
||||||
|
# "websockets",
|
||||||
|
# "sounddevice",
|
||||||
|
# "tqdm",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
import sounddevice as sd
|
||||||
|
import sphn
|
||||||
|
import tqdm
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
SAMPLE_RATE = 24000
|
||||||
|
|
||||||
|
TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."
|
||||||
|
DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices"
|
||||||
|
AUTH_TOKEN = "public_token"
|
||||||
|
|
||||||
|
|
||||||
|
async def receive_messages(websocket: websockets.ClientConnection, output_queue):
|
||||||
|
with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar:
|
||||||
|
accumulated_samples = 0
|
||||||
|
last_seconds = 0
|
||||||
|
|
||||||
|
async for message_bytes in websocket:
|
||||||
|
msg = msgpack.unpackb(message_bytes)
|
||||||
|
|
||||||
|
if msg["type"] == "Audio":
|
||||||
|
pcm = np.array(msg["pcm"]).astype(np.float32)
|
||||||
|
await output_queue.put(pcm)
|
||||||
|
|
||||||
|
accumulated_samples += len(msg["pcm"])
|
||||||
|
current_seconds = accumulated_samples // SAMPLE_RATE
|
||||||
|
if current_seconds > last_seconds:
|
||||||
|
pbar.update(current_seconds - last_seconds)
|
||||||
|
last_seconds = current_seconds
|
||||||
|
|
||||||
|
print("End of audio.")
|
||||||
|
await output_queue.put(None) # Signal end of audio
|
||||||
|
|
||||||
|
|
||||||
|
async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]):
|
||||||
|
if out == "-":
|
||||||
|
should_exit = False
|
||||||
|
|
||||||
|
def audio_callback(outdata, _a, _b, _c):
|
||||||
|
nonlocal should_exit
|
||||||
|
|
||||||
|
try:
|
||||||
|
pcm_data = output_queue.get_nowait()
|
||||||
|
if pcm_data is not None:
|
||||||
|
outdata[:, 0] = pcm_data
|
||||||
|
else:
|
||||||
|
should_exit = True
|
||||||
|
outdata[:] = 0
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
outdata[:] = 0
|
||||||
|
|
||||||
|
with sd.OutputStream(
|
||||||
|
samplerate=SAMPLE_RATE,
|
||||||
|
blocksize=1920,
|
||||||
|
channels=1,
|
||||||
|
callback=audio_callback,
|
||||||
|
):
|
||||||
|
while True:
|
||||||
|
if should_exit:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
else:
|
||||||
|
frames = []
|
||||||
|
while True:
|
||||||
|
item = await output_queue.get()
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
frames.append(item)
|
||||||
|
|
||||||
|
sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE)
|
||||||
|
print(f"Saved audio to {out}")
|
||||||
|
|
||||||
|
|
||||||
|
async def websocket_client():
|
||||||
|
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
|
||||||
|
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
|
||||||
|
parser.add_argument(
|
||||||
|
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice",
|
||||||
|
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
||||||
|
help="The voice to use, relative to the voice repo root. "
|
||||||
|
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--url",
|
||||||
|
help="The URL of the server to which to send the audio",
|
||||||
|
default="ws://127.0.0.1:8080",
|
||||||
|
)
|
||||||
|
parser.add_argument("--api-key", default="public_token")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = {"voice": args.voice, "format": "PcmMessagePack"}
|
||||||
|
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
|
||||||
|
print(uri)
|
||||||
|
|
||||||
|
# TODO: stream the text instead of sending it all at once
|
||||||
|
if args.inp == "-":
|
||||||
|
if sys.stdin.isatty(): # Interactive
|
||||||
|
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||||
|
text_to_tts = sys.stdin.read().strip()
|
||||||
|
else:
|
||||||
|
with open(args.inp, "r") as fobj:
|
||||||
|
text_to_tts = fobj.read().strip()
|
||||||
|
|
||||||
|
headers = {"kyutai-api-key": args.api_key}
|
||||||
|
|
||||||
|
async with websockets.connect(uri, additional_headers=headers) as websocket:
|
||||||
|
await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts}))
|
||||||
|
await websocket.send(msgpack.packb({"type": "Eos"}))
|
||||||
|
|
||||||
|
output_queue = asyncio.Queue()
|
||||||
|
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
|
||||||
|
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
|
||||||
|
await asyncio.gather(receive_task, output_audio_task)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(websocket_client())
|
||||||
Loading…
Reference in New Issue
Block a user