diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..e0300ab --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.analysis.typeCheckingMode": "standard" +} diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py new file mode 100644 index 0000000..0e6e6dc --- /dev/null +++ b/scripts/tts_pytorch.py @@ -0,0 +1,106 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "moshi @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi", +# "torch", +# "sphn", +# "sounddevice", +# ] +# /// +import argparse +import sys + +import numpy as np +import sphn +import torch +from moshi.models.loaders import CheckpointInfo +from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel + + +def audio_to_int16(audio: np.ndarray) -> np.ndarray: + if audio.dtype == np.int16: + return audio + elif audio.dtype == np.float32: + # Multiply by 32767 and not 32768 so that int16 doesn't overflow. + return (np.clip(audio, -1, 1) * 32767).astype(np.int16) + else: + raise TypeError(f"Unsupported audio data type: {audio.dtype}") + + +def play_audio(audio: np.ndarray, sample_rate: int): + # Requires the Portaudio library which might not be available in all environments. + import sounddevice as sd + + with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1): + sd.play(audio, sample_rate) + + +def main(): + parser = argparse.ArgumentParser( + description="Run Kyutai TTS using the PyTorch implementation" + ) + parser.add_argument("inp", type=str, help="Input file, use - for stdin.") + parser.add_argument( + "out", type=str, help="Output file to generate, use - for playing the audio" + ) + parser.add_argument( + "--hf-repo", + type=str, + default=DEFAULT_DSM_TTS_REPO, + help="HF repo in which to look for the pretrained models.", + ) + parser.add_argument( + "--voice-repo", + default=DEFAULT_DSM_TTS_VOICE_REPO, + help="HF repo in which to look for pre-computed voice embeddings.", + ) + parser.add_argument( + "--voice", + default="expresso/ex03-ex01_happy_001_channel1_334s.wav", + help="The voice to use, relative to the voice repo root. " + f"See {DEFAULT_DSM_TTS_VOICE_REPO}", + ) + args = parser.parse_args() + + print("Loading model...") + checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo) + tts_model = TTSModel.from_checkpoint_info( + checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half + ) + + if args.inp == "-": + if sys.stdin.isatty(): # Interactive + print("Enter text to synthesize (Ctrl+D to end input):") + text = sys.stdin.read().strip() + else: + with open(args.inp, "r") as fobj: + text = fobj.read().strip() + + # You could also generate multiple audios at once by passing a list of texts. + entries = tts_model.prepare_script([text], padding_between=1) + voice_path = tts_model.get_voice_path(args.voice) + # CFG coef goes here because the model was trained with CFG distillation, + # so it's not _actually_ doing CFG at inference time. + condition_attributes = tts_model.make_condition_attributes( + [voice_path], cfg_coef=2.0 + ) + + print("Generating audio...") + # This doesn't do streaming generation, + result = tts_model.generate([entries], [condition_attributes]) + + frames = torch.cat(result.frames, dim=-1) + audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :] + with torch.no_grad(): + audios = tts_model.mimi.decode(audio_tokens) + + if args.out == "-": + print("Playing audio...") + play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate) + else: + sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate) + print(f"Audio saved to {args.out}") + + +if __name__ == "__main__": + main()