STT example w/ prompting (#26)

* STT example w/ prompting

* Text-audio prompt example into README.md + cutting prompt transcript.

* A line in README

* formatting in README

---------

Co-authored-by: Eugene <eugene@kyutai.org>
This commit is contained in:
Eugene Kharitonov 2025-07-02 11:23:11 +02:00 committed by GitHub
parent 395eaeae95
commit c4ef93770a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 213 additions and 0 deletions

View File

@ -87,6 +87,23 @@ uv run scripts/evaluate_on_dataset.py \
--hf-repo kyutai/stt-2.6b-en
```
Another example shows how one can provide a text-, audio-, or text-audio prompt to our STT model:
```bash
uv run scripts/transcribe_from_file_via_pytorch_with_prompt.py \
--hf-repo kyutai/stt-2.6b-en \
--file bria.mp3 \
--prompt_file ./audio/loonah.mp3 \
--prompt_text "Loonah" \
--cut-prompt-transcript
```
Produces the transcript of `bria.mp3` using the `Loonah` spelling for the name, instead of the `Luna` used without any prompt:
```
In the heart of an ancient forest, where the trees whispered secrets of the past, there lived a peculiar rabbit named Loonah (...)
```
Apart from nudging the model for a specific spelling of a word, other potential use-cases include speaker adaptation and steering the model towards a specific formatting style or even a language.
However, please bear in mind that is an experimental feature and its behavior is very sensitive to the prompt provided.
### Rust server
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">

BIN
audio/loona.mp3 Normal file

Binary file not shown.

View File

@ -0,0 +1,196 @@
"""An example script that illustrates how one can prompt Kyutai STT models."""
import argparse
import dataclasses
import itertools
import math
from collections import deque
import julius
import moshi.models
import sphn
import torch
import tqdm
class PromptHook:
def __init__(
self,
tokenizer,
prefix,
padding_tokens=(
0,
3,
),
):
self.tokenizer = tokenizer
self.prefix_enforce = deque(self.tokenizer.encode(prefix))
self.padding_tokens = padding_tokens
def on_token(self, token):
if not self.prefix_enforce:
return
token = token.item()
if token in self.padding_tokens:
pass
elif token == self.prefix_enforce[0]:
self.prefix_enforce.popleft()
else:
assert False
def on_logits(self, logits):
if not self.prefix_enforce:
return
mask = torch.zeros_like(logits, dtype=torch.bool)
for t in self.padding_tokens:
mask[..., t] = True
mask[..., self.prefix_enforce[0]] = True
logits[:] = torch.where(mask, logits, float("-inf"))
def main(args):
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
args.hf_repo,
moshi_weights=args.moshi_weight,
mimi_weights=args.mimi_weight,
tokenizer=args.tokenizer,
config_path=args.config_path,
)
mimi = info.get_mimi(device=args.device)
tokenizer = info.get_text_tokenizer()
lm = info.get_moshi(
device=args.device,
dtype=torch.bfloat16,
)
if args.prompt_text:
prompt_hook = PromptHook(tokenizer, args.prompt_text)
lm_gen = moshi.models.LMGen(
lm,
temp=0,
temp_text=0.0,
on_text_hook=prompt_hook.on_token,
on_text_logits_hook=prompt_hook.on_logits,
)
else:
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
audio_silence_prefix_seconds = info.stt_config.get(
"audio_silence_prefix_seconds", 1.0
)
audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
padding_token_id = info.raw_config.get("text_padding_token_id", 3)
def _load_and_process(path):
audio, input_sample_rate = sphn.read(path)
audio = torch.from_numpy(audio).to(args.device).mean(axis=0, keepdim=True)
audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate)
if audio.shape[-1] % mimi.frame_size != 0:
to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size
audio = torch.nn.functional.pad(audio, (0, to_pad))
return audio
n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate)
n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate)
silence_chunk = torch.zeros(
(1, 1, mimi.frame_size), dtype=torch.float32, device=args.device
)
audio = _load_and_process(args.file)
if args.prompt_file:
audio_prompt = _load_and_process(args.prompt_file)
else:
audio_prompt = None
chain = [itertools.repeat(silence_chunk, n_prefix_chunks)]
if audio_prompt is not None:
chain.append(torch.split(audio_prompt[:, None, :], mimi.frame_size, dim=-1))
# adding a bit (0.8s) of silence to separate prompt and the actual audio
chain.append(itertools.repeat(silence_chunk, 10))
chain += [
torch.split(audio[:, None, :], mimi.frame_size, dim=-1),
itertools.repeat(silence_chunk, n_suffix_chunks),
]
chunks = itertools.chain(*chain)
text_tokens_accum = []
with mimi.streaming(1), lm_gen.streaming(1):
for audio_chunk in tqdm.tqdm(chunks):
audio_tokens = mimi.encode(audio_chunk)
text_tokens = lm_gen.step(audio_tokens)
if text_tokens is not None:
text_tokens_accum.append(text_tokens)
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
text_tokens = utterance_tokens.cpu().view(-1)
# if we have an audio prompt and we don't want to have it in the transcript,
# we should cut the corresponding number of frames from the output tokens.
# However, there is also some amount of padding that happens before it
# due to silence_prefix and audio_delay. Normally it is ignored in detokenization,
# but now we should account for it to find the position of the prompt transcript.
if args.cut_prompt_transcript and audio_prompt is not None:
prompt_frames = audio_prompt.shape[1] // mimi.frame_size
no_prompt_offset_seconds = audio_delay_seconds + audio_silence_prefix_seconds
no_prompt_offset = int(no_prompt_offset_seconds * mimi.frame_rate)
text_tokens = text_tokens[prompt_frames + no_prompt_offset:]
text = tokenizer.decode(
text_tokens[text_tokens > padding_token_id].numpy().tolist()
)
print(text)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Example streaming STT w/ a prompt.")
parser.add_argument(
"--file",
required=True,
help="File to transcribe.",
)
parser.add_argument(
"--prompt_file",
required=False,
help="Audio of the prompt.",
)
parser.add_argument(
"--prompt_text",
required=False,
help="Text of the prompt.",
)
parser.add_argument(
"--cut-prompt-transcript",
action="store_true",
help="Cut the prompt from the output transcript",
)
parser.add_argument(
"--hf-repo", type=str, help="HF repo to load the STT model from. "
)
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
parser.add_argument(
"--moshi-weight", type=str, help="Path to a local checkpoint file."
)
parser.add_argument(
"--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi."
)
parser.add_argument(
"--config-path", type=str, help="Path to a local config file.", default=None
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
args = parser.parse_args()
main(args)