diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py index e95c563..0e6e6dc 100644 --- a/scripts/tts_pytorch.py +++ b/scripts/tts_pytorch.py @@ -4,17 +4,37 @@ # "moshi @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi", # "torch", # "sphn", +# "sounddevice", # ] # /// import argparse import sys +import numpy as np import sphn import torch from moshi.models.loaders import CheckpointInfo from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel +def audio_to_int16(audio: np.ndarray) -> np.ndarray: + if audio.dtype == np.int16: + return audio + elif audio.dtype == np.float32: + # Multiply by 32767 and not 32768 so that int16 doesn't overflow. + return (np.clip(audio, -1, 1) * 32767).astype(np.int16) + else: + raise TypeError(f"Unsupported audio data type: {audio.dtype}") + + +def play_audio(audio: np.ndarray, sample_rate: int): + # Requires the Portaudio library which might not be available in all environments. + import sounddevice as sd + + with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1): + sd.play(audio, sample_rate) + + def main(): parser = argparse.ArgumentParser( description="Run Kyutai TTS using the PyTorch implementation" @@ -49,7 +69,8 @@ def main(): ) if args.inp == "-": - print("Enter text to synthesize (Ctrl+D to end input):") + if sys.stdin.isatty(): # Interactive + print("Enter text to synthesize (Ctrl+D to end input):") text = sys.stdin.read().strip() else: with open(args.inp, "r") as fobj: @@ -73,8 +94,12 @@ def main(): with torch.no_grad(): audios = tts_model.mimi.decode(audio_tokens) - sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate) - print(f"Audio saved to {args.out}") + if args.out == "-": + print("Playing audio...") + play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate) + else: + sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate) + print(f"Audio saved to {args.out}") if __name__ == "__main__":