diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py index 2fc7052..5478078 100644 --- a/scripts/tts_pytorch.py +++ b/scripts/tts_pytorch.py @@ -11,20 +11,14 @@ import argparse import sys import numpy as np +import queue import sphn +import time import torch from moshi.models.loaders import CheckpointInfo from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel -def play_audio(audio: np.ndarray, sample_rate: int): - # Requires the Portaudio library which might not be available in all environments. - import sounddevice as sd - - with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1): - sd.play(audio, sample_rate) - - def main(): parser = argparse.ArgumentParser( description="Run Kyutai TTS using the PyTorch implementation" @@ -57,6 +51,7 @@ def main(): tts_model = TTSModel.from_checkpoint_info( checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half ) + tts_model.mimi.streaming_forever(batch_size=1) if args.inp == "-": if sys.stdin.isatty(): # Interactive @@ -75,22 +70,41 @@ def main(): [voice_path], cfg_coef=2.0 ) - print("Generating audio...") - # This doesn't do streaming generation, but the model allows it. For now, see Rust - # example. - result = tts_model.generate([entries], [condition_attributes]) - - frames = torch.cat(result.frames, dim=-1) - audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :] - with torch.no_grad(): - audios = tts_model.mimi.decode(audio_tokens) - if args.out == "-": - print("Playing audio...") - play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate) + # Stream the audio to the speakers using sounddevice. + import sounddevice as sd + pcms = queue.Queue() + + def _on_frame(frame): + if (frame != -1).all(): + pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() + pcms.put_nowait(np.clip(pcm[0, 0], -1, 1)) + + def audio_callback(outdata, _a, _b, _c): + try: + pcm_data = pcms.get(block=False) + outdata[:, 0] = pcm_data + except queue.Empty: + outdata[:] = 0 + with sd.OutputStream(samplerate=tts_model.mimi.sample_rate, + blocksize=1920, + channels=1, + callback=audio_callback): + tts_model.generate([entries], [condition_attributes], on_frame=_on_frame) + time.sleep(3) + while True: + if pcms.qsize() == 0: + break + time.sleep(1) else: - sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate) - print(f"Audio saved to {args.out}") + result = tts_model.generate([entries], [condition_attributes]) + with torch.no_grad(): + pcms = [] + for frame in result.frames[tts_model.delay_steps:]: + pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() + pcms.append(np.clip(pcm[0, 0], -1, 1)) + pcm = np.concatenate(pcms, axis=-1) + sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate) if __name__ == "__main__":