kyutai/scripts/transcribe_from_file_via_pytorch_with_prompt.py
2025-07-01 22:27:03 +02:00

170 lines
5.0 KiB
Python

"""An example script that illustrates how one can prompt Kyutai STT models."""
import argparse
import dataclasses
import itertools
import math
from collections import deque
import julius
import moshi.models
import sphn
import torch
import tqdm
class PromptHook:
def __init__(self, tokenizer, prefix, padding_tokens=(0, 3,)):
self.tokenizer = tokenizer
self.prefix_enforce = deque(self.tokenizer.encode(prefix))
self.padding_tokens = padding_tokens
def on_token(self, token):
if not self.prefix_enforce:
return
token = token.item()
if token in self.padding_tokens:
pass
elif token == self.prefix_enforce[0]:
self.prefix_enforce.popleft()
else:
assert False
def on_logits(self, logits):
if not self.prefix_enforce:
return
mask = torch.zeros_like(logits, dtype=torch.bool)
for t in self.padding_tokens:
mask[..., t] = True
mask[..., self.prefix_enforce[0]] = True
logits[:] = torch.where(mask, logits, float("-inf"))
def main(args):
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
args.hf_repo,
moshi_weights=args.moshi_weight,
mimi_weights=args.mimi_weight,
tokenizer=args.tokenizer,
config_path=args.config_path,
)
mimi = info.get_mimi(device=args.device)
tokenizer = info.get_text_tokenizer()
lm = info.get_moshi(
device=args.device,
dtype=torch.bfloat16,
)
if args.prompt_text:
prompt_hook = PromptHook(tokenizer, args.prompt_text)
lm_gen = moshi.models.LMGen(
lm,
temp=0,
temp_text=0.0,
on_text_hook=prompt_hook.on_token,
on_text_logits_hook=prompt_hook.on_logits,
)
else:
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
audio_silence_prefix_seconds = info.stt_config.get(
"audio_silence_prefix_seconds", 1.0
)
audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
padding_token_id = info.raw_config.get("text_padding_token_id", 3)
def _load_and_process(path):
audio, input_sample_rate = sphn.read(path)
audio = torch.from_numpy(audio).to(args.device).mean(axis=0, keepdim=True)
audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate)
if audio.shape[-1] % mimi.frame_size != 0:
to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size
audio = torch.nn.functional.pad(audio, (0, to_pad))
return audio
n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate)
n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate)
silence_chunk = torch.zeros(
(1, 1, mimi.frame_size), dtype=torch.float32, device=args.device
)
audio = _load_and_process(args.file)
if args.prompt_file:
audio_prompt = _load_and_process(args.prompt_file)
else:
audio_prompt = None
chain = [itertools.repeat(silence_chunk, n_prefix_chunks)]
if audio_prompt is not None:
chain.append(torch.split(audio_prompt[:, None], mimi.frame_size, dim=-1))
chain += [
torch.split(audio[:, None], mimi.frame_size, dim=-1),
itertools.repeat(silence_chunk, n_suffix_chunks),
]
chunks = itertools.chain(*chain)
text_tokens_accum = []
with mimi.streaming(1), lm_gen.streaming(1):
for audio_chunk in tqdm.tqdm(chunks):
audio_tokens = mimi.encode(audio_chunk)
text_tokens = lm_gen.step(audio_tokens)
if text_tokens is not None:
text_tokens_accum.append(text_tokens)
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
text_tokens = utterance_tokens.cpu().view(-1)
text = tokenizer.decode(
text_tokens[text_tokens > padding_token_id].numpy().tolist()
)
print(text)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Example streaming STT w/ a prompt.")
parser.add_argument(
"--file",
required=True,
help="File to transcribe.",
)
parser.add_argument(
"--prompt_file",
required=False,
help="Audio of the prompt.",
)
parser.add_argument(
"--prompt_text",
required=False,
help="Text of the prompt.",
)
parser.add_argument(
"--hf-repo", type=str, help="HF repo to load the STT model from. "
)
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
parser.add_argument(
"--moshi-weight", type=str, help="Path to a local checkpoint file."
)
parser.add_argument(
"--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi."
)
parser.add_argument(
"--config-path", type=str, help="Path to a local config file.", default=None
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
args = parser.parse_args()
main(args)