diff --git a/README.md b/README.md index e3c8a18..927582b 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,23 @@ uv run scripts/evaluate_on_dataset.py \ --hf-repo kyutai/stt-2.6b-en ``` +Another example shows how one can provide a text-, audio-, or text-audio prompt to our STT model: +```bash +uv run scripts/transcribe_from_file_via_pytorch_with_prompt.py \ + --hf-repo kyutai/stt-2.6b-en \ + --file bria.mp3 \ + --prompt_file ./audio/loonah.mp3 \ + --prompt_text "Loonah" \ + --cut-prompt-transcript +``` +Produces the transcript of `bria.mp3` using the `Loonah` spelling for the name, instead of the `Luna` used without any prompt: +``` +In the heart of an ancient forest, where the trees whispered secrets of the past, there lived a peculiar rabbit named Loonah (...) +``` + +Apart from nudging the model for a specific spelling of a word, other potential use-cases include speaker adaptation and steering the model towards a specific formatting style or even a language. +However, please bear in mind that is an experimental feature and its behavior is very sensitive to the prompt provided. + ### Rust server diff --git a/audio/loona.mp3 b/audio/loona.mp3 new file mode 100644 index 0000000..997cc31 Binary files /dev/null and b/audio/loona.mp3 differ diff --git a/scripts/transcribe_from_file_via_pytorch_with_prompt.py b/scripts/transcribe_from_file_via_pytorch_with_prompt.py new file mode 100644 index 0000000..5861116 --- /dev/null +++ b/scripts/transcribe_from_file_via_pytorch_with_prompt.py @@ -0,0 +1,196 @@ +"""An example script that illustrates how one can prompt Kyutai STT models.""" + +import argparse +import dataclasses +import itertools +import math +from collections import deque + +import julius +import moshi.models +import sphn +import torch +import tqdm + + +class PromptHook: + def __init__( + self, + tokenizer, + prefix, + padding_tokens=( + 0, + 3, + ), + ): + self.tokenizer = tokenizer + self.prefix_enforce = deque(self.tokenizer.encode(prefix)) + self.padding_tokens = padding_tokens + + def on_token(self, token): + if not self.prefix_enforce: + return + + token = token.item() + + if token in self.padding_tokens: + pass + elif token == self.prefix_enforce[0]: + self.prefix_enforce.popleft() + else: + assert False + + def on_logits(self, logits): + if not self.prefix_enforce: + return + + mask = torch.zeros_like(logits, dtype=torch.bool) + for t in self.padding_tokens: + mask[..., t] = True + mask[..., self.prefix_enforce[0]] = True + + logits[:] = torch.where(mask, logits, float("-inf")) + + +def main(args): + info = moshi.models.loaders.CheckpointInfo.from_hf_repo( + args.hf_repo, + moshi_weights=args.moshi_weight, + mimi_weights=args.mimi_weight, + tokenizer=args.tokenizer, + config_path=args.config_path, + ) + + mimi = info.get_mimi(device=args.device) + tokenizer = info.get_text_tokenizer() + lm = info.get_moshi( + device=args.device, + dtype=torch.bfloat16, + ) + + if args.prompt_text: + prompt_hook = PromptHook(tokenizer, args.prompt_text) + lm_gen = moshi.models.LMGen( + lm, + temp=0, + temp_text=0.0, + on_text_hook=prompt_hook.on_token, + on_text_logits_hook=prompt_hook.on_logits, + ) + else: + lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0) + + audio_silence_prefix_seconds = info.stt_config.get( + "audio_silence_prefix_seconds", 1.0 + ) + audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0) + padding_token_id = info.raw_config.get("text_padding_token_id", 3) + + def _load_and_process(path): + audio, input_sample_rate = sphn.read(path) + audio = torch.from_numpy(audio).to(args.device).mean(axis=0, keepdim=True) + audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate) + if audio.shape[-1] % mimi.frame_size != 0: + to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size + audio = torch.nn.functional.pad(audio, (0, to_pad)) + return audio + + n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate) + n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate) + silence_chunk = torch.zeros( + (1, 1, mimi.frame_size), dtype=torch.float32, device=args.device + ) + + audio = _load_and_process(args.file) + if args.prompt_file: + audio_prompt = _load_and_process(args.prompt_file) + else: + audio_prompt = None + + chain = [itertools.repeat(silence_chunk, n_prefix_chunks)] + + if audio_prompt is not None: + chain.append(torch.split(audio_prompt[:, None, :], mimi.frame_size, dim=-1)) + # adding a bit (0.8s) of silence to separate prompt and the actual audio + chain.append(itertools.repeat(silence_chunk, 10)) + + chain += [ + torch.split(audio[:, None, :], mimi.frame_size, dim=-1), + itertools.repeat(silence_chunk, n_suffix_chunks), + ] + + chunks = itertools.chain(*chain) + + text_tokens_accum = [] + with mimi.streaming(1), lm_gen.streaming(1): + for audio_chunk in tqdm.tqdm(chunks): + audio_tokens = mimi.encode(audio_chunk) + text_tokens = lm_gen.step(audio_tokens) + if text_tokens is not None: + text_tokens_accum.append(text_tokens) + + utterance_tokens = torch.concat(text_tokens_accum, dim=-1) + text_tokens = utterance_tokens.cpu().view(-1) + + # if we have an audio prompt and we don't want to have it in the transcript, + # we should cut the corresponding number of frames from the output tokens. + # However, there is also some amount of padding that happens before it + # due to silence_prefix and audio_delay. Normally it is ignored in detokenization, + # but now we should account for it to find the position of the prompt transcript. + if args.cut_prompt_transcript and audio_prompt is not None: + prompt_frames = audio_prompt.shape[1] // mimi.frame_size + no_prompt_offset_seconds = audio_delay_seconds + audio_silence_prefix_seconds + no_prompt_offset = int(no_prompt_offset_seconds * mimi.frame_rate) + text_tokens = text_tokens[prompt_frames + no_prompt_offset:] + + text = tokenizer.decode( + text_tokens[text_tokens > padding_token_id].numpy().tolist() + ) + + print(text) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Example streaming STT w/ a prompt.") + parser.add_argument( + "--file", + required=True, + help="File to transcribe.", + ) + parser.add_argument( + "--prompt_file", + required=False, + help="Audio of the prompt.", + ) + parser.add_argument( + "--prompt_text", + required=False, + help="Text of the prompt.", + ) + parser.add_argument( + "--cut-prompt-transcript", + action="store_true", + help="Cut the prompt from the output transcript", + ) + parser.add_argument( + "--hf-repo", type=str, help="HF repo to load the STT model from. " + ) + parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") + parser.add_argument( + "--moshi-weight", type=str, help="Path to a local checkpoint file." + ) + parser.add_argument( + "--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi." + ) + parser.add_argument( + "--config-path", type=str, help="Path to a local config file.", default=None + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device on which to run, defaults to 'cuda'.", + ) + args = parser.parse_args() + + main(args)