kyutai/scripts/tts_mlx_streaming.py

318 lines
10 KiB
Python
Raw Permalink Normal View History

# /// script
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
2025-07-19 13:57:53 +00:00
# "moshi_mlx==0.2.12",
# "numpy",
# "sounddevice",
# ]
# ///
import argparse
from dataclasses import dataclass
import json
import queue
import sys
import time
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import sentencepiece
import sounddevice as sd
import sphn
import typing as tp
from moshi_mlx import models
from moshi_mlx.models.generate import LmGen
from moshi_mlx.client_utils import make_log
from moshi_mlx.modules.conditioner import (
ConditionAttributes,
ConditionTensor,
dropout_all_conditions,
)
from moshi_mlx.utils.sampling import Sampler
from moshi_mlx.models.tts import (
Entry,
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
script_to_entries,
)
from moshi_mlx.utils.loaders import hf_get
def prepare_script(model: TTSModel, script: str, first_turn: bool) -> list[Entry]:
multi_speaker = first_turn and model.multi_speaker
return script_to_entries(
model.tokenizer,
model.machine.token_ids,
model.mimi.frame_rate,
[script],
multi_speaker=multi_speaker,
padding_between=1,
)
def _make_null(
all_attributes: tp.Sequence[ConditionAttributes],
) -> list[ConditionAttributes]:
# When using CFG, returns the null conditions.
return dropout_all_conditions(all_attributes)
@dataclass
class TTSGen:
tts_model: TTSModel
attributes: tp.Sequence[ConditionAttributes]
on_frame: tp.Optional[tp.Callable[[mx.array], None]] = None
def __post_init__(self):
tts_model = self.tts_model
attributes = self.attributes
self.offset = 0
self.state = self.tts_model.machine.new_state([])
if tts_model.cfg_coef != 1.0:
if tts_model.valid_cfg_conditionings:
raise ValueError(
"This model does not support direct CFG, but was trained with "
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
)
nulled = _make_null(attributes)
attributes = list(attributes) + nulled
assert tts_model.lm.condition_provider is not None
self.ct = None
self.cross_attention_src = None
for _attr in attributes:
for _key, _value in _attr.text.items():
_ct = tts_model.lm.condition_provider.condition_tensor(_key, _value)
if self.ct is None:
self.ct = _ct
else:
self.ct = ConditionTensor(self.ct.tensor + _ct.tensor)
for _key, _value in _attr.tensor.items():
_conditioner = tts_model.lm.condition_provider.conditioners[_key]
_ca_src = _conditioner.condition(_value)
if self.cross_attention_src is None:
self.cross_attention_src = _ca_src
else:
raise ValueError("multiple cross-attention conditioners")
def _on_audio_hook(audio_tokens):
delays = tts_model.lm.delays
for q in range(audio_tokens.shape[0]):
delay = delays[q]
if self.offset < delay + tts_model.delay_steps:
audio_tokens[q] = tts_model.machine.token_ids.zero
def _on_text_hook(text_tokens):
tokens = text_tokens.tolist()
out_tokens = []
for token in tokens:
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
out_tokens.append(out_token)
text_tokens[:] = mx.array(out_tokens, dtype=mx.int64)
self.lm_gen = LmGen(
tts_model.lm,
max_steps=tts_model.max_gen_length,
text_sampler=Sampler(temp=tts_model.temp),
audio_sampler=Sampler(temp=tts_model.temp),
cfg_coef=tts_model.cfg_coef,
on_text_hook=_on_text_hook,
on_audio_hook=_on_audio_hook,
# TODO(laurent):
# cfg_is_masked_until=cfg_is_masked_until,
# cfg_is_no_text=cfg_is_no_text,
)
def process_last(self):
while len(self.state.entries) > 0 or self.state.end_step is not None:
self._step()
additional_steps = (
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
)
for _ in range(additional_steps):
self._step()
def process(self):
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
self._step()
def _step(self):
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
input_tokens = (
mx.ones((1, missing), dtype=mx.int64)
* self.tts_model.machine.token_ids.zero
)
self.lm_gen.step(
input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src
)
frame = self.lm_gen.last_audio_tokens()
self.offset += 1
if frame is not None:
if self.on_frame is not None:
self.on_frame(frame)
def append_entry(self, entry):
self.state.entries.append(entry)
def log(level: str, msg: str):
print(make_log(level, msg))
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the MLX implementation"
)
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
)
parser.add_argument(
"--quantize",
type=int,
help="The quantization to be applied, e.g. 8 for 8 bits.",
)
args = parser.parse_args()
mx.random.seed(299792458)
log("info", "retrieving checkpoints")
raw_config = hf_get("config.json", args.hf_repo)
with open(hf_get(raw_config), "r") as fobj:
raw_config = json.load(fobj)
mimi_weights = hf_get(raw_config["mimi_name"], args.hf_repo)
moshi_name = raw_config.get("moshi_name", "model.safetensors")
moshi_weights = hf_get(moshi_name, args.hf_repo)
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
lm_config = models.LmConfig.from_config_dict(raw_config)
2025-08-04 14:33:43 +00:00
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
# The following line gets around it for now.
lm_config.transformer.max_seq_len = lm_config.transformer.context
model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16)
log("info", f"loading model weights from {moshi_weights}")
model.load_pytorch_weights(str(moshi_weights), lm_config, strict=True)
if args.quantize is not None:
log("info", f"quantizing model to {args.quantize} bits")
nn.quantize(model.depformer, bits=args.quantize)
for layer in model.transformer.layers:
nn.quantize(layer.self_attn, bits=args.quantize)
nn.quantize(layer.gating, bits=args.quantize)
log("info", f"loading the text tokenizer from {tokenizer}")
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer)) # type: ignore
log("info", f"loading the audio tokenizer {mimi_weights}")
generated_codebooks = lm_config.generated_codebooks
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(generated_codebooks))
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
cfg_coef_conditioning = None
tts_model = TTSModel(
model,
audio_tokenizer,
text_tokenizer,
voice_repo=args.voice_repo,
temp=0.6,
cfg_coef=1,
max_padding=8,
initial_padding=2,
final_padding=2,
padding_bonus=0,
raw_config=raw_config,
)
if tts_model.valid_cfg_conditionings:
# Model was trained with CFG distillation.
cfg_coef_conditioning = tts_model.cfg_coef
tts_model.cfg_coef = 1.0
mimi = tts_model.mimi
log("info", "reading input from stdin")
if tts_model.multi_speaker:
voices = [tts_model.get_voice_path(args.voice)]
else:
voices = []
all_attributes = [
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
]
wav_frames = queue.Queue()
def _on_frame(frame):
if (frame == -1).any():
return
_pcm = tts_model.mimi.decode_step(frame[:, :, None])
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
wav_frames.put_nowait(_pcm)
gen = TTSGen(tts_model, all_attributes, on_frame=_on_frame)
def run():
log("info", "starting the inference loop")
first_turn = True
for line in sys.stdin:
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
first_turn = False
for entry in entries:
gen.append_entry(entry)
gen.process()
gen.process_last()
if args.out == "-":
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = wav_frames.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(
samplerate=mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
):
run()
while True:
if wav_frames.qsize() == 0:
break
time.sleep(1)
else:
run()
frames = []
while True:
try:
frames.append(wav_frames.get_nowait())
except queue.Empty:
break
wav = np.concat(frames, -1)
sphn.write_wav(args.out, wav, mimi.sample_rate)
if __name__ == "__main__":
main()