kyutai/scripts/stt_from_file_with_prompt_pytorch.py

189 lines
6.1 KiB
Python
Raw Normal View History

"""An example script that illustrates how one can prompt Kyutai STT models."""
import argparse
import dataclasses
import itertools
import math
from collections import deque
import julius
import moshi.models
import sphn
import torch
import tqdm
class PromptHook:
2025-07-02 15:59:44 +00:00
def __init__(self, tokenizer, prefix, padding_tokens=(0, 3)):
self.tokenizer = tokenizer
self.prefix_enforce = deque(self.tokenizer.encode(prefix))
self.padding_tokens = padding_tokens
def on_token(self, token):
if not self.prefix_enforce:
return
token = token.item()
if token in self.padding_tokens:
pass
elif token == self.prefix_enforce[0]:
self.prefix_enforce.popleft()
else:
assert False
def on_logits(self, logits):
if not self.prefix_enforce:
return
mask = torch.zeros_like(logits, dtype=torch.bool)
for t in self.padding_tokens:
mask[..., t] = True
mask[..., self.prefix_enforce[0]] = True
logits[:] = torch.where(mask, logits, float("-inf"))
def main(args):
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
args.hf_repo,
moshi_weights=args.moshi_weight,
mimi_weights=args.mimi_weight,
tokenizer=args.tokenizer,
config_path=args.config_path,
)
mimi = info.get_mimi(device=args.device)
tokenizer = info.get_text_tokenizer()
lm = info.get_moshi(
device=args.device,
dtype=torch.bfloat16,
)
if args.prompt_text:
prompt_hook = PromptHook(tokenizer, args.prompt_text)
lm_gen = moshi.models.LMGen(
lm,
temp=0,
temp_text=0.0,
on_text_hook=prompt_hook.on_token,
on_text_logits_hook=prompt_hook.on_logits,
)
else:
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
audio_silence_prefix_seconds = info.stt_config.get(
"audio_silence_prefix_seconds", 1.0
)
audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
padding_token_id = info.raw_config.get("text_padding_token_id", 3)
def _load_and_process(path):
audio, input_sample_rate = sphn.read(path)
audio = torch.from_numpy(audio).to(args.device).mean(axis=0, keepdim=True)
audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate)
if audio.shape[-1] % mimi.frame_size != 0:
to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size
audio = torch.nn.functional.pad(audio, (0, to_pad))
return audio
n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate)
n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate)
silence_chunk = torch.zeros(
(1, 1, mimi.frame_size), dtype=torch.float32, device=args.device
)
audio = _load_and_process(args.file)
if args.prompt_file:
audio_prompt = _load_and_process(args.prompt_file)
else:
audio_prompt = None
chain = [itertools.repeat(silence_chunk, n_prefix_chunks)]
if audio_prompt is not None:
chain.append(torch.split(audio_prompt[:, None, :], mimi.frame_size, dim=-1))
# adding a bit (0.8s) of silence to separate prompt and the actual audio
chain.append(itertools.repeat(silence_chunk, 10))
chain += [
torch.split(audio[:, None, :], mimi.frame_size, dim=-1),
itertools.repeat(silence_chunk, n_suffix_chunks),
]
chunks = itertools.chain(*chain)
text_tokens_accum = []
with mimi.streaming(1), lm_gen.streaming(1):
for audio_chunk in tqdm.tqdm(chunks):
audio_tokens = mimi.encode(audio_chunk)
text_tokens = lm_gen.step(audio_tokens)
if text_tokens is not None:
text_tokens_accum.append(text_tokens)
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
text_tokens = utterance_tokens.cpu().view(-1)
# if we have an audio prompt and we don't want to have it in the transcript,
# we should cut the corresponding number of frames from the output tokens.
# However, there is also some amount of padding that happens before it
# due to silence_prefix and audio_delay. Normally it is ignored in detokenization,
# but now we should account for it to find the position of the prompt transcript.
if args.cut_prompt_transcript and audio_prompt is not None:
prompt_frames = audio_prompt.shape[1] // mimi.frame_size
no_prompt_offset_seconds = audio_delay_seconds + audio_silence_prefix_seconds
no_prompt_offset = int(no_prompt_offset_seconds * mimi.frame_rate)
2025-07-02 15:59:44 +00:00
text_tokens = text_tokens[prompt_frames + no_prompt_offset :]
text = tokenizer.decode(
text_tokens[text_tokens > padding_token_id].numpy().tolist()
)
print(text)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Example streaming STT w/ a prompt.")
parser.add_argument(
"--file",
required=True,
help="File to transcribe.",
)
parser.add_argument(
"--prompt_file",
required=False,
help="Audio of the prompt.",
)
parser.add_argument(
"--prompt_text",
required=False,
help="Text of the prompt.",
)
parser.add_argument(
"--cut-prompt-transcript",
action="store_true",
help="Cut the prompt from the output transcript",
)
parser.add_argument(
"--hf-repo", type=str, help="HF repo to load the STT model from. "
)
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
parser.add_argument(
"--moshi-weight", type=str, help="Path to a local checkpoint file."
)
parser.add_argument(
"--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi."
)
parser.add_argument(
"--config-path", type=str, help="Path to a local config file.", default=None
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
args = parser.parse_args()
main(args)