Streaming output for the pytorch tts example.

This commit is contained in:
laurent 2025-07-03 08:37:08 +02:00
parent 6c1e9f12cf
commit b0e1c2a6b4

View File

@ -11,20 +11,14 @@ import argparse
import sys import sys
import numpy as np import numpy as np
import queue
import sphn import sphn
import time
import torch import torch
from moshi.models.loaders import CheckpointInfo from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
def play_audio(audio: np.ndarray, sample_rate: int):
# Requires the Portaudio library which might not be available in all environments.
import sounddevice as sd
with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1):
sd.play(audio, sample_rate)
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation" description="Run Kyutai TTS using the PyTorch implementation"
@ -57,6 +51,7 @@ def main():
tts_model = TTSModel.from_checkpoint_info( tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
) )
tts_model.mimi.streaming_forever(batch_size=1)
if args.inp == "-": if args.inp == "-":
if sys.stdin.isatty(): # Interactive if sys.stdin.isatty(): # Interactive
@ -75,22 +70,41 @@ def main():
[voice_path], cfg_coef=2.0 [voice_path], cfg_coef=2.0
) )
print("Generating audio...")
# This doesn't do streaming generation, but the model allows it. For now, see Rust
# example.
result = tts_model.generate([entries], [condition_attributes])
frames = torch.cat(result.frames, dim=-1)
audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :]
with torch.no_grad():
audios = tts_model.mimi.decode(audio_tokens)
if args.out == "-": if args.out == "-":
print("Playing audio...") # Stream the audio to the speakers using sounddevice.
play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate) import sounddevice as sd
pcms = queue.Queue()
def _on_frame(frame):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = pcms.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(samplerate=tts_model.mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback):
tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)
time.sleep(3)
while True:
if pcms.qsize() == 0:
break
time.sleep(1)
else: else:
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate) result = tts_model.generate([entries], [condition_attributes])
print(f"Audio saved to {args.out}") with torch.no_grad():
pcms = []
for frame in result.frames[tts_model.delay_steps:]:
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(np.clip(pcm[0, 0], -1, 1))
pcm = np.concatenate(pcms, axis=-1)
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
if __name__ == "__main__": if __name__ == "__main__":