Print the duration of the audio generated so far. (#107)

This commit is contained in:
Laurent Mazare 2025-08-04 09:24:31 +02:00 committed by GitHub
parent 07729ed47e
commit 09468c239a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -78,6 +78,7 @@ def main():
condition_attributes = tts_model.make_condition_attributes( condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0 [voice_path], cfg_coef=2.0
) )
_frames_cnt = 0
if args.out == "-": if args.out == "-":
# Stream the audio to the speakers using sounddevice. # Stream the audio to the speakers using sounddevice.
@ -86,9 +87,12 @@ def main():
pcms = queue.Queue() pcms = queue.Queue()
def _on_frame(frame): def _on_frame(frame):
nonlocal _frames_cnt
if (frame != -1).all(): if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1)) pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
_frames_cnt += 1
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
def audio_callback(outdata, _a, _b, _c): def audio_callback(outdata, _a, _b, _c):
try: try:
@ -113,7 +117,16 @@ def main():
break break
time.sleep(1) time.sleep(1)
else: else:
result = tts_model.generate([entries], [condition_attributes])
def _on_frame(frame):
nonlocal _frames_cnt
if (frame != -1).all():
_frames_cnt += 1
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
result = tts_model.generate(
[entries], [condition_attributes], on_frame=_on_frame
)
with tts_model.mimi.streaming(1), torch.no_grad(): with tts_model.mimi.streaming(1), torch.no_grad():
pcms = [] pcms = []
for frame in result.frames[tts_model.delay_steps :]: for frame in result.frames[tts_model.delay_steps :]: