Compare commits

...

10 Commits

Author SHA1 Message Date
Václav Volhejn
a55e94a54a Address review comments 2025-06-26 16:36:32 +02:00
Václav Volhejn
332f4109c2 Merge branch 'main' of github.com:kyutai-labs/delayed-streams-modeling into vv/clean-up-examples-2 2025-06-26 15:36:41 +02:00
Václav Volhejn
bf458a9cb6 Add audio samples 2025-06-25 19:13:12 +02:00
Václav Volhejn
9ba717e547 Remove unused variable 2025-06-25 19:08:52 +02:00
Václav Volhejn
3e4384492e Allow visualizing VAD 2025-06-25 19:08:35 +02:00
Václav Volhejn
b036dc2f68 Add auth note 2025-06-25 17:54:30 +02:00
Václav Volhejn
15c8c305ad Simplify 2025-06-25 17:50:53 +02:00
Václav Volhejn
8c18018b6f Link to colab notebook via github 2025-06-25 11:27:33 +02:00
Václav Volhejn
bb0bdbf697 Fix references to scripts, add implementations overview 2025-06-25 11:15:46 +02:00
Václav Volhejn
7b818c2636 Rename examples and add pre-commit 2025-06-25 10:50:14 +02:00
8 changed files with 133 additions and 119 deletions

2
.gitignore vendored
View File

@ -192,5 +192,3 @@ cython_debug/
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
bria.mp3
sample_fr_hibiki_crepes.mp3

View File

@ -48,12 +48,6 @@ Here is how to choose which one to use:
MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon.
If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
You can retrieve the sample files used in the following snippets via:
```bash
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3
```
### PyTorch implementation
<a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
@ -70,12 +64,12 @@ This requires the [moshi package](https://pypi.org/project/moshi/)
with version 0.2.6 or later, which can be installed via pip.
```bash
python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3
```
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly:
```bash
uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3
```
Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs:
@ -84,7 +78,7 @@ Additionally, we provide two scripts that highlight different usage scenarios. T
uv run \
scripts/transcribe_from_file_via_pytorch.py \
--hf-repo kyutai/stt-2.6b-en \
--file bria.mp3
--file audio/bria.mp3
```
The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics:
@ -130,7 +124,7 @@ uv run scripts/transcribe_from_mic_via_rust_server.py
We also provide a script for transcribing from an audio file.
```bash
uv run scripts/transcribe_from_file_via_rust_server.py bria.mp3
uv run scripts/transcribe_from_file_via_rust_server.py audio/bria.mp3
```
The script limits the decoding speed to simulates real-time processing of the audio.
@ -147,7 +141,7 @@ A standalone Rust example script is provided in the `stt-rs` directory in this r
This can be used as follows:
```bash
cd stt-rs
cargo run --features cuda -r -- bria.mp3
cargo run --features cuda -r -- audio/bria.mp3
```
You can get the timestamps by adding the `--timestamps` flag, and see the output
of the semantic VAD by adding the `--vad` flag.
@ -164,12 +158,12 @@ This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/)
with version 0.2.6 or later, which can be installed via pip.
```bash
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0
```
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly:
```bash
uvx --with moshi-mlx python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0
uvx --with moshi-mlx python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0
```
It will install the moshi package in a temporary environment and run the speech-to-text.

BIN
audio/bria.mp3 Normal file

Binary file not shown.

Binary file not shown.

View File

@ -1,7 +1,7 @@
static_dir = "./static/"
log_dir = "$HOME/tmp/tts-logs"
instance_name = "tts"
authorized_ids = ["open_token"]
authorized_ids = ["public_token"]
[modules.asr]
path = "/api/asr-streaming"

View File

@ -1,7 +1,7 @@
static_dir = "./static/"
log_dir = "$HOME/tmp/tts-logs"
instance_name = "tts"
authorized_ids = ["open_token"]
authorized_ids = ["public_token"]
[modules.asr]
path = "/api/asr-streaming"

View File

@ -9,43 +9,36 @@
# ///
import argparse
import asyncio
import json
import time
import msgpack
import numpy as np
import sphn
import websockets
# Desired audio properties
TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono
HEADERS = {"kyutai-api-key": "open_token"}
all_text = []
transcript = []
finished = False
SAMPLE_RATE = 24000
FRAME_SIZE = 1920 # Send data in chunks
def load_and_process_audio(file_path):
"""Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
pcm_data, _ = sphn.read(file_path, sample_rate=TARGET_SAMPLE_RATE)
pcm_data, _ = sphn.read(file_path, sample_rate=SAMPLE_RATE)
return pcm_data[0]
async def receive_messages(websocket):
global all_text
global transcript
global finished
try:
transcript = []
async for message in websocket:
data = msgpack.unpackb(message, raw=False)
if data["type"] == "Step":
# This message contains the signal from the semantic VAD, and tells us how
# much audio the server has already processed. We don't use either here.
continue
print("received:", data)
if data["type"] == "Word":
all_text.append(data["text"])
print(data["text"], end=" ", flush=True)
transcript.append(
{
"speaker": "SPEAKER_00",
"text": data["text"],
"timestamp": [data["start_time"], data["start_time"]],
}
@ -54,84 +47,89 @@ async def receive_messages(websocket):
if len(transcript) > 0:
transcript[-1]["timestamp"][1] = data["stop_time"]
if data["type"] == "Marker":
print("Received marker, stopping stream.")
# Received marker, stopping stream
break
except websockets.ConnectionClosed:
print("Connection closed while receiving messages.")
finished = True
return transcript
async def send_messages(websocket, rtf: float):
global finished
audio_data = load_and_process_audio(args.in_file)
try:
async def send_audio(audio: np.ndarray):
await websocket.send(
msgpack.packb(
{"type": "Audio", "pcm": [float(x) for x in audio]},
use_single_float=True,
)
)
# Start with a second of silence.
# This is needed for the 2.6B model for technical reasons.
chunk = {"type": "Audio", "pcm": [0.0] * 24000}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
await send_audio([0.0] * SAMPLE_RATE)
chunk_size = 1920 # Send data in chunks
start_time = time.time()
for i in range(0, len(audio_data), chunk_size):
chunk = {
"type": "Audio",
"pcm": [float(x) for x in audio_data[i : i + chunk_size]],
}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
expected_send_time = start_time + (i + 1) / 24000 / rtf
for i in range(0, len(audio_data), FRAME_SIZE):
await send_audio(audio_data[i : i + FRAME_SIZE])
expected_send_time = start_time + (i + 1) / SAMPLE_RATE / rtf
current_time = time.time()
if current_time < expected_send_time:
await asyncio.sleep(expected_send_time - current_time)
else:
await asyncio.sleep(0.001)
chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
msg = msgpack.packb(
{"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True
for _ in range(5):
await send_audio([0.0] * SAMPLE_RATE)
# Send a marker to indicate the end of the stream.
await websocket.send(
msgpack.packb({"type": "Marker", "id": 0}, use_single_float=True)
)
await websocket.send(msg)
# We'll get back the marker once the corresponding audio has been transcribed,
# accounting for the delay of the model. That's why we need to send some silence
# after the marker, because the model will not return the marker immediately.
for _ in range(35):
chunk = {"type": "Audio", "pcm": [0.0] * 1920}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
while True:
if finished:
break
await asyncio.sleep(1.0)
# Keep the connection alive as there is a 20s timeout on the rust side.
await websocket.ping()
except websockets.ConnectionClosed:
print("Connection closed while sending messages.")
await send_audio([0.0] * SAMPLE_RATE)
async def stream_audio(url: str, rtf: float):
async def stream_audio(url: str, api_key: str, rtf: float):
"""Stream audio data to a WebSocket server."""
headers = {"kyutai-api-key": api_key}
async with websockets.connect(url, additional_headers=HEADERS) as websocket:
# Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
async with websockets.connect(url, additional_headers=headers) as websocket:
send_task = asyncio.create_task(send_messages(websocket, rtf))
receive_task = asyncio.create_task(receive_messages(websocket))
await asyncio.gather(send_task, receive_task)
print("exiting")
_, transcript = await asyncio.gather(send_task, receive_task)
return transcript
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("in_file")
parser.add_argument("--transcript")
parser.add_argument(
"--url",
help="The url of the server to which to send the audio",
default="ws://127.0.0.1:8080",
)
parser.add_argument("--rtf", type=float, default=1.01)
parser.add_argument("--api-key", default="public_token")
parser.add_argument(
"--rtf",
type=float,
default=1.01,
help="The real-time factor of how fast to feed in the audio.",
)
args = parser.parse_args()
url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.rtf))
print(" ".join(all_text))
if args.transcript is not None:
with open(args.transcript, "w") as fobj:
json.dump({"transcript": transcript}, fobj, indent=4)
transcript = asyncio.run(stream_audio(url, args.api_key, args.rtf))
print()
print()
for word in transcript:
print(
f"{word['timestamp'][0]:7.2f} -{word['timestamp'][1]:7.2f} {word['text']}"
)

View File

@ -16,43 +16,59 @@ import numpy as np
import sounddevice as sd
import websockets
# Desired audio properties
TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono
audio_queue = asyncio.Queue()
SAMPLE_RATE = 24000
# The VAD has several prediction heads, each of which tries to determine whether there
# has been a pause of a given length. The lengths are 0.5, 1.0, 2.0, and 3.0 seconds.
# Lower indices predict pauses more aggressively. In Unmute, we use 2.0 seconds = index 2.
PAUSE_PREDICTION_HEAD_INDEX = 2
async def receive_messages(websocket):
async def receive_messages(websocket, show_vad: bool = False):
"""Receive and process messages from the WebSocket server."""
try:
speech_started = False
async for message in websocket:
data = msgpack.unpackb(message, raw=False)
if data["type"] == "Word":
# The Step message only gets sent if the model has semantic VAD available
if data["type"] == "Step" and show_vad:
pause_prediction = data["prs"][PAUSE_PREDICTION_HEAD_INDEX]
if pause_prediction > 0.5 and speech_started:
print("| ", end="", flush=True)
speech_started = False
elif data["type"] == "Word":
print(data["text"], end=" ", flush=True)
speech_started = True
except websockets.ConnectionClosed:
print("Connection closed while receiving messages.")
async def send_messages(websocket):
async def send_messages(websocket, audio_queue):
"""Send audio data from microphone to WebSocket server."""
try:
# Start by draining the queue to avoid lags
while not audio_queue.empty():
await audio_queue.get()
print("Starting the transcription")
while True:
audio_data = await audio_queue.get()
chunk = {"type": "Audio", "pcm": [float(x) for x in audio_data]}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
except websockets.ConnectionClosed:
print("Connection closed while sending messages.")
async def stream_audio(url: str, api_key: str):
async def stream_audio(url: str, api_key: str, show_vad: bool):
"""Stream audio data to a WebSocket server."""
print("Starting microphone recording...")
print("Press Ctrl+C to stop recording")
audio_queue = asyncio.Queue()
loop = asyncio.get_event_loop()
@ -63,16 +79,19 @@ async def stream_audio(url: str, api_key: str):
# Start audio stream
with sd.InputStream(
samplerate=TARGET_SAMPLE_RATE,
channels=TARGET_CHANNELS,
samplerate=SAMPLE_RATE,
channels=1,
dtype="float32",
callback=audio_callback,
blocksize=1920, # 80ms blocks
):
headers = {"kyutai-api-key": api_key}
# Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
async with websockets.connect(url, additional_headers=headers) as websocket:
send_task = asyncio.create_task(send_messages(websocket))
receive_task = asyncio.create_task(receive_messages(websocket))
send_task = asyncio.create_task(send_messages(websocket, audio_queue))
receive_task = asyncio.create_task(
receive_messages(websocket, show_vad=show_vad)
)
await asyncio.gather(send_task, receive_task)
@ -83,18 +102,23 @@ if __name__ == "__main__":
help="The URL of the server to which to send the audio",
default="ws://127.0.0.1:8080",
)
parser.add_argument("--api-key", default="open_token")
parser.add_argument("--api-key", default="public_token")
parser.add_argument(
"--list-devices", action="store_true", help="List available audio devices"
)
parser.add_argument(
"--device", type=int, help="Input device ID (use --list-devices to see options)"
)
parser.add_argument(
"--show-vad",
action="store_true",
help="Visualize the predictions of the semantic voice activity detector with a '|' symbol",
)
args = parser.parse_args()
def handle_sigint(signum, frame):
print("Interrupted by user")
print("Interrupted by user") # Don't complain about KeyboardInterrupt
exit(0)
signal.signal(signal.SIGINT, handle_sigint)
@ -108,4 +132,4 @@ if __name__ == "__main__":
sd.default.device[0] = args.device # Set input device
url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.api_key))
asyncio.run(stream_audio(url, args.api_key, args.show_vad))