diff --git a/scripts/tts_mlx.py b/scripts/tts_mlx.py new file mode 100644 index 0000000..95c59e4 --- /dev/null +++ b/scripts/tts_mlx.py @@ -0,0 +1,137 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "huggingface_hub", +# "moshi_mlx", +# "numpy", +# "sounddevice", +# ] +# /// + +import argparse +from dataclasses import dataclass +import json +from pathlib import Path +import time + +import numpy as np +import mlx.core as mx +import mlx.nn as nn +import sentencepiece +import sphn + +from moshi_mlx.client_utils import make_log +from moshi_mlx import models +from moshi_mlx.utils.loaders import hf_get +from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO + + +def log(level: str, msg: str): + print(make_log(level, msg)) + + +def main(): + parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi') + parser.add_argument("inp", type=str, help="Input file, use - for stdin") + parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio") + parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO, + help="HF repo in which to look for the pretrained models.") + parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO, + help="HF repo in which to look for pre-computed voice embeddings.") + parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav") + parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.") + args = parser.parse_args() + + mx.random.seed(299792458) + + log("info", "retrieving checkpoints") + + raw_config = hf_get("config.json", args.hf_repo) + with open(hf_get(raw_config), "r") as fobj: + raw_config = json.load(fobj) + + mimi_weights = hf_get(raw_config["mimi_name"], args.hf_repo) + moshi_name = raw_config.get("moshi_name", "model.safetensors") + moshi_weights = hf_get(moshi_name, args.hf_repo) + tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo) + lm_config = models.LmConfig.from_config_dict(raw_config) + model = models.Lm(lm_config) + model.set_dtype(mx.bfloat16) + + log("info", f"loading model weights from {moshi_weights}") + model.load_pytorch_weights(str(moshi_weights), lm_config, strict=True) + + if args.quantize is not None: + log("info", f"quantizing model to {args.quantize} bits") + nn.quantize(model.depformer, bits=args.quantize) + for layer in model.transformer.layers: + nn.quantize(layer.self_attn, bits=args.quantize) + nn.quantize(layer.gating, bits=args.quantize) + + log("info", f"loading the text tokenizer from {tokenizer}") + text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer)) # type: ignore + + log("info", f"loading the audio tokenizer {mimi_weights}") + generated_codebooks = lm_config.generated_codebooks + audio_tokenizer = models.mimi.Mimi(models.mimi_202407(generated_codebooks)) + audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True) + + cfg_coef_conditioning = None + tts_model = TTSModel( + model, + audio_tokenizer, + text_tokenizer, + voice_repo=args.voice_repo, + temp=0.6, + cfg_coef=1, + max_padding=8, + initial_padding=2, + final_padding=2, + padding_bonus=0, + raw_config=raw_config, + ) + if tts_model.valid_cfg_conditionings: + # Model was trained with CFG distillation. + cfg_coef_conditioning = tts_model.cfg_coef + tts_model.cfg_coef = 1. + cfg_is_no_text = False + cfg_is_no_prefix = False + else: + cfg_is_no_text = True + cfg_is_no_prefix = True + mimi = tts_model.mimi + + with open(args.inp, "r") as fobj: + text_to_tts = fobj.read().strip() + + all_entries = [tts_model.prepare_script([text_to_tts])] + if tts_model.multi_speaker: + voices = [tts_model.get_voice_path(args.voice)] + else: + voices = [] + all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)] + + begin = time.time() + result = tts_model.generate( + all_entries, all_attributes, + cfg_is_no_prefix=cfg_is_no_prefix, cfg_is_no_text=cfg_is_no_text) + frames = mx.concat(result.frames, axis=-1) + total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate + time_taken = time.time() - begin + total_speed = total_duration / time_taken + log("info", f"[LM] took {time_taken:.2f}s, total speed {total_speed:.2f}x") + + wav_frames = [] + for frame in result.frames: + # We are processing frames one by one, although we could group them to improve speed. + _pcm = tts_model.mimi.decode_step(frame) + wav_frames.append(_pcm) + wavs = mx.concat(wav_frames, axis=-1) + end_step = result.end_steps[0] + wav_length = int((mimi.sample_rate * (end_step + tts_model.final_padding) / mimi.frame_rate)) + wav = wavs[0, :, :wav_length] + sphn.write_wav(args.out, np.array(mx.clip(wav, -1, 1)), mimi.sample_rate) + + +if __name__ == "__main__": + main()