From 31b425cb6f43528f0241f150617b5423b9159305 Mon Sep 17 00:00:00 2001 From: Vaclav Volhejn Date: Wed, 2 Jul 2025 16:21:10 +0200 Subject: [PATCH] Add attempt at interactive playback --- scripts/tts_pytorch.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py index e95c563..0e6e6dc 100644 --- a/scripts/tts_pytorch.py +++ b/scripts/tts_pytorch.py @@ -4,17 +4,37 @@ # "moshi @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi", # "torch", # "sphn", +# "sounddevice", # ] # /// import argparse import sys +import numpy as np import sphn import torch from moshi.models.loaders import CheckpointInfo from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel +def audio_to_int16(audio: np.ndarray) -> np.ndarray: + if audio.dtype == np.int16: + return audio + elif audio.dtype == np.float32: + # Multiply by 32767 and not 32768 so that int16 doesn't overflow. + return (np.clip(audio, -1, 1) * 32767).astype(np.int16) + else: + raise TypeError(f"Unsupported audio data type: {audio.dtype}") + + +def play_audio(audio: np.ndarray, sample_rate: int): + # Requires the Portaudio library which might not be available in all environments. + import sounddevice as sd + + with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1): + sd.play(audio, sample_rate) + + def main(): parser = argparse.ArgumentParser( description="Run Kyutai TTS using the PyTorch implementation" @@ -49,7 +69,8 @@ def main(): ) if args.inp == "-": - print("Enter text to synthesize (Ctrl+D to end input):") + if sys.stdin.isatty(): # Interactive + print("Enter text to synthesize (Ctrl+D to end input):") text = sys.stdin.read().strip() else: with open(args.inp, "r") as fobj: @@ -73,8 +94,12 @@ def main(): with torch.no_grad(): audios = tts_model.mimi.decode(audio_tokens) - sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate) - print(f"Audio saved to {args.out}") + if args.out == "-": + print("Playing audio...") + play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate) + else: + sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate) + print(f"Audio saved to {args.out}") if __name__ == "__main__":