VAD support.
This commit is contained in:
parent
846042d0a4
commit
dc6f552a47
|
|
@ -25,6 +25,9 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("in_file", help="The file to transcribe.")
|
parser.add_argument("in_file", help="The file to transcribe.")
|
||||||
parser.add_argument("--max-steps", default=4096)
|
parser.add_argument("--max-steps", default=4096)
|
||||||
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
|
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
|
||||||
|
parser.add_argument(
|
||||||
|
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
audio, _ = sphn.read(args.in_file, sample_rate=24000)
|
audio, _ = sphn.read(args.in_file, sample_rate=24000)
|
||||||
|
|
@ -65,11 +68,19 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
print(f"starting inference {audio.shape}")
|
print(f"starting inference {audio.shape}")
|
||||||
audio = mx.concat([mx.array(audio), mx.zeros((1, 48000))], axis=-1)
|
audio = mx.concat([mx.array(audio), mx.zeros((1, 48000))], axis=-1)
|
||||||
|
last_print_was_vad = False
|
||||||
for start_idx in range(0, audio.shape[-1] // 1920 * 1920, 1920):
|
for start_idx in range(0, audio.shape[-1] // 1920 * 1920, 1920):
|
||||||
block = audio[:, None, start_idx:start_idx + 1920]
|
block = audio[:, None, start_idx : start_idx + 1920]
|
||||||
other_audio_tokens = audio_tokenizer.encode_step(block)
|
other_audio_tokens = audio_tokenizer.encode_step(block).transpose(0, 2, 1)
|
||||||
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)
|
if args.vad:
|
||||||
text_token = gen.step(other_audio_tokens[0])
|
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
|
||||||
|
if vad_heads:
|
||||||
|
pr_vad = vad_heads[2][0, 0, 0].item()
|
||||||
|
if pr_vad > 0.5 and not last_print_was_vad:
|
||||||
|
print(" [end of turn detected]")
|
||||||
|
last_print_was_vad = True
|
||||||
|
else:
|
||||||
|
text_token = gen.step(other_audio_tokens[0])
|
||||||
text_token = text_token[0].item()
|
text_token = text_token[0].item()
|
||||||
audio_tokens = gen.last_audio_tokens()
|
audio_tokens = gen.last_audio_tokens()
|
||||||
_text = None
|
_text = None
|
||||||
|
|
@ -77,5 +88,5 @@ if __name__ == "__main__":
|
||||||
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
|
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
|
||||||
_text = _text.replace("▁", " ")
|
_text = _text.replace("▁", " ")
|
||||||
print(_text, end="", flush=True)
|
print(_text, end="", flush=True)
|
||||||
|
last_print_was_vad = False
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user