VAD support.
This commit is contained in:
parent
846042d0a4
commit
dc6f552a47
|
|
@ -25,6 +25,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument("in_file", help="The file to transcribe.")
|
||||
parser.add_argument("--max-steps", default=4096)
|
||||
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
|
||||
parser.add_argument(
|
||||
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
audio, _ = sphn.read(args.in_file, sample_rate=24000)
|
||||
|
|
@ -65,11 +68,19 @@ if __name__ == "__main__":
|
|||
|
||||
print(f"starting inference {audio.shape}")
|
||||
audio = mx.concat([mx.array(audio), mx.zeros((1, 48000))], axis=-1)
|
||||
last_print_was_vad = False
|
||||
for start_idx in range(0, audio.shape[-1] // 1920 * 1920, 1920):
|
||||
block = audio[:, None, start_idx:start_idx + 1920]
|
||||
other_audio_tokens = audio_tokenizer.encode_step(block)
|
||||
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)
|
||||
text_token = gen.step(other_audio_tokens[0])
|
||||
block = audio[:, None, start_idx : start_idx + 1920]
|
||||
other_audio_tokens = audio_tokenizer.encode_step(block).transpose(0, 2, 1)
|
||||
if args.vad:
|
||||
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
|
||||
if vad_heads:
|
||||
pr_vad = vad_heads[2][0, 0, 0].item()
|
||||
if pr_vad > 0.5 and not last_print_was_vad:
|
||||
print(" [end of turn detected]")
|
||||
last_print_was_vad = True
|
||||
else:
|
||||
text_token = gen.step(other_audio_tokens[0])
|
||||
text_token = text_token[0].item()
|
||||
audio_tokens = gen.last_audio_tokens()
|
||||
_text = None
|
||||
|
|
@ -77,5 +88,5 @@ if __name__ == "__main__":
|
|||
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
|
||||
_text = _text.replace("▁", " ")
|
||||
print(_text, end="", flush=True)
|
||||
last_print_was_vad = False
|
||||
print()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user