Compare commits
3 Commits
main
...
vv/clean-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c18018b6f | ||
|
|
bb0bdbf697 | ||
|
|
7b818c2636 |
22
.pre-commit-config.yaml
Normal file
22
.pre-commit-config.yaml
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
repos:
|
||||
# Get rid of Jupyter Notebook output because we don't want to keep it in Git
|
||||
- repo: https://github.com/kynan/nbstripout
|
||||
rev: 0.8.1
|
||||
hooks:
|
||||
- id: nbstripout
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: ["--maxkb=2048"]
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.11.7
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
types_or: [python, pyi] # Don't run on `jupyter` files
|
||||
args: [--fix]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
types_or: [python, pyi] # Don't run on `jupyter` files
|
||||
52
README.md
52
README.md
|
|
@ -1,7 +1,7 @@
|
|||
<a href="https://huggingface.co/collections/kyutai/speech-to-text-685403682cf8a23ab9466886" target="_blank" style="margin: 2px;">
|
||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiSTT-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||
</a>
|
||||
<a target="_blank" href="https://colab.research.google.com/drive/1mc0Q-FoHxU2pEvId8rTdS4q1r1zorJhS?usp=sharing">
|
||||
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/transcribe_via_pytorch.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
</a>
|
||||
|
||||
|
|
@ -33,6 +33,21 @@ These speech-to-text models have several advantages:
|
|||
can be used to detect when the user is speaking. This is especially useful
|
||||
for building voice agents.
|
||||
|
||||
### Implementations overview
|
||||
|
||||
We provide different implementations of Kyutai STT for different use cases.
|
||||
Here is how to choose which one to use:
|
||||
|
||||
- **PyTorch: for research and tinkering.**
|
||||
If you want to call the model from Python for research or experimentation, use our PyTorch implementation.
|
||||
- **Rust: for production.**
|
||||
If you want to serve Kyutai STT in a production setting, use our Rust server.
|
||||
Our robust Rust server provides streaming access to the model over websockets.
|
||||
We use this server to run [Unmute](https://unmute.sh/); on a L40S GPU, we can serve 64 simultaneous connections at a real-time factor of 3x.
|
||||
- **MLX: for on-device inference on iPhone and Mac.**
|
||||
MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon.
|
||||
If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
|
||||
|
||||
You can retrieve the sample files used in the following snippets via:
|
||||
```bash
|
||||
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
|
||||
|
|
@ -43,10 +58,14 @@ wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hib
|
|||
<a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;">
|
||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||
</a>
|
||||
<a target="_blank" href="https://colab.research.google.com/drive/1mc0Q-FoHxU2pEvId8rTdS4q1r1zorJhS?usp=sharing">
|
||||
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/transcribe_via_pytorch.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
</a>
|
||||
|
||||
For an example of how to use the model in a way where you can directly stream in PyTorch tensors,
|
||||
[see our Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/transcribe_via_pytorch.ipynb).
|
||||
|
||||
If you just want to run the model on a file, you can use `moshi.run_inference`.
|
||||
This requires the [moshi package](https://pypi.org/project/moshi/)
|
||||
with version 0.2.6 or later, which can be installed via pip.
|
||||
|
||||
|
|
@ -58,25 +77,25 @@ If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the install
|
|||
```bash
|
||||
uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
|
||||
```
|
||||
It will install the moshi package in a temporary environment and run the speech-to-text.
|
||||
|
||||
Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs:
|
||||
|
||||
```bash
|
||||
uv run \
|
||||
scripts/streaming_stt_timestamps.py \
|
||||
scripts/transcribe_from_file_via_pytorch.py \
|
||||
--hf-repo kyutai/stt-2.6b-en \
|
||||
--file bria.mp3
|
||||
```
|
||||
|
||||
The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics:
|
||||
```bash
|
||||
uv run scripts/streaming_stt.py \
|
||||
uv run scripts/evaluate_on_dataset.py \
|
||||
--dataset meanwhile \
|
||||
--hf-repo kyutai/stt-2.6b-en
|
||||
```
|
||||
|
||||
### Rust server
|
||||
|
||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
|
||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||
</a>
|
||||
|
|
@ -104,15 +123,19 @@ and for `kyutai/stt-2.6b-en`, use `configs/config-stt-en-hf.toml`,
|
|||
moshi-server worker --config configs/config-stt-en_fr-hf.toml
|
||||
```
|
||||
|
||||
Once the server has started you can run a streaming inference with the following
|
||||
script.
|
||||
Once the server has started you can transcribe audio from your microphone with the following script.
|
||||
```bash
|
||||
uv run scripts/asr-streaming-query.py bria.mp3
|
||||
uv run scripts/transcribe_from_mic_via_rust_server.py
|
||||
```
|
||||
|
||||
We also provide a script for transcribing from an audio file.
|
||||
```bash
|
||||
uv run scripts/transcribe_from_file_via_rust_server.py bria.mp3
|
||||
```
|
||||
|
||||
The script limits the decoding speed to simulates real-time processing of the audio.
|
||||
Faster processing can be triggered by setting
|
||||
the real-time factor, e.g. `--rtf 500` will process
|
||||
the real-time factor, e.g. `--rtf 1000` will process
|
||||
the data as fast as possible.
|
||||
|
||||
### Rust standalone
|
||||
|
|
@ -166,3 +189,14 @@ Note that parts of this code is based on [AudioCraft](https://github.com/faceboo
|
|||
the MIT license.
|
||||
|
||||
The weights for the speech-to-text models are released under the CC-BY 4.0 license.
|
||||
|
||||
## Developing
|
||||
|
||||
Install the [pre-commit hooks](https://pre-commit.com/) by running:
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
If you're using `uv`, you can replace the two commands with `uvx pre-commit install`.
|
||||
|
|
@ -41,18 +41,17 @@ uv run scripts/streaming_stt.py \
|
|||
# Rev16 === cer: 6.57% wer: 10.08% corpus_wer: 11.43% RTF = 40.34
|
||||
# Earnings21 === cer: 5.73% wer: 9.84% corpus_wer: 10.38% RTF = 73.15
|
||||
|
||||
import dataclasses
|
||||
import julius
|
||||
import jiwer
|
||||
from datasets import load_dataset, Dataset
|
||||
from whisper.normalizers import EnglishTextNormalizer
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import moshi.models
|
||||
import tqdm
|
||||
import dataclasses
|
||||
import time
|
||||
|
||||
import jiwer
|
||||
import julius
|
||||
import moshi.models
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, load_dataset
|
||||
from whisper.normalizers import EnglishTextNormalizer
|
||||
|
||||
_NORMALIZER = EnglishTextNormalizer()
|
||||
|
||||
|
|
@ -120,9 +119,9 @@ class AsrMetrics:
|
|||
self.num_sequences += 1
|
||||
|
||||
def compute(self) -> dict:
|
||||
assert (
|
||||
self.num_sequences > 0
|
||||
), "Unable to compute with total number of comparisons <= 0" # type: ignore
|
||||
assert self.num_sequences > 0, (
|
||||
"Unable to compute with total number of comparisons <= 0"
|
||||
) # type: ignore
|
||||
return {
|
||||
"cer": (self.cer_sum / self.num_sequences),
|
||||
"wer": (self.wer_sum / self.num_sequences),
|
||||
|
|
@ -19,15 +19,15 @@ uv run scripts/streaming_stt_timestamps.py \
|
|||
```
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import dataclasses
|
||||
import julius
|
||||
import sphn
|
||||
import argparse
|
||||
import dataclasses
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import julius
|
||||
import moshi.models
|
||||
import sphn
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
|
||||
|
|
@ -10,17 +10,16 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import msgpack
|
||||
import sphn
|
||||
import struct
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import msgpack
|
||||
import sphn
|
||||
import websockets
|
||||
|
||||
# Desired audio properties
|
||||
TARGET_SAMPLE_RATE = 24000
|
||||
TARGET_CHANNELS = 1 # Mono
|
||||
HEADERS = {"kyutai-api-key": "open_token"}
|
||||
all_text = []
|
||||
transcript = []
|
||||
finished = False
|
||||
|
|
@ -44,11 +43,13 @@ async def receive_messages(websocket):
|
|||
print("received:", data)
|
||||
if data["type"] == "Word":
|
||||
all_text.append(data["text"])
|
||||
transcript.append({
|
||||
transcript.append(
|
||||
{
|
||||
"speaker": "SPEAKER_00",
|
||||
"text": data["text"],
|
||||
"timestamp": [data["start_time"], data["start_time"]],
|
||||
})
|
||||
}
|
||||
)
|
||||
if data["type"] == "EndWord":
|
||||
if len(transcript) > 0:
|
||||
transcript[-1]["timestamp"][1] = data["stop_time"]
|
||||
|
|
@ -64,15 +65,19 @@ async def send_messages(websocket, rtf: float):
|
|||
global finished
|
||||
audio_data = load_and_process_audio(args.in_file)
|
||||
try:
|
||||
# Start with a second of silence
|
||||
chunk = { "type": "Audio", "pcm": [0.0] * 24000 }
|
||||
# Start with a second of silence.
|
||||
# This is needed for the 2.6B model for technical reasons.
|
||||
chunk = {"type": "Audio", "pcm": [0.0] * 24000}
|
||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||
await websocket.send(msg)
|
||||
|
||||
chunk_size = 1920 # Send data in chunks
|
||||
start_time = time.time()
|
||||
for i in range(0, len(audio_data), chunk_size):
|
||||
chunk = { "type": "Audio", "pcm": [float(x) for x in audio_data[i : i + chunk_size]] }
|
||||
chunk = {
|
||||
"type": "Audio",
|
||||
"pcm": [float(x) for x in audio_data[i : i + chunk_size]],
|
||||
}
|
||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||
await websocket.send(msg)
|
||||
expected_send_time = start_time + (i + 1) / 24000 / rtf
|
||||
|
|
@ -81,13 +86,15 @@ async def send_messages(websocket, rtf: float):
|
|||
await asyncio.sleep(expected_send_time - current_time)
|
||||
else:
|
||||
await asyncio.sleep(0.001)
|
||||
chunk = { "type": "Audio", "pcm": [0.0] * 1920 * 5 }
|
||||
chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5}
|
||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||
await websocket.send(msg)
|
||||
msg = msgpack.packb({"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True)
|
||||
msg = msgpack.packb(
|
||||
{"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True
|
||||
)
|
||||
await websocket.send(msg)
|
||||
for _ in range(35):
|
||||
chunk = { "type": "Audio", "pcm": [0.0] * 1920 }
|
||||
chunk = {"type": "Audio", "pcm": [0.0] * 1920}
|
||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||
await websocket.send(msg)
|
||||
while True:
|
||||
|
|
@ -100,11 +107,10 @@ async def send_messages(websocket, rtf: float):
|
|||
print("Connection closed while sending messages.")
|
||||
|
||||
|
||||
async def stream_audio(url: str, rtf: float, api_key: str):
|
||||
async def stream_audio(url: str, rtf: float):
|
||||
"""Stream audio data to a WebSocket server."""
|
||||
|
||||
headers = {"kyutai-api-key": api_key}
|
||||
async with websockets.connect(url, additional_headers=headers) as websocket:
|
||||
async with websockets.connect(url, additional_headers=HEADERS) as websocket:
|
||||
send_task = asyncio.create_task(send_messages(websocket, rtf))
|
||||
receive_task = asyncio.create_task(receive_messages(websocket))
|
||||
await asyncio.gather(send_task, receive_task)
|
||||
|
|
@ -115,7 +121,6 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("in_file")
|
||||
parser.add_argument("--transcript")
|
||||
parser.add_argument("--api-key", default="open_token")
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
help="The url of the server to which to send the audio",
|
||||
|
|
@ -125,7 +130,7 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.url}/api/asr-streaming"
|
||||
asyncio.run(stream_audio(url, args.rtf, args.api_key))
|
||||
asyncio.run(stream_audio(url, args.rtf))
|
||||
print(" ".join(all_text))
|
||||
if args.transcript is not None:
|
||||
with open(args.transcript, "w") as fobj:
|
||||
|
|
@ -11,19 +11,16 @@
|
|||
# ///
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import numpy as np
|
||||
import queue
|
||||
import sounddevice as sd
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from moshi_mlx import models, utils
|
||||
import rustymimi
|
||||
import sentencepiece
|
||||
|
||||
import sounddevice as sd
|
||||
from huggingface_hub import hf_hub_download
|
||||
from moshi_mlx import models, utils
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
@ -69,6 +66,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
block_queue = queue.Queue()
|
||||
|
||||
def audio_callback(indata, _frames, _time, _status):
|
||||
block_queue.put(indata.copy())
|
||||
|
||||
|
|
@ -84,7 +82,9 @@ if __name__ == "__main__":
|
|||
block = block_queue.get()
|
||||
block = block[None, :, 0]
|
||||
other_audio_tokens = audio_tokenizer.encode_step(block[None, 0:1])
|
||||
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[:, :, :other_codebooks]
|
||||
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
|
||||
:, :, :other_codebooks
|
||||
]
|
||||
text_token = gen.step(other_audio_tokens[0])
|
||||
text_token = text_token[0].item()
|
||||
audio_tokens = gen.last_audio_tokens()
|
||||
|
|
@ -93,4 +93,3 @@ if __name__ == "__main__":
|
|||
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
|
||||
_text = _text.replace("▁", " ")
|
||||
print(_text, end="", flush=True)
|
||||
|
||||
|
|
@ -9,9 +9,9 @@
|
|||
# ///
|
||||
import argparse
|
||||
import asyncio
|
||||
import msgpack
|
||||
import signal
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
import websockets
|
||||
|
|
@ -21,6 +21,7 @@ TARGET_SAMPLE_RATE = 24000
|
|||
TARGET_CHANNELS = 1 # Mono
|
||||
audio_queue = asyncio.Queue()
|
||||
|
||||
|
||||
async def receive_messages(websocket):
|
||||
"""Receive and process messages from the WebSocket server."""
|
||||
try:
|
||||
|
|
@ -47,22 +48,26 @@ async def send_messages(websocket):
|
|||
except websockets.ConnectionClosed:
|
||||
print("Connection closed while sending messages.")
|
||||
|
||||
|
||||
async def stream_audio(url: str, api_key: str):
|
||||
"""Stream audio data to a WebSocket server."""
|
||||
print("Starting microphone recording...")
|
||||
print("Press Ctrl+C to stop recording")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def audio_callback(indata, frames, time, status):
|
||||
loop.call_soon_threadsafe(audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy())
|
||||
loop.call_soon_threadsafe(
|
||||
audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy()
|
||||
)
|
||||
|
||||
# Start audio stream
|
||||
with sd.InputStream(
|
||||
samplerate=TARGET_SAMPLE_RATE,
|
||||
channels=TARGET_CHANNELS,
|
||||
dtype='float32',
|
||||
dtype="float32",
|
||||
callback=audio_callback,
|
||||
blocksize=1920 # 80ms blocks
|
||||
blocksize=1920, # 80ms blocks
|
||||
):
|
||||
headers = {"kyutai-api-key": api_key}
|
||||
async with websockets.connect(url, additional_headers=headers) as websocket:
|
||||
|
|
@ -79,8 +84,12 @@ if __name__ == "__main__":
|
|||
default="ws://127.0.0.1:8080",
|
||||
)
|
||||
parser.add_argument("--api-key", default="open_token")
|
||||
parser.add_argument("--list-devices", action="store_true", help="List available audio devices")
|
||||
parser.add_argument("--device", type=int, help="Input device ID (use --list-devices to see options)")
|
||||
parser.add_argument(
|
||||
"--list-devices", action="store_true", help="List available audio devices"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=int, help="Input device ID (use --list-devices to see options)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
240
transcribe_via_pytorch.ipynb
Normal file
240
transcribe_via_pytorch.ipynb
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "gJEMjPgeI-rw",
|
||||
"outputId": "7491c067-b1be-4505-b3f5-19ba4c00a593"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install moshi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "CA4K5iDFJcqJ",
|
||||
"outputId": "b609843a-a193-4729-b099-5f8780532333"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "VA3Haix3IZ8Q"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"import time\n",
|
||||
"import sentencepiece\n",
|
||||
"import sphn\n",
|
||||
"import textwrap\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from moshi.models import loaders, MimiModel, LMModel, LMGen"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "9AK5zBMTI9bw"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@dataclass\n",
|
||||
"class InferenceState:\n",
|
||||
" mimi: MimiModel\n",
|
||||
" text_tokenizer: sentencepiece.SentencePieceProcessor\n",
|
||||
" lm_gen: LMGen\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" mimi: MimiModel,\n",
|
||||
" text_tokenizer: sentencepiece.SentencePieceProcessor,\n",
|
||||
" lm: LMModel,\n",
|
||||
" batch_size: int,\n",
|
||||
" device: str | torch.device,\n",
|
||||
" ):\n",
|
||||
" self.mimi = mimi\n",
|
||||
" self.text_tokenizer = text_tokenizer\n",
|
||||
" self.lm_gen = LMGen(lm, temp=0, temp_text=0, use_sampling=False)\n",
|
||||
" self.device = device\n",
|
||||
" self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" self.mimi.streaming_forever(batch_size)\n",
|
||||
" self.lm_gen.streaming_forever(batch_size)\n",
|
||||
"\n",
|
||||
" def run(self, in_pcms: torch.Tensor):\n",
|
||||
" device = self.lm_gen.lm_model.device\n",
|
||||
" ntokens = 0\n",
|
||||
" first_frame = True\n",
|
||||
" chunks = [\n",
|
||||
" c\n",
|
||||
" for c in in_pcms.split(self.frame_size, dim=2)\n",
|
||||
" if c.shape[-1] == self.frame_size\n",
|
||||
" ]\n",
|
||||
" start_time = time.time()\n",
|
||||
" all_text = []\n",
|
||||
" for chunk in chunks:\n",
|
||||
" codes = self.mimi.encode(chunk)\n",
|
||||
" if first_frame:\n",
|
||||
" # Ensure that the first slice of codes is properly seen by the transformer\n",
|
||||
" # as otherwise the first slice is replaced by the initial tokens.\n",
|
||||
" tokens = self.lm_gen.step(codes)\n",
|
||||
" first_frame = False\n",
|
||||
" tokens = self.lm_gen.step(codes)\n",
|
||||
" if tokens is None:\n",
|
||||
" continue\n",
|
||||
" assert tokens.shape[1] == 1\n",
|
||||
" one_text = tokens[0, 0].cpu()\n",
|
||||
" if one_text.item() not in [0, 3]:\n",
|
||||
" text = self.text_tokenizer.id_to_piece(one_text.item())\n",
|
||||
" text = text.replace(\"▁\", \" \")\n",
|
||||
" all_text.append(text)\n",
|
||||
" ntokens += 1\n",
|
||||
" dt = time.time() - start_time\n",
|
||||
" print(\n",
|
||||
" f\"processed {ntokens} steps in {dt:.0f}s, {1000 * dt / ntokens:.2f}ms/step\"\n",
|
||||
" )\n",
|
||||
" return \"\".join(all_text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 353,
|
||||
"referenced_widgets": [
|
||||
"0a5f6f887e2b4cd1990a0e9ec0153ed9",
|
||||
"f7893826fcba4bdc87539589d669249b",
|
||||
"8805afb12c484781be85082ff02dad13",
|
||||
"97679c0d9ab44bed9a3456f2fcb541fd",
|
||||
"d73c0321bed54a52b5e1da0a7788e32a",
|
||||
"d67be13a920d4fc89e5570b5b29fc1d2",
|
||||
"6b377c2d7bf945fb89e46c39d246a332",
|
||||
"b82ff365c78e41ad8094b46daf79449d",
|
||||
"477aa7fa82dc42d5bce6f1743c45d626",
|
||||
"cbd288510c474430beb66f346f382c45",
|
||||
"aafc347cdf28428ea6a7abe5b46b726f",
|
||||
"fca09acd5d0d45468c8b04bfb2de7646",
|
||||
"79e35214b51b4a9e9b3f7144b0b34f7b",
|
||||
"89e9a37f69904bd48b954d627bff6687",
|
||||
"57028789c78248a7b0ad4f031c9545c9",
|
||||
"1150fcb427994c2984d4d0f4e4745fe5",
|
||||
"e24b1fc52f294f849019c9b3befb613f",
|
||||
"8724878682cf4c3ca992667c45009398",
|
||||
"36a22c977d5242008871310133b7d2af",
|
||||
"5b3683cad5cb4877b43fadd003edf97f",
|
||||
"703f98272e4d469d8f27f5a465715dd8",
|
||||
"9dbe02ef5fac41cfaee3d02946e65c88",
|
||||
"37faa87ad03a4271992c21ce6a629e18",
|
||||
"570c547e48cd421b814b2c5e028e4c0b",
|
||||
"b173768580fc4c0a8e3abf272e4c363a",
|
||||
"e57d1620f0a9427b85d8b4885ef4e8e3",
|
||||
"5dd4474df70743498b616608182714dd",
|
||||
"cc907676a65f4ad1bf68a77b4a00e89b",
|
||||
"a34abc3b118e4305951a466919c28ff6",
|
||||
"a77ccfcdb90146c7a63b4b2d232bc494",
|
||||
"f7313e6e3a27475993cab3961d6ae363",
|
||||
"39b47fad9c554839868fe9e4bbf7def2",
|
||||
"14e9511ea0bd44c49f0cf3abf1a6d40e",
|
||||
"a4ea8e0c4cac4d5e88b7e3f527e4fe90",
|
||||
"571afc0f4b2840c9830d6b5a307ed1f9",
|
||||
"6ec593cab5b64f0ea638bb175b9daa5c",
|
||||
"77a52aed00ae408bb24524880e19ec8a",
|
||||
"0b2de4b29b4b44fe9d96361a40c793d0",
|
||||
"3c5b5fb1a5ac468a89c1058bd90cfb58",
|
||||
"e53e0a2a240e43cfa562c89b3d703dea",
|
||||
"35966343cf9249ef8bc028a0d5c5f97d",
|
||||
"e36a37e0d41c47ccb8bc6d56c19fb17c",
|
||||
"279ccf7de43847a1a6579c9182a46cc8",
|
||||
"41b5d6ab0b7d43c790a55f125c0e7494"
|
||||
]
|
||||
},
|
||||
"id": "UsQJdAgkLp9n",
|
||||
"outputId": "9b7131c3-69c5-4323-8312-2ce7621d8869"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"device = \"cuda\"\n",
|
||||
"# Use the en+fr low latency model, an alternative is kyutai/stt-2.6b-en\n",
|
||||
"checkpoint_info = loaders.CheckpointInfo.from_hf_repo(\"kyutai/stt-1b-en_fr\")\n",
|
||||
"mimi = checkpoint_info.get_mimi(device=device)\n",
|
||||
"text_tokenizer = checkpoint_info.get_text_tokenizer()\n",
|
||||
"lm = checkpoint_info.get_moshi(device=device)\n",
|
||||
"in_pcms, _ = sphn.read(\"sample_fr_hibiki_crepes.mp3\", sample_rate=mimi.sample_rate)\n",
|
||||
"in_pcms = torch.from_numpy(in_pcms).to(device=device)\n",
|
||||
"\n",
|
||||
"stt_config = checkpoint_info.stt_config\n",
|
||||
"pad_left = int(stt_config.get(\"audio_silence_prefix_seconds\", 0.0) * 24000)\n",
|
||||
"pad_right = int((stt_config.get(\"audio_delay_seconds\", 0.0) + 1.0) * 24000)\n",
|
||||
"in_pcms = torch.nn.functional.pad(in_pcms, (pad_left, pad_right), mode=\"constant\")\n",
|
||||
"in_pcms = in_pcms[None, 0:1].expand(1, -1, -1)\n",
|
||||
"\n",
|
||||
"state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device)\n",
|
||||
"text = state.run(in_pcms)\n",
|
||||
"print(textwrap.fill(text, width=100))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 75
|
||||
},
|
||||
"id": "CIAXs9oaPrtj",
|
||||
"outputId": "94cc208c-2454-4dd4-a64e-d79025144af5"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from IPython.display import Audio\n",
|
||||
"\n",
|
||||
"Audio(\"sample_fr_hibiki_crepes.mp3\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "qkUZ6CBKOdTa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"gpuType": "L4",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user