VAD support in the mlx-stt example that uses the microphone.

This commit is contained in:
Laurent 2025-07-08 16:08:32 +02:00
parent 952319de90
commit baf0c75bba

View File

@ -2,7 +2,7 @@
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
# "moshi_mlx",
# "moshi_mlx==0.2.10",
# "numpy",
# "rustymimi",
# "sentencepiece",
@ -26,6 +26,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max-steps", default=4096)
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
)
args = parser.parse_args()
lm_config = hf_hub_download(args.hf_repo, "config.json")
@ -45,6 +48,9 @@ if __name__ == "__main__":
nn.quantize(model, bits=8, group_size=64)
print(f"loading model weights from {moshi_weights}")
if args.hf_repo.endswith("-candle"):
model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
else:
model.load_weights(moshi_weights, strict=True)
print(f"loading the text tokenizer from {tokenizer}")
@ -71,6 +77,7 @@ if __name__ == "__main__":
block_queue.put(indata.copy())
print("recording audio from microphone, speak to get your words transcribed")
last_print_was_vad = False
with sd.InputStream(
channels=1,
dtype="float32",
@ -85,6 +92,14 @@ if __name__ == "__main__":
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
:, :, :other_codebooks
]
if args.vad:
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
if vad_heads:
pr_vad = vad_heads[2][0, 0, 0].item()
if pr_vad > 0.5 and not last_print_was_vad:
print(" [end of turn detected]")
last_print_was_vad = True
else:
text_token = gen.step(other_audio_tokens[0])
text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens()
@ -93,3 +108,4 @@ if __name__ == "__main__":
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)
last_print_was_vad = False