"""An example script that illustrates how one can prompt Kyutai STT models.""" import argparse import dataclasses import itertools import math from collections import deque import julius import moshi.models import sphn import torch import tqdm class PromptHook: def __init__(self, tokenizer, prefix, padding_tokens=(0, 3,)): self.tokenizer = tokenizer self.prefix_enforce = deque(self.tokenizer.encode(prefix)) self.padding_tokens = padding_tokens def on_token(self, token): if not self.prefix_enforce: return token = token.item() if token in self.padding_tokens: pass elif token == self.prefix_enforce[0]: self.prefix_enforce.popleft() else: assert False def on_logits(self, logits): if not self.prefix_enforce: return mask = torch.zeros_like(logits, dtype=torch.bool) for t in self.padding_tokens: mask[..., t] = True mask[..., self.prefix_enforce[0]] = True logits[:] = torch.where(mask, logits, float("-inf")) def main(args): info = moshi.models.loaders.CheckpointInfo.from_hf_repo( args.hf_repo, moshi_weights=args.moshi_weight, mimi_weights=args.mimi_weight, tokenizer=args.tokenizer, config_path=args.config_path, ) mimi = info.get_mimi(device=args.device) tokenizer = info.get_text_tokenizer() lm = info.get_moshi( device=args.device, dtype=torch.bfloat16, ) if args.prompt_text: prompt_hook = PromptHook(tokenizer, args.prompt_text) lm_gen = moshi.models.LMGen( lm, temp=0, temp_text=0.0, on_text_hook=prompt_hook.on_token, on_text_logits_hook=prompt_hook.on_logits, ) else: lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0) audio_silence_prefix_seconds = info.stt_config.get( "audio_silence_prefix_seconds", 1.0 ) audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0) padding_token_id = info.raw_config.get("text_padding_token_id", 3) def _load_and_process(path): audio, input_sample_rate = sphn.read(path) audio = torch.from_numpy(audio).to(args.device).mean(axis=0, keepdim=True) audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate) if audio.shape[-1] % mimi.frame_size != 0: to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size audio = torch.nn.functional.pad(audio, (0, to_pad)) return audio n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate) n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate) silence_chunk = torch.zeros( (1, 1, mimi.frame_size), dtype=torch.float32, device=args.device ) audio = _load_and_process(args.file) if args.prompt_file: audio_prompt = _load_and_process(args.prompt_file) else: audio_prompt = None chain = [itertools.repeat(silence_chunk, n_prefix_chunks)] if audio_prompt is not None: chain.append(torch.split(audio_prompt[:, None], mimi.frame_size, dim=-1)) chain += [ torch.split(audio[:, None], mimi.frame_size, dim=-1), itertools.repeat(silence_chunk, n_suffix_chunks), ] chunks = itertools.chain(*chain) text_tokens_accum = [] with mimi.streaming(1), lm_gen.streaming(1): for audio_chunk in tqdm.tqdm(chunks): audio_tokens = mimi.encode(audio_chunk) text_tokens = lm_gen.step(audio_tokens) if text_tokens is not None: text_tokens_accum.append(text_tokens) utterance_tokens = torch.concat(text_tokens_accum, dim=-1) text_tokens = utterance_tokens.cpu().view(-1) text = tokenizer.decode( text_tokens[text_tokens > padding_token_id].numpy().tolist() ) print(text) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Example streaming STT w/ a prompt.") parser.add_argument( "--file", required=True, help="File to transcribe.", ) parser.add_argument( "--prompt_file", required=False, help="Audio of the prompt.", ) parser.add_argument( "--prompt_text", required=False, help="Text of the prompt.", ) parser.add_argument( "--hf-repo", type=str, help="HF repo to load the STT model from. " ) parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") parser.add_argument( "--moshi-weight", type=str, help="Path to a local checkpoint file." ) parser.add_argument( "--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi." ) parser.add_argument( "--config-path", type=str, help="Path to a local config file.", default=None ) parser.add_argument( "--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.", ) args = parser.parse_args() main(args)