Add some VAD to the pytorch speech-to-text example.
This commit is contained in:
parent
cafac63222
commit
c9c4e0a573
|
|
@ -4,7 +4,7 @@
|
|||
# "julius",
|
||||
# "librosa",
|
||||
# "soundfile",
|
||||
# "moshi",
|
||||
# "moshi==0.2.9",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
|
|
@ -20,8 +20,8 @@ import math
|
|||
import julius
|
||||
import moshi.models
|
||||
import sphn
|
||||
import time
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
@ -138,9 +138,13 @@ def main(args):
|
|||
|
||||
mimi = info.get_mimi(device=args.device)
|
||||
tokenizer = info.get_text_tokenizer()
|
||||
lm_kwargs_overrides = {}
|
||||
if args.vad:
|
||||
lm_kwargs_overrides = {"extra_heads_num_heads": 4}
|
||||
lm = info.get_moshi(
|
||||
device=args.device,
|
||||
dtype=torch.bfloat16,
|
||||
lm_kwargs_overrides=lm_kwargs_overrides,
|
||||
)
|
||||
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
|
||||
|
||||
|
|
@ -171,14 +175,35 @@ def main(args):
|
|||
itertools.repeat(silence_chunk, n_suffix_chunks),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
nchunks = 0
|
||||
last_print_was_vad = False
|
||||
with mimi.streaming(1), lm_gen.streaming(1):
|
||||
for audio_chunk in tqdm.tqdm(chunks):
|
||||
for audio_chunk in chunks:
|
||||
nchunks += 1
|
||||
audio_tokens = mimi.encode(audio_chunk)
|
||||
text_tokens = lm_gen.step(audio_tokens)
|
||||
if text_tokens is not None:
|
||||
text_tokens_accum.append(text_tokens)
|
||||
if args.vad:
|
||||
text_tokens, vad_heads = lm_gen.step_with_extra_heads(audio_tokens)
|
||||
if vad_heads:
|
||||
pr_vad = vad_heads[2][0, 0, 0].cpu().item()
|
||||
if pr_vad > 0.5 and not last_print_was_vad:
|
||||
print(" [end of turn detected]")
|
||||
last_print_was_vad = True
|
||||
else:
|
||||
text_tokens = lm_gen.step(audio_tokens)
|
||||
text_token = text_tokens[0, 0, 0].cpu().item()
|
||||
if text_token not in (0, 3):
|
||||
_text = tokenizer.id_to_piece(text_tokens[0, 0, 0].cpu().item()) # type: ignore
|
||||
_text = _text.replace("▁", " ")
|
||||
print(_text, end="", flush=True)
|
||||
last_print_was_vad = False
|
||||
text_tokens_accum.append(text_tokens)
|
||||
|
||||
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
|
||||
dt = time.time() - start_time
|
||||
print(
|
||||
f"\nprocessed {nchunks} chunks in {dt:.2f} seconds, steps per second: {nchunks / dt:.2f}"
|
||||
)
|
||||
timed_text = tokens_to_timestamped_text(
|
||||
utterance_tokens,
|
||||
tokenizer,
|
||||
|
|
@ -209,6 +234,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"--config-path", type=str, help="Path to a local config file.", default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user