VAD support in the mlx-stt example that uses the microphone.
This commit is contained in:
parent
952319de90
commit
baf0c75bba
|
|
@ -2,7 +2,7 @@
|
|||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "huggingface_hub",
|
||||
# "moshi_mlx",
|
||||
# "moshi_mlx==0.2.10",
|
||||
# "numpy",
|
||||
# "rustymimi",
|
||||
# "sentencepiece",
|
||||
|
|
@ -26,6 +26,9 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--max-steps", default=4096)
|
||||
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
|
||||
parser.add_argument(
|
||||
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
||||
|
|
@ -45,6 +48,9 @@ if __name__ == "__main__":
|
|||
nn.quantize(model, bits=8, group_size=64)
|
||||
|
||||
print(f"loading model weights from {moshi_weights}")
|
||||
if args.hf_repo.endswith("-candle"):
|
||||
model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
|
||||
else:
|
||||
model.load_weights(moshi_weights, strict=True)
|
||||
|
||||
print(f"loading the text tokenizer from {tokenizer}")
|
||||
|
|
@ -71,6 +77,7 @@ if __name__ == "__main__":
|
|||
block_queue.put(indata.copy())
|
||||
|
||||
print("recording audio from microphone, speak to get your words transcribed")
|
||||
last_print_was_vad = False
|
||||
with sd.InputStream(
|
||||
channels=1,
|
||||
dtype="float32",
|
||||
|
|
@ -85,6 +92,14 @@ if __name__ == "__main__":
|
|||
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
|
||||
:, :, :other_codebooks
|
||||
]
|
||||
if args.vad:
|
||||
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
|
||||
if vad_heads:
|
||||
pr_vad = vad_heads[2][0, 0, 0].item()
|
||||
if pr_vad > 0.5 and not last_print_was_vad:
|
||||
print(" [end of turn detected]")
|
||||
last_print_was_vad = True
|
||||
else:
|
||||
text_token = gen.step(other_audio_tokens[0])
|
||||
text_token = text_token[0].item()
|
||||
audio_tokens = gen.last_audio_tokens()
|
||||
|
|
@ -93,3 +108,4 @@ if __name__ == "__main__":
|
|||
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
|
||||
_text = _text.replace("▁", " ")
|
||||
print(_text, end="", flush=True)
|
||||
last_print_was_vad = False
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user