diff --git a/README.md b/README.md index b546f3b..b65b470 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,22 @@ uv run scripts/evaluate_on_dataset.py \ --hf-repo kyutai/stt-2.6b-en ``` +Another example shows how one can provide a text-, audio-, or text-audio prompt to our STT model: +```bash +uv run scripts/transcribe_from_file_via_pytorch_with_prompt.py \ + --hf-repo kyutai/stt-2.6b-en \ + --file bria.mp3 \ + --prompt_file ./audio/loonah.mp3 \ + --prompt_text "Loonah" \ + --cut-prompt-transcript +``` +Produces the transcript of `bria.mp3` using the `Loonah` spelling for the name, instead of the `Luna` used without any prompt: +``` +In the heart of an ancient forest, where the trees whispered secrets of the past, there lived a peculiar rabbit named Loonah (...) +``` +Please bear in mind that is an experimental feature and its behavior is very sensitive to the prompt provided. + + ### Rust server diff --git a/audio/loona.mp3 b/audio/loona.mp3 new file mode 100644 index 0000000..997cc31 Binary files /dev/null and b/audio/loona.mp3 differ diff --git a/scripts/transcribe_from_file_via_pytorch_with_prompt.py b/scripts/transcribe_from_file_via_pytorch_with_prompt.py index 833bb8d..5861116 100644 --- a/scripts/transcribe_from_file_via_pytorch_with_prompt.py +++ b/scripts/transcribe_from_file_via_pytorch_with_prompt.py @@ -14,7 +14,15 @@ import tqdm class PromptHook: - def __init__(self, tokenizer, prefix, padding_tokens=(0, 3,)): + def __init__( + self, + tokenizer, + prefix, + padding_tokens=( + 0, + 3, + ), + ): self.tokenizer = tokenizer self.prefix_enforce = deque(self.tokenizer.encode(prefix)) self.padding_tokens = padding_tokens @@ -102,10 +110,12 @@ def main(args): chain = [itertools.repeat(silence_chunk, n_prefix_chunks)] if audio_prompt is not None: - chain.append(torch.split(audio_prompt[:, None], mimi.frame_size, dim=-1)) + chain.append(torch.split(audio_prompt[:, None, :], mimi.frame_size, dim=-1)) + # adding a bit (0.8s) of silence to separate prompt and the actual audio + chain.append(itertools.repeat(silence_chunk, 10)) chain += [ - torch.split(audio[:, None], mimi.frame_size, dim=-1), + torch.split(audio[:, None, :], mimi.frame_size, dim=-1), itertools.repeat(silence_chunk, n_suffix_chunks), ] @@ -121,9 +131,22 @@ def main(args): utterance_tokens = torch.concat(text_tokens_accum, dim=-1) text_tokens = utterance_tokens.cpu().view(-1) + + # if we have an audio prompt and we don't want to have it in the transcript, + # we should cut the corresponding number of frames from the output tokens. + # However, there is also some amount of padding that happens before it + # due to silence_prefix and audio_delay. Normally it is ignored in detokenization, + # but now we should account for it to find the position of the prompt transcript. + if args.cut_prompt_transcript and audio_prompt is not None: + prompt_frames = audio_prompt.shape[1] // mimi.frame_size + no_prompt_offset_seconds = audio_delay_seconds + audio_silence_prefix_seconds + no_prompt_offset = int(no_prompt_offset_seconds * mimi.frame_rate) + text_tokens = text_tokens[prompt_frames + no_prompt_offset:] + text = tokenizer.decode( text_tokens[text_tokens > padding_token_id].numpy().tolist() ) + print(text) @@ -144,7 +167,11 @@ if __name__ == "__main__": required=False, help="Text of the prompt.", ) - + parser.add_argument( + "--cut-prompt-transcript", + action="store_true", + help="Cut the prompt from the output transcript", + ) parser.add_argument( "--hf-repo", type=str, help="HF repo to load the STT model from. " )