From 7b5a01dfbad36df145380b31004edadb072007ff Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?V=C3=A1clav=20Volhejn?=
<8401624+vvolhejn@users.noreply.github.com>
Date: Thu, 26 Jun 2025 09:26:11 +0200
Subject: [PATCH] Rename examples and add pre-commit (#16)
* Rename examples and add pre-commit
* Fix references to scripts, add implementations overview
* Link to colab notebook via github
---
.pre-commit-config.yaml | 22 ++
README.md | 52 +++-
...treaming_stt.py => evaluate_on_dataset.py} | 23 +-
...py => transcribe_from_file_via_pytorch.py} | 10 +-
...> transcribe_from_file_via_rust_server.py} | 45 ++--
...-mic.py => transcribe_from_mic_via_mlx.py} | 15 +-
...=> transcribe_from_mic_via_rust_server.py} | 33 ++-
transcribe_via_pytorch.ipynb | 240 ++++++++++++++++++
8 files changed, 374 insertions(+), 66 deletions(-)
create mode 100644 .pre-commit-config.yaml
rename scripts/{streaming_stt.py => evaluate_on_dataset.py} (98%)
rename scripts/{streaming_stt_timestamps.py => transcribe_from_file_via_pytorch.py} (100%)
rename scripts/{asr-streaming-query.py => transcribe_from_file_via_rust_server.py} (78%)
rename scripts/{mlx-mic.py => transcribe_from_mic_via_mlx.py} (97%)
rename scripts/{mic-query.py => transcribe_from_mic_via_rust_server.py} (86%)
create mode 100644 transcribe_via_pytorch.ipynb
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..e2f0230
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,22 @@
+repos:
+ # Get rid of Jupyter Notebook output because we don't want to keep it in Git
+ - repo: https://github.com/kynan/nbstripout
+ rev: 0.8.1
+ hooks:
+ - id: nbstripout
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v5.0.0
+ hooks:
+ - id: check-added-large-files
+ args: ["--maxkb=2048"]
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ # Ruff version.
+ rev: v0.11.7
+ hooks:
+ # Run the linter.
+ - id: ruff
+ types_or: [python, pyi] # Don't run on `jupyter` files
+ args: [--fix]
+ # Run the formatter.
+ - id: ruff-format
+ types_or: [python, pyi] # Don't run on `jupyter` files
diff --git a/README.md b/README.md
index 99ed1bd..5b79d1a 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
-
+
@@ -33,6 +33,21 @@ These speech-to-text models have several advantages:
can be used to detect when the user is speaking. This is especially useful
for building voice agents.
+### Implementations overview
+
+We provide different implementations of Kyutai STT for different use cases.
+Here is how to choose which one to use:
+
+- **PyTorch: for research and tinkering.**
+ If you want to call the model from Python for research or experimentation, use our PyTorch implementation.
+- **Rust: for production.**
+ If you want to serve Kyutai STT in a production setting, use our Rust server.
+ Our robust Rust server provides streaming access to the model over websockets.
+ We use this server to run [Unmute](https://unmute.sh/); on a L40S GPU, we can serve 64 simultaneous connections at a real-time factor of 3x.
+- **MLX: for on-device inference on iPhone and Mac.**
+ MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon.
+ If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
+
You can retrieve the sample files used in the following snippets via:
```bash
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
@@ -43,10 +58,14 @@ wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hib
-
+
+For an example of how to use the model in a way where you can directly stream in PyTorch tensors,
+[see our Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/transcribe_via_pytorch.ipynb).
+
+If you just want to run the model on a file, you can use `moshi.run_inference`.
This requires the [moshi package](https://pypi.org/project/moshi/)
with version 0.2.6 or later, which can be installed via pip.
@@ -58,25 +77,25 @@ If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the install
```bash
uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
```
-It will install the moshi package in a temporary environment and run the speech-to-text.
Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs:
```bash
uv run \
- scripts/streaming_stt_timestamps.py \
+ scripts/transcribe_from_file_via_pytorch.py \
--hf-repo kyutai/stt-2.6b-en \
--file bria.mp3
```
The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics:
```bash
-uv run scripts/streaming_stt.py \
+uv run scripts/evaluate_on_dataset.py \
--dataset meanwhile \
--hf-repo kyutai/stt-2.6b-en
```
### Rust server
+
@@ -104,15 +123,19 @@ and for `kyutai/stt-2.6b-en`, use `configs/config-stt-en-hf.toml`,
moshi-server worker --config configs/config-stt-en_fr-hf.toml
```
-Once the server has started you can run a streaming inference with the following
-script.
+Once the server has started you can transcribe audio from your microphone with the following script.
```bash
-uv run scripts/asr-streaming-query.py bria.mp3
+uv run scripts/transcribe_from_mic_via_rust_server.py
+```
+
+We also provide a script for transcribing from an audio file.
+```bash
+uv run scripts/transcribe_from_file_via_rust_server.py bria.mp3
```
The script limits the decoding speed to simulates real-time processing of the audio.
Faster processing can be triggered by setting
-the real-time factor, e.g. `--rtf 500` will process
+the real-time factor, e.g. `--rtf 1000` will process
the data as fast as possible.
### Rust standalone
@@ -166,3 +189,14 @@ Note that parts of this code is based on [AudioCraft](https://github.com/faceboo
the MIT license.
The weights for the speech-to-text models are released under the CC-BY 4.0 license.
+
+## Developing
+
+Install the [pre-commit hooks](https://pre-commit.com/) by running:
+
+```bash
+pip install pre-commit
+pre-commit install
+```
+
+If you're using `uv`, you can replace the two commands with `uvx pre-commit install`.
\ No newline at end of file
diff --git a/scripts/streaming_stt.py b/scripts/evaluate_on_dataset.py
similarity index 98%
rename from scripts/streaming_stt.py
rename to scripts/evaluate_on_dataset.py
index bffe294..3bef8aa 100644
--- a/scripts/streaming_stt.py
+++ b/scripts/evaluate_on_dataset.py
@@ -41,18 +41,17 @@ uv run scripts/streaming_stt.py \
# Rev16 === cer: 6.57% wer: 10.08% corpus_wer: 11.43% RTF = 40.34
# Earnings21 === cer: 5.73% wer: 9.84% corpus_wer: 10.38% RTF = 73.15
-import dataclasses
-import julius
-import jiwer
-from datasets import load_dataset, Dataset
-from whisper.normalizers import EnglishTextNormalizer
import argparse
-
-import torch
-import moshi.models
-import tqdm
+import dataclasses
import time
+import jiwer
+import julius
+import moshi.models
+import torch
+import tqdm
+from datasets import Dataset, load_dataset
+from whisper.normalizers import EnglishTextNormalizer
_NORMALIZER = EnglishTextNormalizer()
@@ -120,9 +119,9 @@ class AsrMetrics:
self.num_sequences += 1
def compute(self) -> dict:
- assert (
- self.num_sequences > 0
- ), "Unable to compute with total number of comparisons <= 0" # type: ignore
+ assert self.num_sequences > 0, (
+ "Unable to compute with total number of comparisons <= 0"
+ ) # type: ignore
return {
"cer": (self.cer_sum / self.num_sequences),
"wer": (self.wer_sum / self.num_sequences),
diff --git a/scripts/streaming_stt_timestamps.py b/scripts/transcribe_from_file_via_pytorch.py
similarity index 100%
rename from scripts/streaming_stt_timestamps.py
rename to scripts/transcribe_from_file_via_pytorch.py
index 7e0970b..e941da8 100644
--- a/scripts/streaming_stt_timestamps.py
+++ b/scripts/transcribe_from_file_via_pytorch.py
@@ -19,15 +19,15 @@ uv run scripts/streaming_stt_timestamps.py \
```
"""
-import itertools
-import dataclasses
-import julius
-import sphn
import argparse
+import dataclasses
+import itertools
import math
-import torch
+import julius
import moshi.models
+import sphn
+import torch
import tqdm
diff --git a/scripts/asr-streaming-query.py b/scripts/transcribe_from_file_via_rust_server.py
similarity index 78%
rename from scripts/asr-streaming-query.py
rename to scripts/transcribe_from_file_via_rust_server.py
index 0527556..3e90b79 100644
--- a/scripts/asr-streaming-query.py
+++ b/scripts/transcribe_from_file_via_rust_server.py
@@ -10,17 +10,16 @@
import argparse
import asyncio
import json
-import msgpack
-import sphn
-import struct
import time
-import numpy as np
+import msgpack
+import sphn
import websockets
# Desired audio properties
TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono
+HEADERS = {"kyutai-api-key": "open_token"}
all_text = []
transcript = []
finished = False
@@ -44,11 +43,13 @@ async def receive_messages(websocket):
print("received:", data)
if data["type"] == "Word":
all_text.append(data["text"])
- transcript.append({
- "speaker": "SPEAKER_00",
- "text": data["text"],
- "timestamp": [data["start_time"], data["start_time"]],
- })
+ transcript.append(
+ {
+ "speaker": "SPEAKER_00",
+ "text": data["text"],
+ "timestamp": [data["start_time"], data["start_time"]],
+ }
+ )
if data["type"] == "EndWord":
if len(transcript) > 0:
transcript[-1]["timestamp"][1] = data["stop_time"]
@@ -64,15 +65,19 @@ async def send_messages(websocket, rtf: float):
global finished
audio_data = load_and_process_audio(args.in_file)
try:
- # Start with a second of silence
- chunk = { "type": "Audio", "pcm": [0.0] * 24000 }
+ # Start with a second of silence.
+ # This is needed for the 2.6B model for technical reasons.
+ chunk = {"type": "Audio", "pcm": [0.0] * 24000}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
chunk_size = 1920 # Send data in chunks
start_time = time.time()
for i in range(0, len(audio_data), chunk_size):
- chunk = { "type": "Audio", "pcm": [float(x) for x in audio_data[i : i + chunk_size]] }
+ chunk = {
+ "type": "Audio",
+ "pcm": [float(x) for x in audio_data[i : i + chunk_size]],
+ }
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
expected_send_time = start_time + (i + 1) / 24000 / rtf
@@ -81,13 +86,15 @@ async def send_messages(websocket, rtf: float):
await asyncio.sleep(expected_send_time - current_time)
else:
await asyncio.sleep(0.001)
- chunk = { "type": "Audio", "pcm": [0.0] * 1920 * 5 }
+ chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
- msg = msgpack.packb({"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True)
+ msg = msgpack.packb(
+ {"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True
+ )
await websocket.send(msg)
for _ in range(35):
- chunk = { "type": "Audio", "pcm": [0.0] * 1920 }
+ chunk = {"type": "Audio", "pcm": [0.0] * 1920}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
while True:
@@ -100,11 +107,10 @@ async def send_messages(websocket, rtf: float):
print("Connection closed while sending messages.")
-async def stream_audio(url: str, rtf: float, api_key: str):
+async def stream_audio(url: str, rtf: float):
"""Stream audio data to a WebSocket server."""
- headers = {"kyutai-api-key": api_key}
- async with websockets.connect(url, additional_headers=headers) as websocket:
+ async with websockets.connect(url, additional_headers=HEADERS) as websocket:
send_task = asyncio.create_task(send_messages(websocket, rtf))
receive_task = asyncio.create_task(receive_messages(websocket))
await asyncio.gather(send_task, receive_task)
@@ -115,7 +121,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("in_file")
parser.add_argument("--transcript")
- parser.add_argument("--api-key", default="open_token")
parser.add_argument(
"--url",
help="The url of the server to which to send the audio",
@@ -125,7 +130,7 @@ if __name__ == "__main__":
args = parser.parse_args()
url = f"{args.url}/api/asr-streaming"
- asyncio.run(stream_audio(url, args.rtf, args.api_key))
+ asyncio.run(stream_audio(url, args.rtf))
print(" ".join(all_text))
if args.transcript is not None:
with open(args.transcript, "w") as fobj:
diff --git a/scripts/mlx-mic.py b/scripts/transcribe_from_mic_via_mlx.py
similarity index 97%
rename from scripts/mlx-mic.py
rename to scripts/transcribe_from_mic_via_mlx.py
index 1816efe..e8792e2 100644
--- a/scripts/mlx-mic.py
+++ b/scripts/transcribe_from_mic_via_mlx.py
@@ -11,19 +11,16 @@
# ///
import argparse
-from dataclasses import dataclass
import json
-import numpy as np
import queue
-import sounddevice as sd
-from huggingface_hub import hf_hub_download
import mlx.core as mx
import mlx.nn as nn
-from moshi_mlx import models, utils
import rustymimi
import sentencepiece
-
+import sounddevice as sd
+from huggingface_hub import hf_hub_download
+from moshi_mlx import models, utils
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -69,6 +66,7 @@ if __name__ == "__main__":
)
block_queue = queue.Queue()
+
def audio_callback(indata, _frames, _time, _status):
block_queue.put(indata.copy())
@@ -84,7 +82,9 @@ if __name__ == "__main__":
block = block_queue.get()
block = block[None, :, 0]
other_audio_tokens = audio_tokenizer.encode_step(block[None, 0:1])
- other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[:, :, :other_codebooks]
+ other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
+ :, :, :other_codebooks
+ ]
text_token = gen.step(other_audio_tokens[0])
text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens()
@@ -93,4 +93,3 @@ if __name__ == "__main__":
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("▁", " ")
print(_text, end="", flush=True)
-
diff --git a/scripts/mic-query.py b/scripts/transcribe_from_mic_via_rust_server.py
similarity index 86%
rename from scripts/mic-query.py
rename to scripts/transcribe_from_mic_via_rust_server.py
index ff41d59..85bad93 100644
--- a/scripts/mic-query.py
+++ b/scripts/transcribe_from_mic_via_rust_server.py
@@ -9,9 +9,9 @@
# ///
import argparse
import asyncio
-import msgpack
import signal
+import msgpack
import numpy as np
import sounddevice as sd
import websockets
@@ -21,6 +21,7 @@ TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono
audio_queue = asyncio.Queue()
+
async def receive_messages(websocket):
"""Receive and process messages from the WebSocket server."""
try:
@@ -47,22 +48,26 @@ async def send_messages(websocket):
except websockets.ConnectionClosed:
print("Connection closed while sending messages.")
+
async def stream_audio(url: str, api_key: str):
"""Stream audio data to a WebSocket server."""
print("Starting microphone recording...")
print("Press Ctrl+C to stop recording")
loop = asyncio.get_event_loop()
- def audio_callback(indata, frames, time, status):
- loop.call_soon_threadsafe(audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy())
- # Start audio stream
+ def audio_callback(indata, frames, time, status):
+ loop.call_soon_threadsafe(
+ audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy()
+ )
+
+ # Start audio stream
with sd.InputStream(
samplerate=TARGET_SAMPLE_RATE,
channels=TARGET_CHANNELS,
- dtype='float32',
+ dtype="float32",
callback=audio_callback,
- blocksize=1920 # 80ms blocks
+ blocksize=1920, # 80ms blocks
):
headers = {"kyutai-api-key": api_key}
async with websockets.connect(url, additional_headers=headers) as websocket:
@@ -79,11 +84,15 @@ if __name__ == "__main__":
default="ws://127.0.0.1:8080",
)
parser.add_argument("--api-key", default="open_token")
- parser.add_argument("--list-devices", action="store_true", help="List available audio devices")
- parser.add_argument("--device", type=int, help="Input device ID (use --list-devices to see options)")
-
+ parser.add_argument(
+ "--list-devices", action="store_true", help="List available audio devices"
+ )
+ parser.add_argument(
+ "--device", type=int, help="Input device ID (use --list-devices to see options)"
+ )
+
args = parser.parse_args()
-
+
def handle_sigint(signum, frame):
print("Interrupted by user")
exit(0)
@@ -94,9 +103,9 @@ if __name__ == "__main__":
print("Available audio devices:")
print(sd.query_devices())
exit(0)
-
+
if args.device is not None:
sd.default.device[0] = args.device # Set input device
-
+
url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.api_key))
diff --git a/transcribe_via_pytorch.ipynb b/transcribe_via_pytorch.ipynb
new file mode 100644
index 0000000..4210d64
--- /dev/null
+++ b/transcribe_via_pytorch.ipynb
@@ -0,0 +1,240 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "gJEMjPgeI-rw",
+ "outputId": "7491c067-b1be-4505-b3f5-19ba4c00a593"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install moshi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "CA4K5iDFJcqJ",
+ "outputId": "b609843a-a193-4729-b099-5f8780532333"
+ },
+ "outputs": [],
+ "source": [
+ "!wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "VA3Haix3IZ8Q"
+ },
+ "outputs": [],
+ "source": [
+ "from dataclasses import dataclass\n",
+ "import time\n",
+ "import sentencepiece\n",
+ "import sphn\n",
+ "import textwrap\n",
+ "import torch\n",
+ "\n",
+ "from moshi.models import loaders, MimiModel, LMModel, LMGen"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "9AK5zBMTI9bw"
+ },
+ "outputs": [],
+ "source": [
+ "@dataclass\n",
+ "class InferenceState:\n",
+ " mimi: MimiModel\n",
+ " text_tokenizer: sentencepiece.SentencePieceProcessor\n",
+ " lm_gen: LMGen\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " mimi: MimiModel,\n",
+ " text_tokenizer: sentencepiece.SentencePieceProcessor,\n",
+ " lm: LMModel,\n",
+ " batch_size: int,\n",
+ " device: str | torch.device,\n",
+ " ):\n",
+ " self.mimi = mimi\n",
+ " self.text_tokenizer = text_tokenizer\n",
+ " self.lm_gen = LMGen(lm, temp=0, temp_text=0, use_sampling=False)\n",
+ " self.device = device\n",
+ " self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)\n",
+ " self.batch_size = batch_size\n",
+ " self.mimi.streaming_forever(batch_size)\n",
+ " self.lm_gen.streaming_forever(batch_size)\n",
+ "\n",
+ " def run(self, in_pcms: torch.Tensor):\n",
+ " device = self.lm_gen.lm_model.device\n",
+ " ntokens = 0\n",
+ " first_frame = True\n",
+ " chunks = [\n",
+ " c\n",
+ " for c in in_pcms.split(self.frame_size, dim=2)\n",
+ " if c.shape[-1] == self.frame_size\n",
+ " ]\n",
+ " start_time = time.time()\n",
+ " all_text = []\n",
+ " for chunk in chunks:\n",
+ " codes = self.mimi.encode(chunk)\n",
+ " if first_frame:\n",
+ " # Ensure that the first slice of codes is properly seen by the transformer\n",
+ " # as otherwise the first slice is replaced by the initial tokens.\n",
+ " tokens = self.lm_gen.step(codes)\n",
+ " first_frame = False\n",
+ " tokens = self.lm_gen.step(codes)\n",
+ " if tokens is None:\n",
+ " continue\n",
+ " assert tokens.shape[1] == 1\n",
+ " one_text = tokens[0, 0].cpu()\n",
+ " if one_text.item() not in [0, 3]:\n",
+ " text = self.text_tokenizer.id_to_piece(one_text.item())\n",
+ " text = text.replace(\"▁\", \" \")\n",
+ " all_text.append(text)\n",
+ " ntokens += 1\n",
+ " dt = time.time() - start_time\n",
+ " print(\n",
+ " f\"processed {ntokens} steps in {dt:.0f}s, {1000 * dt / ntokens:.2f}ms/step\"\n",
+ " )\n",
+ " return \"\".join(all_text)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 353,
+ "referenced_widgets": [
+ "0a5f6f887e2b4cd1990a0e9ec0153ed9",
+ "f7893826fcba4bdc87539589d669249b",
+ "8805afb12c484781be85082ff02dad13",
+ "97679c0d9ab44bed9a3456f2fcb541fd",
+ "d73c0321bed54a52b5e1da0a7788e32a",
+ "d67be13a920d4fc89e5570b5b29fc1d2",
+ "6b377c2d7bf945fb89e46c39d246a332",
+ "b82ff365c78e41ad8094b46daf79449d",
+ "477aa7fa82dc42d5bce6f1743c45d626",
+ "cbd288510c474430beb66f346f382c45",
+ "aafc347cdf28428ea6a7abe5b46b726f",
+ "fca09acd5d0d45468c8b04bfb2de7646",
+ "79e35214b51b4a9e9b3f7144b0b34f7b",
+ "89e9a37f69904bd48b954d627bff6687",
+ "57028789c78248a7b0ad4f031c9545c9",
+ "1150fcb427994c2984d4d0f4e4745fe5",
+ "e24b1fc52f294f849019c9b3befb613f",
+ "8724878682cf4c3ca992667c45009398",
+ "36a22c977d5242008871310133b7d2af",
+ "5b3683cad5cb4877b43fadd003edf97f",
+ "703f98272e4d469d8f27f5a465715dd8",
+ "9dbe02ef5fac41cfaee3d02946e65c88",
+ "37faa87ad03a4271992c21ce6a629e18",
+ "570c547e48cd421b814b2c5e028e4c0b",
+ "b173768580fc4c0a8e3abf272e4c363a",
+ "e57d1620f0a9427b85d8b4885ef4e8e3",
+ "5dd4474df70743498b616608182714dd",
+ "cc907676a65f4ad1bf68a77b4a00e89b",
+ "a34abc3b118e4305951a466919c28ff6",
+ "a77ccfcdb90146c7a63b4b2d232bc494",
+ "f7313e6e3a27475993cab3961d6ae363",
+ "39b47fad9c554839868fe9e4bbf7def2",
+ "14e9511ea0bd44c49f0cf3abf1a6d40e",
+ "a4ea8e0c4cac4d5e88b7e3f527e4fe90",
+ "571afc0f4b2840c9830d6b5a307ed1f9",
+ "6ec593cab5b64f0ea638bb175b9daa5c",
+ "77a52aed00ae408bb24524880e19ec8a",
+ "0b2de4b29b4b44fe9d96361a40c793d0",
+ "3c5b5fb1a5ac468a89c1058bd90cfb58",
+ "e53e0a2a240e43cfa562c89b3d703dea",
+ "35966343cf9249ef8bc028a0d5c5f97d",
+ "e36a37e0d41c47ccb8bc6d56c19fb17c",
+ "279ccf7de43847a1a6579c9182a46cc8",
+ "41b5d6ab0b7d43c790a55f125c0e7494"
+ ]
+ },
+ "id": "UsQJdAgkLp9n",
+ "outputId": "9b7131c3-69c5-4323-8312-2ce7621d8869"
+ },
+ "outputs": [],
+ "source": [
+ "device = \"cuda\"\n",
+ "# Use the en+fr low latency model, an alternative is kyutai/stt-2.6b-en\n",
+ "checkpoint_info = loaders.CheckpointInfo.from_hf_repo(\"kyutai/stt-1b-en_fr\")\n",
+ "mimi = checkpoint_info.get_mimi(device=device)\n",
+ "text_tokenizer = checkpoint_info.get_text_tokenizer()\n",
+ "lm = checkpoint_info.get_moshi(device=device)\n",
+ "in_pcms, _ = sphn.read(\"sample_fr_hibiki_crepes.mp3\", sample_rate=mimi.sample_rate)\n",
+ "in_pcms = torch.from_numpy(in_pcms).to(device=device)\n",
+ "\n",
+ "stt_config = checkpoint_info.stt_config\n",
+ "pad_left = int(stt_config.get(\"audio_silence_prefix_seconds\", 0.0) * 24000)\n",
+ "pad_right = int((stt_config.get(\"audio_delay_seconds\", 0.0) + 1.0) * 24000)\n",
+ "in_pcms = torch.nn.functional.pad(in_pcms, (pad_left, pad_right), mode=\"constant\")\n",
+ "in_pcms = in_pcms[None, 0:1].expand(1, -1, -1)\n",
+ "\n",
+ "state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device)\n",
+ "text = state.run(in_pcms)\n",
+ "print(textwrap.fill(text, width=100))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "CIAXs9oaPrtj",
+ "outputId": "94cc208c-2454-4dd4-a64e-d79025144af5"
+ },
+ "outputs": [],
+ "source": [
+ "from IPython.display import Audio\n",
+ "\n",
+ "Audio(\"sample_fr_hibiki_crepes.mp3\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qkUZ6CBKOdTa"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "L4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}