Streaming output for the pytorch tts example. (#33)

* Streaming output for the pytorch tts example.

* Run the pre-commit hooks.
This commit is contained in:
Laurent Mazare 2025-07-03 11:05:06 +02:00 committed by GitHub
parent d3bed09f9a
commit 5f8e924176
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 72 deletions

View File

@ -1,7 +1,6 @@
"""An example script that illustrates how one can prompt Kyutai STT models.""" """An example script that illustrates how one can prompt Kyutai STT models."""
import argparse import argparse
import dataclasses
import itertools import itertools
import math import math
from collections import deque from collections import deque

View File

@ -11,20 +11,14 @@ import argparse
import sys import sys
import numpy as np import numpy as np
import queue
import sphn import sphn
import time
import torch import torch
from moshi.models.loaders import CheckpointInfo from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
def play_audio(audio: np.ndarray, sample_rate: int):
# Requires the Portaudio library which might not be available in all environments.
import sounddevice as sd
with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1):
sd.play(audio, sample_rate)
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation" description="Run Kyutai TTS using the PyTorch implementation"
@ -57,6 +51,7 @@ def main():
tts_model = TTSModel.from_checkpoint_info( tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
) )
tts_model.mimi.streaming_forever(batch_size=1)
if args.inp == "-": if args.inp == "-":
if sys.stdin.isatty(): # Interactive if sys.stdin.isatty(): # Interactive
@ -75,22 +70,45 @@ def main():
[voice_path], cfg_coef=2.0 [voice_path], cfg_coef=2.0
) )
print("Generating audio...")
# This doesn't do streaming generation, but the model allows it. For now, see Rust
# example.
result = tts_model.generate([entries], [condition_attributes])
frames = torch.cat(result.frames, dim=-1)
audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :]
with torch.no_grad():
audios = tts_model.mimi.decode(audio_tokens)
if args.out == "-": if args.out == "-":
print("Playing audio...") # Stream the audio to the speakers using sounddevice.
play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate) import sounddevice as sd
pcms = queue.Queue()
def _on_frame(frame):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = pcms.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(
samplerate=tts_model.mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
):
tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)
time.sleep(3)
while True:
if pcms.qsize() == 0:
break
time.sleep(1)
else: else:
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate) result = tts_model.generate([entries], [condition_attributes])
print(f"Audio saved to {args.out}") with torch.no_grad():
pcms = []
for frame in result.frames[tts_model.delay_steps :]:
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(np.clip(pcm[0, 0], -1, 1))
pcm = np.concatenate(pcms, axis=-1)
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
if __name__ == "__main__": if __name__ == "__main__":

File diff suppressed because one or more lines are too long