Add some VAD to the pytorch speech-to-text example.
This commit is contained in:
parent
cafac63222
commit
c9c4e0a573
|
|
@ -4,7 +4,7 @@
|
||||||
# "julius",
|
# "julius",
|
||||||
# "librosa",
|
# "librosa",
|
||||||
# "soundfile",
|
# "soundfile",
|
||||||
# "moshi",
|
# "moshi==0.2.9",
|
||||||
# ]
|
# ]
|
||||||
# ///
|
# ///
|
||||||
|
|
||||||
|
|
@ -20,8 +20,8 @@ import math
|
||||||
import julius
|
import julius
|
||||||
import moshi.models
|
import moshi.models
|
||||||
import sphn
|
import sphn
|
||||||
|
import time
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
|
@ -138,9 +138,13 @@ def main(args):
|
||||||
|
|
||||||
mimi = info.get_mimi(device=args.device)
|
mimi = info.get_mimi(device=args.device)
|
||||||
tokenizer = info.get_text_tokenizer()
|
tokenizer = info.get_text_tokenizer()
|
||||||
|
lm_kwargs_overrides = {}
|
||||||
|
if args.vad:
|
||||||
|
lm_kwargs_overrides = {"extra_heads_num_heads": 4}
|
||||||
lm = info.get_moshi(
|
lm = info.get_moshi(
|
||||||
device=args.device,
|
device=args.device,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
|
lm_kwargs_overrides=lm_kwargs_overrides,
|
||||||
)
|
)
|
||||||
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
|
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
|
||||||
|
|
||||||
|
|
@ -171,14 +175,35 @@ def main(args):
|
||||||
itertools.repeat(silence_chunk, n_suffix_chunks),
|
itertools.repeat(silence_chunk, n_suffix_chunks),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
nchunks = 0
|
||||||
|
last_print_was_vad = False
|
||||||
with mimi.streaming(1), lm_gen.streaming(1):
|
with mimi.streaming(1), lm_gen.streaming(1):
|
||||||
for audio_chunk in tqdm.tqdm(chunks):
|
for audio_chunk in chunks:
|
||||||
|
nchunks += 1
|
||||||
audio_tokens = mimi.encode(audio_chunk)
|
audio_tokens = mimi.encode(audio_chunk)
|
||||||
|
if args.vad:
|
||||||
|
text_tokens, vad_heads = lm_gen.step_with_extra_heads(audio_tokens)
|
||||||
|
if vad_heads:
|
||||||
|
pr_vad = vad_heads[2][0, 0, 0].cpu().item()
|
||||||
|
if pr_vad > 0.5 and not last_print_was_vad:
|
||||||
|
print(" [end of turn detected]")
|
||||||
|
last_print_was_vad = True
|
||||||
|
else:
|
||||||
text_tokens = lm_gen.step(audio_tokens)
|
text_tokens = lm_gen.step(audio_tokens)
|
||||||
if text_tokens is not None:
|
text_token = text_tokens[0, 0, 0].cpu().item()
|
||||||
|
if text_token not in (0, 3):
|
||||||
|
_text = tokenizer.id_to_piece(text_tokens[0, 0, 0].cpu().item()) # type: ignore
|
||||||
|
_text = _text.replace("▁", " ")
|
||||||
|
print(_text, end="", flush=True)
|
||||||
|
last_print_was_vad = False
|
||||||
text_tokens_accum.append(text_tokens)
|
text_tokens_accum.append(text_tokens)
|
||||||
|
|
||||||
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
|
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
|
||||||
|
dt = time.time() - start_time
|
||||||
|
print(
|
||||||
|
f"\nprocessed {nchunks} chunks in {dt:.2f} seconds, steps per second: {nchunks / dt:.2f}"
|
||||||
|
)
|
||||||
timed_text = tokens_to_timestamped_text(
|
timed_text = tokens_to_timestamped_text(
|
||||||
utterance_tokens,
|
utterance_tokens,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
|
@ -209,6 +234,9 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config-path", type=str, help="Path to a local config file.", default=None
|
"--config-path", type=str, help="Path to a local config file.", default=None
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user