VAD support in the mlx-stt example that uses the microphone.
This commit is contained in:
parent
952319de90
commit
baf0c75bba
|
|
@ -2,7 +2,7 @@
|
||||||
# requires-python = ">=3.12"
|
# requires-python = ">=3.12"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "huggingface_hub",
|
# "huggingface_hub",
|
||||||
# "moshi_mlx",
|
# "moshi_mlx==0.2.10",
|
||||||
# "numpy",
|
# "numpy",
|
||||||
# "rustymimi",
|
# "rustymimi",
|
||||||
# "sentencepiece",
|
# "sentencepiece",
|
||||||
|
|
@ -26,6 +26,9 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--max-steps", default=4096)
|
parser.add_argument("--max-steps", default=4096)
|
||||||
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
|
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
|
||||||
|
parser.add_argument(
|
||||||
|
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
||||||
|
|
@ -45,6 +48,9 @@ if __name__ == "__main__":
|
||||||
nn.quantize(model, bits=8, group_size=64)
|
nn.quantize(model, bits=8, group_size=64)
|
||||||
|
|
||||||
print(f"loading model weights from {moshi_weights}")
|
print(f"loading model weights from {moshi_weights}")
|
||||||
|
if args.hf_repo.endswith("-candle"):
|
||||||
|
model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
|
||||||
|
else:
|
||||||
model.load_weights(moshi_weights, strict=True)
|
model.load_weights(moshi_weights, strict=True)
|
||||||
|
|
||||||
print(f"loading the text tokenizer from {tokenizer}")
|
print(f"loading the text tokenizer from {tokenizer}")
|
||||||
|
|
@ -71,6 +77,7 @@ if __name__ == "__main__":
|
||||||
block_queue.put(indata.copy())
|
block_queue.put(indata.copy())
|
||||||
|
|
||||||
print("recording audio from microphone, speak to get your words transcribed")
|
print("recording audio from microphone, speak to get your words transcribed")
|
||||||
|
last_print_was_vad = False
|
||||||
with sd.InputStream(
|
with sd.InputStream(
|
||||||
channels=1,
|
channels=1,
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
|
|
@ -85,6 +92,14 @@ if __name__ == "__main__":
|
||||||
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
|
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
|
||||||
:, :, :other_codebooks
|
:, :, :other_codebooks
|
||||||
]
|
]
|
||||||
|
if args.vad:
|
||||||
|
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
|
||||||
|
if vad_heads:
|
||||||
|
pr_vad = vad_heads[2][0, 0, 0].item()
|
||||||
|
if pr_vad > 0.5 and not last_print_was_vad:
|
||||||
|
print(" [end of turn detected]")
|
||||||
|
last_print_was_vad = True
|
||||||
|
else:
|
||||||
text_token = gen.step(other_audio_tokens[0])
|
text_token = gen.step(other_audio_tokens[0])
|
||||||
text_token = text_token[0].item()
|
text_token = text_token[0].item()
|
||||||
audio_tokens = gen.last_audio_tokens()
|
audio_tokens = gen.last_audio_tokens()
|
||||||
|
|
@ -93,3 +108,4 @@ if __name__ == "__main__":
|
||||||
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
|
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
|
||||||
_text = _text.replace("▁", " ")
|
_text = _text.replace("▁", " ")
|
||||||
print(_text, end="", flush=True)
|
print(_text, end="", flush=True)
|
||||||
|
last_print_was_vad = False
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user