VAD support.

This commit is contained in:
Laurent 2025-07-08 15:43:38 +02:00
parent 846042d0a4
commit dc6f552a47

View File

@ -25,6 +25,9 @@ if __name__ == "__main__":
parser.add_argument("in_file", help="The file to transcribe.")
parser.add_argument("--max-steps", default=4096)
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
)
args = parser.parse_args()
audio, _ = sphn.read(args.in_file, sample_rate=24000)
@ -65,10 +68,18 @@ if __name__ == "__main__":
print(f"starting inference {audio.shape}")
audio = mx.concat([mx.array(audio), mx.zeros((1, 48000))], axis=-1)
last_print_was_vad = False
for start_idx in range(0, audio.shape[-1] // 1920 * 1920, 1920):
block = audio[:, None, start_idx:start_idx + 1920]
other_audio_tokens = audio_tokenizer.encode_step(block)
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)
block = audio[:, None, start_idx : start_idx + 1920]
other_audio_tokens = audio_tokenizer.encode_step(block).transpose(0, 2, 1)
if args.vad:
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
if vad_heads:
pr_vad = vad_heads[2][0, 0, 0].item()
if pr_vad > 0.5 and not last_print_was_vad:
print(" [end of turn detected]")
last_print_was_vad = True
else:
text_token = gen.step(other_audio_tokens[0])
text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens()
@ -77,5 +88,5 @@ if __name__ == "__main__":
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)
last_print_was_vad = False
print()