Run the pre-commit hooks.

This commit is contained in:
laurent 2025-07-03 11:04:24 +02:00
parent 9ba4e88553
commit c970b86088
3 changed files with 24 additions and 55 deletions

View File

@ -1,7 +1,6 @@
"""An example script that illustrates how one can prompt Kyutai STT models.""" """An example script that illustrates how one can prompt Kyutai STT models."""
import argparse import argparse
import dataclasses
import itertools import itertools
import math import math
from collections import deque from collections import deque

View File

@ -73,6 +73,7 @@ def main():
if args.out == "-": if args.out == "-":
# Stream the audio to the speakers using sounddevice. # Stream the audio to the speakers using sounddevice.
import sounddevice as sd import sounddevice as sd
pcms = queue.Queue() pcms = queue.Queue()
def _on_frame(frame): def _on_frame(frame):
@ -86,10 +87,13 @@ def main():
outdata[:, 0] = pcm_data outdata[:, 0] = pcm_data
except queue.Empty: except queue.Empty:
outdata[:] = 0 outdata[:] = 0
with sd.OutputStream(samplerate=tts_model.mimi.sample_rate,
with sd.OutputStream(
samplerate=tts_model.mimi.sample_rate,
blocksize=1920, blocksize=1920,
channels=1, channels=1,
callback=audio_callback): callback=audio_callback,
):
tts_model.generate([entries], [condition_attributes], on_frame=_on_frame) tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)
time.sleep(3) time.sleep(3)
while True: while True:
@ -100,7 +104,7 @@ def main():
result = tts_model.generate([entries], [condition_attributes]) result = tts_model.generate([entries], [condition_attributes])
with torch.no_grad(): with torch.no_grad():
pcms = [] pcms = []
for frame in result.frames[tts_model.delay_steps:]: for frame in result.frames[tts_model.delay_steps :]:
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(np.clip(pcm[0, 0], -1, 1)) pcms.append(np.clip(pcm[0, 0], -1, 1))
pcm = np.concatenate(pcms, axis=-1) pcm = np.concatenate(pcms, axis=-1)

File diff suppressed because one or more lines are too long