Add attempt at interactive playback
This commit is contained in:
parent
3de0606614
commit
31b425cb6f
|
|
@ -4,17 +4,37 @@
|
||||||
# "moshi @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi",
|
# "moshi @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi",
|
||||||
# "torch",
|
# "torch",
|
||||||
# "sphn",
|
# "sphn",
|
||||||
|
# "sounddevice",
|
||||||
# ]
|
# ]
|
||||||
# ///
|
# ///
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import sphn
|
import sphn
|
||||||
import torch
|
import torch
|
||||||
from moshi.models.loaders import CheckpointInfo
|
from moshi.models.loaders import CheckpointInfo
|
||||||
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
|
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
|
||||||
|
|
||||||
|
|
||||||
|
def audio_to_int16(audio: np.ndarray) -> np.ndarray:
|
||||||
|
if audio.dtype == np.int16:
|
||||||
|
return audio
|
||||||
|
elif audio.dtype == np.float32:
|
||||||
|
# Multiply by 32767 and not 32768 so that int16 doesn't overflow.
|
||||||
|
return (np.clip(audio, -1, 1) * 32767).astype(np.int16)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported audio data type: {audio.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def play_audio(audio: np.ndarray, sample_rate: int):
|
||||||
|
# Requires the Portaudio library which might not be available in all environments.
|
||||||
|
import sounddevice as sd
|
||||||
|
|
||||||
|
with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1):
|
||||||
|
sd.play(audio, sample_rate)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Run Kyutai TTS using the PyTorch implementation"
|
description="Run Kyutai TTS using the PyTorch implementation"
|
||||||
|
|
@ -49,7 +69,8 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.inp == "-":
|
if args.inp == "-":
|
||||||
print("Enter text to synthesize (Ctrl+D to end input):")
|
if sys.stdin.isatty(): # Interactive
|
||||||
|
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||||
text = sys.stdin.read().strip()
|
text = sys.stdin.read().strip()
|
||||||
else:
|
else:
|
||||||
with open(args.inp, "r") as fobj:
|
with open(args.inp, "r") as fobj:
|
||||||
|
|
@ -73,8 +94,12 @@ def main():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
audios = tts_model.mimi.decode(audio_tokens)
|
audios = tts_model.mimi.decode(audio_tokens)
|
||||||
|
|
||||||
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate)
|
if args.out == "-":
|
||||||
print(f"Audio saved to {args.out}")
|
print("Playing audio...")
|
||||||
|
play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate)
|
||||||
|
else:
|
||||||
|
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate)
|
||||||
|
print(f"Audio saved to {args.out}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user