Add tts_pytorch.py

This commit is contained in:
Vaclav Volhejn 2025-07-02 16:08:46 +02:00
parent c4ef93770a
commit 3de0606614
2 changed files with 84 additions and 0 deletions

3
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,3 @@
{
"python.analysis.typeCheckingMode": "standard"
}

81
scripts/tts_pytorch.py Normal file
View File

@ -0,0 +1,81 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "moshi @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi",
# "torch",
# "sphn",
# ]
# ///
import argparse
import sys
import sphn
import torch
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
)
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice",
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
help="The voice to use, relative to the voice repo root. "
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
)
args = parser.parse_args()
print("Loading model...")
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
)
if args.inp == "-":
print("Enter text to synthesize (Ctrl+D to end input):")
text = sys.stdin.read().strip()
else:
with open(args.inp, "r") as fobj:
text = fobj.read().strip()
# You could also generate multiple audios at once by passing a list of texts.
entries = tts_model.prepare_script([text], padding_between=1)
voice_path = tts_model.get_voice_path(args.voice)
# CFG coef goes here because the model was trained with CFG distillation,
# so it's not _actually_ doing CFG at inference time.
condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0
)
print("Generating audio...")
# This doesn't do streaming generation,
result = tts_model.generate([entries], [condition_attributes])
frames = torch.cat(result.frames, dim=-1)
audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :]
with torch.no_grad():
audios = tts_model.mimi.decode(audio_tokens)
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate)
print(f"Audio saved to {args.out}")
if __name__ == "__main__":
main()