Allow for playing the audio in a streaming way.
This commit is contained in:
parent
61206d78c8
commit
236df99003
|
|
@ -2,16 +2,16 @@
|
||||||
# requires-python = ">=3.12"
|
# requires-python = ">=3.12"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "huggingface_hub",
|
# "huggingface_hub",
|
||||||
# "moshi_mlx",
|
# "moshi_mlx>=0.2.8",
|
||||||
# "numpy",
|
# "numpy",
|
||||||
# "sounddevice",
|
# "sounddevice",
|
||||||
# ]
|
# ]
|
||||||
# ///
|
# ///
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import queue
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -104,6 +104,7 @@ def main():
|
||||||
cfg_is_no_prefix = True
|
cfg_is_no_prefix = True
|
||||||
mimi = tts_model.mimi
|
mimi = tts_model.mimi
|
||||||
|
|
||||||
|
log("info", f"reading input from {args.inp}")
|
||||||
with open(args.inp, "r") as fobj:
|
with open(args.inp, "r") as fobj:
|
||||||
text_to_tts = fobj.read().strip()
|
text_to_tts = fobj.read().strip()
|
||||||
|
|
||||||
|
|
@ -114,39 +115,56 @@ def main():
|
||||||
voices = []
|
voices = []
|
||||||
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
|
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
|
||||||
|
|
||||||
begin = time.time()
|
wav_frames = queue.Queue()
|
||||||
result = tts_model.generate(
|
def _on_audio_hook(audio_tokens):
|
||||||
all_entries, all_attributes,
|
if (audio_tokens == -1).any():
|
||||||
cfg_is_no_prefix=cfg_is_no_prefix, cfg_is_no_text=cfg_is_no_text)
|
return
|
||||||
frames = mx.concat(result.frames, axis=-1)
|
_pcm = tts_model.mimi.decode_step(audio_tokens[None, :, None])
|
||||||
total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate
|
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
|
||||||
time_taken = time.time() - begin
|
wav_frames.put_nowait(_pcm)
|
||||||
total_speed = total_duration / time_taken
|
|
||||||
log("info", f"[LM] took {time_taken:.2f}s, total speed {total_speed:.2f}x")
|
def run():
|
||||||
|
log("info", "starting the inference loop")
|
||||||
|
begin = time.time()
|
||||||
|
result = tts_model.generate(
|
||||||
|
all_entries,
|
||||||
|
all_attributes,
|
||||||
|
cfg_is_no_prefix=cfg_is_no_prefix,
|
||||||
|
cfg_is_no_text=cfg_is_no_text,
|
||||||
|
on_audio_hook=_on_audio_hook,
|
||||||
|
)
|
||||||
|
frames = mx.concat(result.frames, axis=-1)
|
||||||
|
total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate
|
||||||
|
time_taken = time.time() - begin
|
||||||
|
total_speed = total_duration / time_taken
|
||||||
|
log("info", f"[LM] took {time_taken:.2f}s, total speed {total_speed:.2f}x")
|
||||||
|
return result
|
||||||
|
|
||||||
wav_frames = []
|
|
||||||
for frame in result.frames:
|
|
||||||
# We are processing frames one by one, although we could group them to improve speed.
|
|
||||||
_pcm = tts_model.mimi.decode_step(frame)
|
|
||||||
wav_frames.append(_pcm)
|
|
||||||
if args.out == "-":
|
if args.out == "-":
|
||||||
cnt = [0]
|
|
||||||
def audio_callback(outdata, _a, _b, _c):
|
def audio_callback(outdata, _a, _b, _c):
|
||||||
if cnt[0] < len(wav_frames):
|
try:
|
||||||
outdata[:, 0] = wav_frames[cnt[0]][0, 0]
|
pcm_data = wav_frames.get(block=False)
|
||||||
cnt[0] += 1
|
outdata[:, 0] = pcm_data
|
||||||
else:
|
except queue.Empty:
|
||||||
outdata[:] = 0
|
outdata[:] = 0
|
||||||
with sd.OutputStream(samplerate=mimi.sample_rate,
|
with sd.OutputStream(samplerate=mimi.sample_rate,
|
||||||
blocksize=1920,
|
blocksize=1920,
|
||||||
channels=1,
|
channels=1,
|
||||||
callback=audio_callback):
|
callback=audio_callback):
|
||||||
time.sleep(10)
|
run()
|
||||||
|
time.sleep(3)
|
||||||
|
while True:
|
||||||
|
if wav_frames.qsize() == 0:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
else:
|
else:
|
||||||
wavs = mx.concat(wav_frames, axis=-1)
|
frames = []
|
||||||
end_step = result.end_steps[0]
|
while True:
|
||||||
wav_length = int((mimi.sample_rate * (end_step + tts_model.final_padding) / mimi.frame_rate))
|
try:
|
||||||
wav = np.array(mx.clip(wavs[0, :, :wav_length], -1, 1))
|
frames.append(wav_frames.get_nowait())
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
wav = np.concat(frames, -1)
|
||||||
sphn.write_wav(args.out, wav, mimi.sample_rate)
|
sphn.write_wav(args.out, wav, mimi.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user