Add attempt at interactive playback

This commit is contained in:
Vaclav Volhejn 2025-07-02 16:21:10 +02:00
parent 3de0606614
commit 31b425cb6f

View File

@ -4,17 +4,37 @@
# "moshi @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi", # "moshi @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi",
# "torch", # "torch",
# "sphn", # "sphn",
# "sounddevice",
# ] # ]
# /// # ///
import argparse import argparse
import sys import sys
import numpy as np
import sphn import sphn
import torch import torch
from moshi.models.loaders import CheckpointInfo from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
def audio_to_int16(audio: np.ndarray) -> np.ndarray:
if audio.dtype == np.int16:
return audio
elif audio.dtype == np.float32:
# Multiply by 32767 and not 32768 so that int16 doesn't overflow.
return (np.clip(audio, -1, 1) * 32767).astype(np.int16)
else:
raise TypeError(f"Unsupported audio data type: {audio.dtype}")
def play_audio(audio: np.ndarray, sample_rate: int):
# Requires the Portaudio library which might not be available in all environments.
import sounddevice as sd
with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1):
sd.play(audio, sample_rate)
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation" description="Run Kyutai TTS using the PyTorch implementation"
@ -49,7 +69,8 @@ def main():
) )
if args.inp == "-": if args.inp == "-":
print("Enter text to synthesize (Ctrl+D to end input):") if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
text = sys.stdin.read().strip() text = sys.stdin.read().strip()
else: else:
with open(args.inp, "r") as fobj: with open(args.inp, "r") as fobj:
@ -73,8 +94,12 @@ def main():
with torch.no_grad(): with torch.no_grad():
audios = tts_model.mimi.decode(audio_tokens) audios = tts_model.mimi.decode(audio_tokens)
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate) if args.out == "-":
print(f"Audio saved to {args.out}") print("Playing audio...")
play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate)
else:
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate)
print(f"Audio saved to {args.out}")
if __name__ == "__main__": if __name__ == "__main__":