Streaming output for the pytorch tts example. (#33)

* Streaming output for the pytorch tts example.

* Run the pre-commit hooks.
This commit is contained in:
Laurent Mazare 2025-07-03 11:05:06 +02:00 committed by GitHub
parent d3bed09f9a
commit 5f8e924176
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 72 deletions

View File

@ -1,7 +1,6 @@
"""An example script that illustrates how one can prompt Kyutai STT models."""
import argparse
import dataclasses
import itertools
import math
from collections import deque

View File

@ -11,20 +11,14 @@ import argparse
import sys
import numpy as np
import queue
import sphn
import time
import torch
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
def play_audio(audio: np.ndarray, sample_rate: int):
# Requires the Portaudio library which might not be available in all environments.
import sounddevice as sd
with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1):
sd.play(audio, sample_rate)
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
@ -57,6 +51,7 @@ def main():
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
)
tts_model.mimi.streaming_forever(batch_size=1)
if args.inp == "-":
if sys.stdin.isatty(): # Interactive
@ -75,22 +70,45 @@ def main():
[voice_path], cfg_coef=2.0
)
print("Generating audio...")
# This doesn't do streaming generation, but the model allows it. For now, see Rust
# example.
result = tts_model.generate([entries], [condition_attributes])
frames = torch.cat(result.frames, dim=-1)
audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :]
with torch.no_grad():
audios = tts_model.mimi.decode(audio_tokens)
if args.out == "-":
print("Playing audio...")
play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate)
# Stream the audio to the speakers using sounddevice.
import sounddevice as sd
pcms = queue.Queue()
def _on_frame(frame):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = pcms.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(
samplerate=tts_model.mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
):
tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)
time.sleep(3)
while True:
if pcms.qsize() == 0:
break
time.sleep(1)
else:
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate)
print(f"Audio saved to {args.out}")
result = tts_model.generate([entries], [condition_attributes])
with torch.no_grad():
pcms = []
for frame in result.frames[tts_model.delay_steps :]:
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(np.clip(pcm[0, 0], -1, 1))
pcm = np.concatenate(pcms, axis=-1)
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
if __name__ == "__main__":

File diff suppressed because one or more lines are too long