Add attempt at interactive playback
This commit is contained in:
parent
3de0606614
commit
31b425cb6f
|
|
@ -4,17 +4,37 @@
|
|||
# "moshi @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi",
|
||||
# "torch",
|
||||
# "sphn",
|
||||
# "sounddevice",
|
||||
# ]
|
||||
# ///
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import sphn
|
||||
import torch
|
||||
from moshi.models.loaders import CheckpointInfo
|
||||
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
|
||||
|
||||
|
||||
def audio_to_int16(audio: np.ndarray) -> np.ndarray:
|
||||
if audio.dtype == np.int16:
|
||||
return audio
|
||||
elif audio.dtype == np.float32:
|
||||
# Multiply by 32767 and not 32768 so that int16 doesn't overflow.
|
||||
return (np.clip(audio, -1, 1) * 32767).astype(np.int16)
|
||||
else:
|
||||
raise TypeError(f"Unsupported audio data type: {audio.dtype}")
|
||||
|
||||
|
||||
def play_audio(audio: np.ndarray, sample_rate: int):
|
||||
# Requires the Portaudio library which might not be available in all environments.
|
||||
import sounddevice as sd
|
||||
|
||||
with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1):
|
||||
sd.play(audio, sample_rate)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Kyutai TTS using the PyTorch implementation"
|
||||
|
|
@ -49,7 +69,8 @@ def main():
|
|||
)
|
||||
|
||||
if args.inp == "-":
|
||||
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||
if sys.stdin.isatty(): # Interactive
|
||||
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||
text = sys.stdin.read().strip()
|
||||
else:
|
||||
with open(args.inp, "r") as fobj:
|
||||
|
|
@ -73,8 +94,12 @@ def main():
|
|||
with torch.no_grad():
|
||||
audios = tts_model.mimi.decode(audio_tokens)
|
||||
|
||||
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate)
|
||||
print(f"Audio saved to {args.out}")
|
||||
if args.out == "-":
|
||||
print("Playing audio...")
|
||||
play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate)
|
||||
else:
|
||||
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate)
|
||||
print(f"Audio saved to {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user