Refactor Rust server examples (#19)

* Rename examples and add pre-commit

* Fix references to scripts, add implementations overview

* Link to colab notebook via github

* Simplify

* Add auth note

* Allow visualizing VAD

* Remove unused variable

* Add audio samples

* Address review comments
This commit is contained in:
Václav Volhejn 2025-06-26 16:51:43 +02:00 committed by GitHub
parent 96eef33c4c
commit 0112245ef7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 133 additions and 119 deletions

2
.gitignore vendored
View File

@ -192,5 +192,3 @@ cython_debug/
# refer to https://docs.cursor.com/context/ignore-files # refer to https://docs.cursor.com/context/ignore-files
.cursorignore .cursorignore
.cursorindexingignore .cursorindexingignore
bria.mp3
sample_fr_hibiki_crepes.mp3

View File

@ -48,12 +48,6 @@ Here is how to choose which one to use:
MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon. MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon.
If you want to run the model on a Mac or an iPhone, choose the MLX implementation. If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
You can retrieve the sample files used in the following snippets via:
```bash
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3
```
### PyTorch implementation ### PyTorch implementation
<a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;"> <a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/> <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
@ -70,12 +64,12 @@ This requires the [moshi package](https://pypi.org/project/moshi/)
with version 0.2.6 or later, which can be installed via pip. with version 0.2.6 or later, which can be installed via pip.
```bash ```bash
python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3 python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3
``` ```
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly: If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly:
```bash ```bash
uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3 uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3
``` ```
Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs: Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs:
@ -84,7 +78,7 @@ Additionally, we provide two scripts that highlight different usage scenarios. T
uv run \ uv run \
scripts/transcribe_from_file_via_pytorch.py \ scripts/transcribe_from_file_via_pytorch.py \
--hf-repo kyutai/stt-2.6b-en \ --hf-repo kyutai/stt-2.6b-en \
--file bria.mp3 --file audio/bria.mp3
``` ```
The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics: The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics:
@ -130,7 +124,7 @@ uv run scripts/transcribe_from_mic_via_rust_server.py
We also provide a script for transcribing from an audio file. We also provide a script for transcribing from an audio file.
```bash ```bash
uv run scripts/transcribe_from_file_via_rust_server.py bria.mp3 uv run scripts/transcribe_from_file_via_rust_server.py audio/bria.mp3
``` ```
The script limits the decoding speed to simulates real-time processing of the audio. The script limits the decoding speed to simulates real-time processing of the audio.
@ -147,7 +141,7 @@ A standalone Rust example script is provided in the `stt-rs` directory in this r
This can be used as follows: This can be used as follows:
```bash ```bash
cd stt-rs cd stt-rs
cargo run --features cuda -r -- bria.mp3 cargo run --features cuda -r -- audio/bria.mp3
``` ```
You can get the timestamps by adding the `--timestamps` flag, and see the output You can get the timestamps by adding the `--timestamps` flag, and see the output
of the semantic VAD by adding the `--vad` flag. of the semantic VAD by adding the `--vad` flag.
@ -164,12 +158,12 @@ This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/)
with version 0.2.6 or later, which can be installed via pip. with version 0.2.6 or later, which can be installed via pip.
```bash ```bash
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0 python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0
``` ```
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly: If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly:
```bash ```bash
uvx --with moshi-mlx python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0 uvx --with moshi-mlx python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0
``` ```
It will install the moshi package in a temporary environment and run the speech-to-text. It will install the moshi package in a temporary environment and run the speech-to-text.

BIN
audio/bria.mp3 Normal file

Binary file not shown.

Binary file not shown.

View File

@ -1,7 +1,7 @@
static_dir = "./static/" static_dir = "./static/"
log_dir = "$HOME/tmp/tts-logs" log_dir = "$HOME/tmp/tts-logs"
instance_name = "tts" instance_name = "tts"
authorized_ids = ["open_token"] authorized_ids = ["public_token"]
[modules.asr] [modules.asr]
path = "/api/asr-streaming" path = "/api/asr-streaming"

View File

@ -1,7 +1,7 @@
static_dir = "./static/" static_dir = "./static/"
log_dir = "$HOME/tmp/tts-logs" log_dir = "$HOME/tmp/tts-logs"
instance_name = "tts" instance_name = "tts"
authorized_ids = ["open_token"] authorized_ids = ["public_token"]
[modules.asr] [modules.asr]
path = "/api/asr-streaming" path = "/api/asr-streaming"

View File

@ -9,129 +9,127 @@
# /// # ///
import argparse import argparse
import asyncio import asyncio
import json
import time import time
import msgpack import msgpack
import numpy as np
import sphn import sphn
import websockets import websockets
# Desired audio properties SAMPLE_RATE = 24000
TARGET_SAMPLE_RATE = 24000 FRAME_SIZE = 1920 # Send data in chunks
TARGET_CHANNELS = 1 # Mono
HEADERS = {"kyutai-api-key": "open_token"}
all_text = []
transcript = []
finished = False
def load_and_process_audio(file_path): def load_and_process_audio(file_path):
"""Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data.""" """Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
pcm_data, _ = sphn.read(file_path, sample_rate=TARGET_SAMPLE_RATE) pcm_data, _ = sphn.read(file_path, sample_rate=SAMPLE_RATE)
return pcm_data[0] return pcm_data[0]
async def receive_messages(websocket): async def receive_messages(websocket):
global all_text transcript = []
global transcript
global finished async for message in websocket:
try: data = msgpack.unpackb(message, raw=False)
async for message in websocket: if data["type"] == "Step":
data = msgpack.unpackb(message, raw=False) # This message contains the signal from the semantic VAD, and tells us how
if data["type"] == "Step": # much audio the server has already processed. We don't use either here.
continue continue
print("received:", data) if data["type"] == "Word":
if data["type"] == "Word": print(data["text"], end=" ", flush=True)
all_text.append(data["text"]) transcript.append(
transcript.append( {
{ "text": data["text"],
"speaker": "SPEAKER_00", "timestamp": [data["start_time"], data["start_time"]],
"text": data["text"], }
"timestamp": [data["start_time"], data["start_time"]], )
} if data["type"] == "EndWord":
) if len(transcript) > 0:
if data["type"] == "EndWord": transcript[-1]["timestamp"][1] = data["stop_time"]
if len(transcript) > 0: if data["type"] == "Marker":
transcript[-1]["timestamp"][1] = data["stop_time"] # Received marker, stopping stream
if data["type"] == "Marker": break
print("Received marker, stopping stream.")
break return transcript
except websockets.ConnectionClosed:
print("Connection closed while receiving messages.")
finished = True
async def send_messages(websocket, rtf: float): async def send_messages(websocket, rtf: float):
global finished
audio_data = load_and_process_audio(args.in_file) audio_data = load_and_process_audio(args.in_file)
try:
# Start with a second of silence.
# This is needed for the 2.6B model for technical reasons.
chunk = {"type": "Audio", "pcm": [0.0] * 24000}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
chunk_size = 1920 # Send data in chunks async def send_audio(audio: np.ndarray):
start_time = time.time() await websocket.send(
for i in range(0, len(audio_data), chunk_size): msgpack.packb(
chunk = { {"type": "Audio", "pcm": [float(x) for x in audio]},
"type": "Audio", use_single_float=True,
"pcm": [float(x) for x in audio_data[i : i + chunk_size]], )
}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
expected_send_time = start_time + (i + 1) / 24000 / rtf
current_time = time.time()
if current_time < expected_send_time:
await asyncio.sleep(expected_send_time - current_time)
else:
await asyncio.sleep(0.001)
chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
msg = msgpack.packb(
{"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True
) )
await websocket.send(msg)
for _ in range(35): # Start with a second of silence.
chunk = {"type": "Audio", "pcm": [0.0] * 1920} # This is needed for the 2.6B model for technical reasons.
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) await send_audio([0.0] * SAMPLE_RATE)
await websocket.send(msg)
while True: start_time = time.time()
if finished: for i in range(0, len(audio_data), FRAME_SIZE):
break await send_audio(audio_data[i : i + FRAME_SIZE])
await asyncio.sleep(1.0)
# Keep the connection alive as there is a 20s timeout on the rust side. expected_send_time = start_time + (i + 1) / SAMPLE_RATE / rtf
await websocket.ping() current_time = time.time()
except websockets.ConnectionClosed: if current_time < expected_send_time:
print("Connection closed while sending messages.") await asyncio.sleep(expected_send_time - current_time)
else:
await asyncio.sleep(0.001)
for _ in range(5):
await send_audio([0.0] * SAMPLE_RATE)
# Send a marker to indicate the end of the stream.
await websocket.send(
msgpack.packb({"type": "Marker", "id": 0}, use_single_float=True)
)
# We'll get back the marker once the corresponding audio has been transcribed,
# accounting for the delay of the model. That's why we need to send some silence
# after the marker, because the model will not return the marker immediately.
for _ in range(35):
await send_audio([0.0] * SAMPLE_RATE)
async def stream_audio(url: str, rtf: float): async def stream_audio(url: str, api_key: str, rtf: float):
"""Stream audio data to a WebSocket server.""" """Stream audio data to a WebSocket server."""
headers = {"kyutai-api-key": api_key}
async with websockets.connect(url, additional_headers=HEADERS) as websocket: # Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
async with websockets.connect(url, additional_headers=headers) as websocket:
send_task = asyncio.create_task(send_messages(websocket, rtf)) send_task = asyncio.create_task(send_messages(websocket, rtf))
receive_task = asyncio.create_task(receive_messages(websocket)) receive_task = asyncio.create_task(receive_messages(websocket))
await asyncio.gather(send_task, receive_task) _, transcript = await asyncio.gather(send_task, receive_task)
print("exiting")
return transcript
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("in_file") parser.add_argument("in_file")
parser.add_argument("--transcript")
parser.add_argument( parser.add_argument(
"--url", "--url",
help="The url of the server to which to send the audio", help="The url of the server to which to send the audio",
default="ws://127.0.0.1:8080", default="ws://127.0.0.1:8080",
) )
parser.add_argument("--rtf", type=float, default=1.01) parser.add_argument("--api-key", default="public_token")
parser.add_argument(
"--rtf",
type=float,
default=1.01,
help="The real-time factor of how fast to feed in the audio.",
)
args = parser.parse_args() args = parser.parse_args()
url = f"{args.url}/api/asr-streaming" url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.rtf)) transcript = asyncio.run(stream_audio(url, args.api_key, args.rtf))
print(" ".join(all_text))
if args.transcript is not None: print()
with open(args.transcript, "w") as fobj: print()
json.dump({"transcript": transcript}, fobj, indent=4) for word in transcript:
print(
f"{word['timestamp'][0]:7.2f} -{word['timestamp'][1]:7.2f} {word['text']}"
)

View File

@ -16,43 +16,59 @@ import numpy as np
import sounddevice as sd import sounddevice as sd
import websockets import websockets
# Desired audio properties SAMPLE_RATE = 24000
TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono # The VAD has several prediction heads, each of which tries to determine whether there
audio_queue = asyncio.Queue() # has been a pause of a given length. The lengths are 0.5, 1.0, 2.0, and 3.0 seconds.
# Lower indices predict pauses more aggressively. In Unmute, we use 2.0 seconds = index 2.
PAUSE_PREDICTION_HEAD_INDEX = 2
async def receive_messages(websocket): async def receive_messages(websocket, show_vad: bool = False):
"""Receive and process messages from the WebSocket server.""" """Receive and process messages from the WebSocket server."""
try: try:
speech_started = False
async for message in websocket: async for message in websocket:
data = msgpack.unpackb(message, raw=False) data = msgpack.unpackb(message, raw=False)
if data["type"] == "Word":
# The Step message only gets sent if the model has semantic VAD available
if data["type"] == "Step" and show_vad:
pause_prediction = data["prs"][PAUSE_PREDICTION_HEAD_INDEX]
if pause_prediction > 0.5 and speech_started:
print("| ", end="", flush=True)
speech_started = False
elif data["type"] == "Word":
print(data["text"], end=" ", flush=True) print(data["text"], end=" ", flush=True)
speech_started = True
except websockets.ConnectionClosed: except websockets.ConnectionClosed:
print("Connection closed while receiving messages.") print("Connection closed while receiving messages.")
async def send_messages(websocket): async def send_messages(websocket, audio_queue):
"""Send audio data from microphone to WebSocket server.""" """Send audio data from microphone to WebSocket server."""
try: try:
# Start by draining the queue to avoid lags # Start by draining the queue to avoid lags
while not audio_queue.empty(): while not audio_queue.empty():
await audio_queue.get() await audio_queue.get()
print("Starting the transcription") print("Starting the transcription")
while True: while True:
audio_data = await audio_queue.get() audio_data = await audio_queue.get()
chunk = {"type": "Audio", "pcm": [float(x) for x in audio_data]} chunk = {"type": "Audio", "pcm": [float(x) for x in audio_data]}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg) await websocket.send(msg)
except websockets.ConnectionClosed: except websockets.ConnectionClosed:
print("Connection closed while sending messages.") print("Connection closed while sending messages.")
async def stream_audio(url: str, api_key: str): async def stream_audio(url: str, api_key: str, show_vad: bool):
"""Stream audio data to a WebSocket server.""" """Stream audio data to a WebSocket server."""
print("Starting microphone recording...") print("Starting microphone recording...")
print("Press Ctrl+C to stop recording") print("Press Ctrl+C to stop recording")
audio_queue = asyncio.Queue()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -63,16 +79,19 @@ async def stream_audio(url: str, api_key: str):
# Start audio stream # Start audio stream
with sd.InputStream( with sd.InputStream(
samplerate=TARGET_SAMPLE_RATE, samplerate=SAMPLE_RATE,
channels=TARGET_CHANNELS, channels=1,
dtype="float32", dtype="float32",
callback=audio_callback, callback=audio_callback,
blocksize=1920, # 80ms blocks blocksize=1920, # 80ms blocks
): ):
headers = {"kyutai-api-key": api_key} headers = {"kyutai-api-key": api_key}
# Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
async with websockets.connect(url, additional_headers=headers) as websocket: async with websockets.connect(url, additional_headers=headers) as websocket:
send_task = asyncio.create_task(send_messages(websocket)) send_task = asyncio.create_task(send_messages(websocket, audio_queue))
receive_task = asyncio.create_task(receive_messages(websocket)) receive_task = asyncio.create_task(
receive_messages(websocket, show_vad=show_vad)
)
await asyncio.gather(send_task, receive_task) await asyncio.gather(send_task, receive_task)
@ -83,18 +102,23 @@ if __name__ == "__main__":
help="The URL of the server to which to send the audio", help="The URL of the server to which to send the audio",
default="ws://127.0.0.1:8080", default="ws://127.0.0.1:8080",
) )
parser.add_argument("--api-key", default="open_token") parser.add_argument("--api-key", default="public_token")
parser.add_argument( parser.add_argument(
"--list-devices", action="store_true", help="List available audio devices" "--list-devices", action="store_true", help="List available audio devices"
) )
parser.add_argument( parser.add_argument(
"--device", type=int, help="Input device ID (use --list-devices to see options)" "--device", type=int, help="Input device ID (use --list-devices to see options)"
) )
parser.add_argument(
"--show-vad",
action="store_true",
help="Visualize the predictions of the semantic voice activity detector with a '|' symbol",
)
args = parser.parse_args() args = parser.parse_args()
def handle_sigint(signum, frame): def handle_sigint(signum, frame):
print("Interrupted by user") print("Interrupted by user") # Don't complain about KeyboardInterrupt
exit(0) exit(0)
signal.signal(signal.SIGINT, handle_sigint) signal.signal(signal.SIGINT, handle_sigint)
@ -108,4 +132,4 @@ if __name__ == "__main__":
sd.default.device[0] = args.device # Set input device sd.default.device[0] = args.device # Set input device
url = f"{args.url}/api/asr-streaming" url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.api_key)) asyncio.run(stream_audio(url, args.api_key, args.show_vad))