Add attempt at interactive playback

This commit is contained in:
Vaclav Volhejn 2025-07-02 16:21:10 +02:00
parent 3de0606614
commit 31b425cb6f

View File

@ -4,17 +4,37 @@
# "moshi @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi",
# "torch",
# "sphn",
# "sounddevice",
# ]
# ///
import argparse
import sys
import numpy as np
import sphn
import torch
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
def audio_to_int16(audio: np.ndarray) -> np.ndarray:
if audio.dtype == np.int16:
return audio
elif audio.dtype == np.float32:
# Multiply by 32767 and not 32768 so that int16 doesn't overflow.
return (np.clip(audio, -1, 1) * 32767).astype(np.int16)
else:
raise TypeError(f"Unsupported audio data type: {audio.dtype}")
def play_audio(audio: np.ndarray, sample_rate: int):
# Requires the Portaudio library which might not be available in all environments.
import sounddevice as sd
with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1):
sd.play(audio, sample_rate)
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
@ -49,6 +69,7 @@ def main():
)
if args.inp == "-":
if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
text = sys.stdin.read().strip()
else:
@ -73,6 +94,10 @@ def main():
with torch.no_grad():
audios = tts_model.mimi.decode(audio_tokens)
if args.out == "-":
print("Playing audio...")
play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate)
else:
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate)
print(f"Audio saved to {args.out}")