Compare commits

...

2 Commits

Author SHA1 Message Date
laurent
221f3cbad8 Avoid the config override for the extra-heads. 2025-07-08 12:18:28 +02:00
laurent
c9c4e0a573 Add some VAD to the pytorch speech-to-text example. 2025-07-08 11:28:54 +02:00

View File

@ -4,7 +4,7 @@
# "julius",
# "librosa",
# "soundfile",
# "moshi",
# "moshi==0.2.9",
# ]
# ///
@ -20,8 +20,8 @@ import math
import julius
import moshi.models
import sphn
import time
import torch
import tqdm
@dataclasses.dataclass
@ -171,14 +171,35 @@ def main(args):
itertools.repeat(silence_chunk, n_suffix_chunks),
)
start_time = time.time()
nchunks = 0
last_print_was_vad = False
with mimi.streaming(1), lm_gen.streaming(1):
for audio_chunk in tqdm.tqdm(chunks):
for audio_chunk in chunks:
nchunks += 1
audio_tokens = mimi.encode(audio_chunk)
text_tokens = lm_gen.step(audio_tokens)
if text_tokens is not None:
text_tokens_accum.append(text_tokens)
if args.vad:
text_tokens, vad_heads = lm_gen.step_with_extra_heads(audio_tokens)
if vad_heads:
pr_vad = vad_heads[2][0, 0, 0].cpu().item()
if pr_vad > 0.5 and not last_print_was_vad:
print(" [end of turn detected]")
last_print_was_vad = True
else:
text_tokens = lm_gen.step(audio_tokens)
text_token = text_tokens[0, 0, 0].cpu().item()
if text_token not in (0, 3):
_text = tokenizer.id_to_piece(text_tokens[0, 0, 0].cpu().item()) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)
last_print_was_vad = False
text_tokens_accum.append(text_tokens)
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
dt = time.time() - start_time
print(
f"\nprocessed {nchunks} chunks in {dt:.2f} seconds, steps per second: {nchunks / dt:.2f}"
)
timed_text = tokens_to_timestamped_text(
utterance_tokens,
tokenizer,
@ -209,6 +230,9 @@ if __name__ == "__main__":
parser.add_argument(
"--config-path", type=str, help="Path to a local config file.", default=None
)
parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
)
parser.add_argument(
"--device",
type=str,