Add the MLX TTS example.
This commit is contained in:
parent
c4ef93770a
commit
aa06a44fd4
137
scripts/tts_mlx.py
Normal file
137
scripts/tts_mlx.py
Normal file
|
|
@ -0,0 +1,137 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "huggingface_hub",
|
||||||
|
# "moshi_mlx",
|
||||||
|
# "numpy",
|
||||||
|
# "sounddevice",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import sentencepiece
|
||||||
|
import sphn
|
||||||
|
|
||||||
|
from moshi_mlx.client_utils import make_log
|
||||||
|
from moshi_mlx import models
|
||||||
|
from moshi_mlx.utils.loaders import hf_get
|
||||||
|
from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
|
||||||
|
|
||||||
|
|
||||||
|
def log(level: str, msg: str):
|
||||||
|
print(make_log(level, msg))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi')
|
||||||
|
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
|
||||||
|
parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio")
|
||||||
|
parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO,
|
||||||
|
help="HF repo in which to look for the pretrained models.")
|
||||||
|
parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
help="HF repo in which to look for pre-computed voice embeddings.")
|
||||||
|
parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav")
|
||||||
|
parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
mx.random.seed(299792458)
|
||||||
|
|
||||||
|
log("info", "retrieving checkpoints")
|
||||||
|
|
||||||
|
raw_config = hf_get("config.json", args.hf_repo)
|
||||||
|
with open(hf_get(raw_config), "r") as fobj:
|
||||||
|
raw_config = json.load(fobj)
|
||||||
|
|
||||||
|
mimi_weights = hf_get(raw_config["mimi_name"], args.hf_repo)
|
||||||
|
moshi_name = raw_config.get("moshi_name", "model.safetensors")
|
||||||
|
moshi_weights = hf_get(moshi_name, args.hf_repo)
|
||||||
|
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
|
||||||
|
lm_config = models.LmConfig.from_config_dict(raw_config)
|
||||||
|
model = models.Lm(lm_config)
|
||||||
|
model.set_dtype(mx.bfloat16)
|
||||||
|
|
||||||
|
log("info", f"loading model weights from {moshi_weights}")
|
||||||
|
model.load_pytorch_weights(str(moshi_weights), lm_config, strict=True)
|
||||||
|
|
||||||
|
if args.quantize is not None:
|
||||||
|
log("info", f"quantizing model to {args.quantize} bits")
|
||||||
|
nn.quantize(model.depformer, bits=args.quantize)
|
||||||
|
for layer in model.transformer.layers:
|
||||||
|
nn.quantize(layer.self_attn, bits=args.quantize)
|
||||||
|
nn.quantize(layer.gating, bits=args.quantize)
|
||||||
|
|
||||||
|
log("info", f"loading the text tokenizer from {tokenizer}")
|
||||||
|
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer)) # type: ignore
|
||||||
|
|
||||||
|
log("info", f"loading the audio tokenizer {mimi_weights}")
|
||||||
|
generated_codebooks = lm_config.generated_codebooks
|
||||||
|
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(generated_codebooks))
|
||||||
|
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
|
||||||
|
|
||||||
|
cfg_coef_conditioning = None
|
||||||
|
tts_model = TTSModel(
|
||||||
|
model,
|
||||||
|
audio_tokenizer,
|
||||||
|
text_tokenizer,
|
||||||
|
voice_repo=args.voice_repo,
|
||||||
|
temp=0.6,
|
||||||
|
cfg_coef=1,
|
||||||
|
max_padding=8,
|
||||||
|
initial_padding=2,
|
||||||
|
final_padding=2,
|
||||||
|
padding_bonus=0,
|
||||||
|
raw_config=raw_config,
|
||||||
|
)
|
||||||
|
if tts_model.valid_cfg_conditionings:
|
||||||
|
# Model was trained with CFG distillation.
|
||||||
|
cfg_coef_conditioning = tts_model.cfg_coef
|
||||||
|
tts_model.cfg_coef = 1.
|
||||||
|
cfg_is_no_text = False
|
||||||
|
cfg_is_no_prefix = False
|
||||||
|
else:
|
||||||
|
cfg_is_no_text = True
|
||||||
|
cfg_is_no_prefix = True
|
||||||
|
mimi = tts_model.mimi
|
||||||
|
|
||||||
|
with open(args.inp, "r") as fobj:
|
||||||
|
text_to_tts = fobj.read().strip()
|
||||||
|
|
||||||
|
all_entries = [tts_model.prepare_script([text_to_tts])]
|
||||||
|
if tts_model.multi_speaker:
|
||||||
|
voices = [tts_model.get_voice_path(args.voice)]
|
||||||
|
else:
|
||||||
|
voices = []
|
||||||
|
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
|
||||||
|
|
||||||
|
begin = time.time()
|
||||||
|
result = tts_model.generate(
|
||||||
|
all_entries, all_attributes,
|
||||||
|
cfg_is_no_prefix=cfg_is_no_prefix, cfg_is_no_text=cfg_is_no_text)
|
||||||
|
frames = mx.concat(result.frames, axis=-1)
|
||||||
|
total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate
|
||||||
|
time_taken = time.time() - begin
|
||||||
|
total_speed = total_duration / time_taken
|
||||||
|
log("info", f"[LM] took {time_taken:.2f}s, total speed {total_speed:.2f}x")
|
||||||
|
|
||||||
|
wav_frames = []
|
||||||
|
for frame in result.frames:
|
||||||
|
# We are processing frames one by one, although we could group them to improve speed.
|
||||||
|
_pcm = tts_model.mimi.decode_step(frame)
|
||||||
|
wav_frames.append(_pcm)
|
||||||
|
wavs = mx.concat(wav_frames, axis=-1)
|
||||||
|
end_step = result.end_steps[0]
|
||||||
|
wav_length = int((mimi.sample_rate * (end_step + tts_model.final_padding) / mimi.frame_rate))
|
||||||
|
wav = wavs[0, :, :wav_length]
|
||||||
|
sphn.write_wav(args.out, np.array(mx.clip(wav, -1, 1)), mimi.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue
Block a user