Text-audio prompt example into README.md + cutting prompt transcript.
This commit is contained in:
parent
adca7c2731
commit
ebbd58dd23
16
README.md
16
README.md
|
|
@ -88,6 +88,22 @@ uv run scripts/evaluate_on_dataset.py \
|
||||||
--hf-repo kyutai/stt-2.6b-en
|
--hf-repo kyutai/stt-2.6b-en
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Another example shows how one can provide a text-, audio-, or text-audio prompt to our STT model:
|
||||||
|
```bash
|
||||||
|
uv run scripts/transcribe_from_file_via_pytorch_with_prompt.py \
|
||||||
|
--hf-repo kyutai/stt-2.6b-en \
|
||||||
|
--file bria.mp3 \
|
||||||
|
--prompt_file ./audio/loonah.mp3 \
|
||||||
|
--prompt_text "Loonah" \
|
||||||
|
--cut-prompt-transcript
|
||||||
|
```
|
||||||
|
Produces the transcript of `bria.mp3` using the `Loonah` spelling for the name, instead of the `Luna` used without any prompt:
|
||||||
|
```
|
||||||
|
In the heart of an ancient forest, where the trees whispered secrets of the past, there lived a peculiar rabbit named Loonah (...)
|
||||||
|
```
|
||||||
|
Please bear in mind that is an experimental feature and its behavior is very sensitive to the prompt provided.
|
||||||
|
|
||||||
|
|
||||||
### Rust server
|
### Rust server
|
||||||
|
|
||||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
|
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
|
||||||
|
|
|
||||||
BIN
audio/loona.mp3
Normal file
BIN
audio/loona.mp3
Normal file
Binary file not shown.
|
|
@ -14,7 +14,15 @@ import tqdm
|
||||||
|
|
||||||
|
|
||||||
class PromptHook:
|
class PromptHook:
|
||||||
def __init__(self, tokenizer, prefix, padding_tokens=(0, 3,)):
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
prefix,
|
||||||
|
padding_tokens=(
|
||||||
|
0,
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.prefix_enforce = deque(self.tokenizer.encode(prefix))
|
self.prefix_enforce = deque(self.tokenizer.encode(prefix))
|
||||||
self.padding_tokens = padding_tokens
|
self.padding_tokens = padding_tokens
|
||||||
|
|
@ -102,10 +110,12 @@ def main(args):
|
||||||
chain = [itertools.repeat(silence_chunk, n_prefix_chunks)]
|
chain = [itertools.repeat(silence_chunk, n_prefix_chunks)]
|
||||||
|
|
||||||
if audio_prompt is not None:
|
if audio_prompt is not None:
|
||||||
chain.append(torch.split(audio_prompt[:, None], mimi.frame_size, dim=-1))
|
chain.append(torch.split(audio_prompt[:, None, :], mimi.frame_size, dim=-1))
|
||||||
|
# adding a bit (0.8s) of silence to separate prompt and the actual audio
|
||||||
|
chain.append(itertools.repeat(silence_chunk, 10))
|
||||||
|
|
||||||
chain += [
|
chain += [
|
||||||
torch.split(audio[:, None], mimi.frame_size, dim=-1),
|
torch.split(audio[:, None, :], mimi.frame_size, dim=-1),
|
||||||
itertools.repeat(silence_chunk, n_suffix_chunks),
|
itertools.repeat(silence_chunk, n_suffix_chunks),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -121,9 +131,22 @@ def main(args):
|
||||||
|
|
||||||
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
|
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
|
||||||
text_tokens = utterance_tokens.cpu().view(-1)
|
text_tokens = utterance_tokens.cpu().view(-1)
|
||||||
|
|
||||||
|
# if we have an audio prompt and we don't want to have it in the transcript,
|
||||||
|
# we should cut the corresponding number of frames from the output tokens.
|
||||||
|
# However, there is also some amount of padding that happens before it
|
||||||
|
# due to silence_prefix and audio_delay. Normally it is ignored in detokenization,
|
||||||
|
# but now we should account for it to find the position of the prompt transcript.
|
||||||
|
if args.cut_prompt_transcript and audio_prompt is not None:
|
||||||
|
prompt_frames = audio_prompt.shape[1] // mimi.frame_size
|
||||||
|
no_prompt_offset_seconds = audio_delay_seconds + audio_silence_prefix_seconds
|
||||||
|
no_prompt_offset = int(no_prompt_offset_seconds * mimi.frame_rate)
|
||||||
|
text_tokens = text_tokens[prompt_frames + no_prompt_offset:]
|
||||||
|
|
||||||
text = tokenizer.decode(
|
text = tokenizer.decode(
|
||||||
text_tokens[text_tokens > padding_token_id].numpy().tolist()
|
text_tokens[text_tokens > padding_token_id].numpy().tolist()
|
||||||
)
|
)
|
||||||
|
|
||||||
print(text)
|
print(text)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -144,7 +167,11 @@ if __name__ == "__main__":
|
||||||
required=False,
|
required=False,
|
||||||
help="Text of the prompt.",
|
help="Text of the prompt.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cut-prompt-transcript",
|
||||||
|
action="store_true",
|
||||||
|
help="Cut the prompt from the output transcript",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hf-repo", type=str, help="HF repo to load the STT model from. "
|
"--hf-repo", type=str, help="HF repo to load the STT model from. "
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user