From 3de06066143ac4b1800cc19c7ba8970e9262d112 Mon Sep 17 00:00:00 2001 From: Vaclav Volhejn Date: Wed, 2 Jul 2025 16:08:46 +0200 Subject: [PATCH] Add tts_pytorch.py --- .vscode/settings.json | 3 ++ scripts/tts_pytorch.py | 81 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 .vscode/settings.json create mode 100644 scripts/tts_pytorch.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..e0300ab --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.analysis.typeCheckingMode": "standard" +} diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py new file mode 100644 index 0000000..e95c563 --- /dev/null +++ b/scripts/tts_pytorch.py @@ -0,0 +1,81 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "moshi @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi", +# "torch", +# "sphn", +# ] +# /// +import argparse +import sys + +import sphn +import torch +from moshi.models.loaders import CheckpointInfo +from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel + + +def main(): + parser = argparse.ArgumentParser( + description="Run Kyutai TTS using the PyTorch implementation" + ) + parser.add_argument("inp", type=str, help="Input file, use - for stdin.") + parser.add_argument( + "out", type=str, help="Output file to generate, use - for playing the audio" + ) + parser.add_argument( + "--hf-repo", + type=str, + default=DEFAULT_DSM_TTS_REPO, + help="HF repo in which to look for the pretrained models.", + ) + parser.add_argument( + "--voice-repo", + default=DEFAULT_DSM_TTS_VOICE_REPO, + help="HF repo in which to look for pre-computed voice embeddings.", + ) + parser.add_argument( + "--voice", + default="expresso/ex03-ex01_happy_001_channel1_334s.wav", + help="The voice to use, relative to the voice repo root. " + f"See {DEFAULT_DSM_TTS_VOICE_REPO}", + ) + args = parser.parse_args() + + print("Loading model...") + checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo) + tts_model = TTSModel.from_checkpoint_info( + checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half + ) + + if args.inp == "-": + print("Enter text to synthesize (Ctrl+D to end input):") + text = sys.stdin.read().strip() + else: + with open(args.inp, "r") as fobj: + text = fobj.read().strip() + + # You could also generate multiple audios at once by passing a list of texts. + entries = tts_model.prepare_script([text], padding_between=1) + voice_path = tts_model.get_voice_path(args.voice) + # CFG coef goes here because the model was trained with CFG distillation, + # so it's not _actually_ doing CFG at inference time. + condition_attributes = tts_model.make_condition_attributes( + [voice_path], cfg_coef=2.0 + ) + + print("Generating audio...") + # This doesn't do streaming generation, + result = tts_model.generate([entries], [condition_attributes]) + + frames = torch.cat(result.frames, dim=-1) + audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :] + with torch.no_grad(): + audios = tts_model.mimi.decode(audio_tokens) + + sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate) + print(f"Audio saved to {args.out}") + + +if __name__ == "__main__": + main()