Streaming output for the pytorch tts example.
This commit is contained in:
parent
6c1e9f12cf
commit
b0e1c2a6b4
|
|
@ -11,20 +11,14 @@ import argparse
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import queue
|
||||||
import sphn
|
import sphn
|
||||||
|
import time
|
||||||
import torch
|
import torch
|
||||||
from moshi.models.loaders import CheckpointInfo
|
from moshi.models.loaders import CheckpointInfo
|
||||||
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
|
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
|
||||||
|
|
||||||
|
|
||||||
def play_audio(audio: np.ndarray, sample_rate: int):
|
|
||||||
# Requires the Portaudio library which might not be available in all environments.
|
|
||||||
import sounddevice as sd
|
|
||||||
|
|
||||||
with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1):
|
|
||||||
sd.play(audio, sample_rate)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Run Kyutai TTS using the PyTorch implementation"
|
description="Run Kyutai TTS using the PyTorch implementation"
|
||||||
|
|
@ -57,6 +51,7 @@ def main():
|
||||||
tts_model = TTSModel.from_checkpoint_info(
|
tts_model = TTSModel.from_checkpoint_info(
|
||||||
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
|
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
|
||||||
)
|
)
|
||||||
|
tts_model.mimi.streaming_forever(batch_size=1)
|
||||||
|
|
||||||
if args.inp == "-":
|
if args.inp == "-":
|
||||||
if sys.stdin.isatty(): # Interactive
|
if sys.stdin.isatty(): # Interactive
|
||||||
|
|
@ -75,22 +70,41 @@ def main():
|
||||||
[voice_path], cfg_coef=2.0
|
[voice_path], cfg_coef=2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Generating audio...")
|
|
||||||
# This doesn't do streaming generation, but the model allows it. For now, see Rust
|
|
||||||
# example.
|
|
||||||
result = tts_model.generate([entries], [condition_attributes])
|
|
||||||
|
|
||||||
frames = torch.cat(result.frames, dim=-1)
|
|
||||||
audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :]
|
|
||||||
with torch.no_grad():
|
|
||||||
audios = tts_model.mimi.decode(audio_tokens)
|
|
||||||
|
|
||||||
if args.out == "-":
|
if args.out == "-":
|
||||||
print("Playing audio...")
|
# Stream the audio to the speakers using sounddevice.
|
||||||
play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate)
|
import sounddevice as sd
|
||||||
|
pcms = queue.Queue()
|
||||||
|
|
||||||
|
def _on_frame(frame):
|
||||||
|
if (frame != -1).all():
|
||||||
|
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||||
|
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
|
||||||
|
|
||||||
|
def audio_callback(outdata, _a, _b, _c):
|
||||||
|
try:
|
||||||
|
pcm_data = pcms.get(block=False)
|
||||||
|
outdata[:, 0] = pcm_data
|
||||||
|
except queue.Empty:
|
||||||
|
outdata[:] = 0
|
||||||
|
with sd.OutputStream(samplerate=tts_model.mimi.sample_rate,
|
||||||
|
blocksize=1920,
|
||||||
|
channels=1,
|
||||||
|
callback=audio_callback):
|
||||||
|
tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)
|
||||||
|
time.sleep(3)
|
||||||
|
while True:
|
||||||
|
if pcms.qsize() == 0:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
else:
|
else:
|
||||||
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate)
|
result = tts_model.generate([entries], [condition_attributes])
|
||||||
print(f"Audio saved to {args.out}")
|
with torch.no_grad():
|
||||||
|
pcms = []
|
||||||
|
for frame in result.frames[tts_model.delay_steps:]:
|
||||||
|
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||||
|
pcms.append(np.clip(pcm[0, 0], -1, 1))
|
||||||
|
pcm = np.concatenate(pcms, axis=-1)
|
||||||
|
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user