Add some VAD to the pytorch speech-to-text example. (#68)

This commit is contained in:
Laurent Mazare 2025-07-08 11:30:34 +02:00 committed by GitHub
parent cafac63222
commit 12dbe36b0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@
# "julius",
# "librosa",
# "soundfile",
# "moshi",
# "moshi==0.2.9",
# ]
# ///
@ -20,8 +20,8 @@ import math
import julius
import moshi.models
import sphn
import time
import torch
import tqdm
@dataclasses.dataclass
@ -138,9 +138,13 @@ def main(args):
mimi = info.get_mimi(device=args.device)
tokenizer = info.get_text_tokenizer()
lm_kwargs_overrides = {}
if args.vad:
lm_kwargs_overrides = {"extra_heads_num_heads": 4}
lm = info.get_moshi(
device=args.device,
dtype=torch.bfloat16,
lm_kwargs_overrides=lm_kwargs_overrides,
)
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
@ -171,14 +175,35 @@ def main(args):
itertools.repeat(silence_chunk, n_suffix_chunks),
)
start_time = time.time()
nchunks = 0
last_print_was_vad = False
with mimi.streaming(1), lm_gen.streaming(1):
for audio_chunk in tqdm.tqdm(chunks):
for audio_chunk in chunks:
nchunks += 1
audio_tokens = mimi.encode(audio_chunk)
if args.vad:
text_tokens, vad_heads = lm_gen.step_with_extra_heads(audio_tokens)
if vad_heads:
pr_vad = vad_heads[2][0, 0, 0].cpu().item()
if pr_vad > 0.5 and not last_print_was_vad:
print(" [end of turn detected]")
last_print_was_vad = True
else:
text_tokens = lm_gen.step(audio_tokens)
if text_tokens is not None:
text_token = text_tokens[0, 0, 0].cpu().item()
if text_token not in (0, 3):
_text = tokenizer.id_to_piece(text_tokens[0, 0, 0].cpu().item()) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)
last_print_was_vad = False
text_tokens_accum.append(text_tokens)
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
dt = time.time() - start_time
print(
f"\nprocessed {nchunks} chunks in {dt:.2f} seconds, steps per second: {nchunks / dt:.2f}"
)
timed_text = tokens_to_timestamped_text(
utterance_tokens,
tokenizer,
@ -209,6 +234,9 @@ if __name__ == "__main__":
parser.add_argument(
"--config-path", type=str, help="Path to a local config file.", default=None
)
parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
)
parser.add_argument(
"--device",
type=str,