Compare commits
10 Commits
main
...
vv/clean-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a55e94a54a | ||
|
|
332f4109c2 | ||
|
|
bf458a9cb6 | ||
|
|
9ba717e547 | ||
|
|
3e4384492e | ||
|
|
b036dc2f68 | ||
|
|
15c8c305ad | ||
|
|
8c18018b6f | ||
|
|
bb0bdbf697 | ||
|
|
7b818c2636 |
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -192,5 +192,3 @@ cython_debug/
|
|||
# refer to https://docs.cursor.com/context/ignore-files
|
||||
.cursorignore
|
||||
.cursorindexingignore
|
||||
bria.mp3
|
||||
sample_fr_hibiki_crepes.mp3
|
||||
|
|
|
|||
20
README.md
20
README.md
|
|
@ -48,12 +48,6 @@ Here is how to choose which one to use:
|
|||
MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon.
|
||||
If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
|
||||
|
||||
You can retrieve the sample files used in the following snippets via:
|
||||
```bash
|
||||
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
|
||||
wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3
|
||||
```
|
||||
|
||||
### PyTorch implementation
|
||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;">
|
||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||
|
|
@ -70,12 +64,12 @@ This requires the [moshi package](https://pypi.org/project/moshi/)
|
|||
with version 0.2.6 or later, which can be installed via pip.
|
||||
|
||||
```bash
|
||||
python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
|
||||
python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3
|
||||
```
|
||||
|
||||
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly:
|
||||
```bash
|
||||
uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
|
||||
uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3
|
||||
```
|
||||
|
||||
Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs:
|
||||
|
|
@ -84,7 +78,7 @@ Additionally, we provide two scripts that highlight different usage scenarios. T
|
|||
uv run \
|
||||
scripts/transcribe_from_file_via_pytorch.py \
|
||||
--hf-repo kyutai/stt-2.6b-en \
|
||||
--file bria.mp3
|
||||
--file audio/bria.mp3
|
||||
```
|
||||
|
||||
The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics:
|
||||
|
|
@ -130,7 +124,7 @@ uv run scripts/transcribe_from_mic_via_rust_server.py
|
|||
|
||||
We also provide a script for transcribing from an audio file.
|
||||
```bash
|
||||
uv run scripts/transcribe_from_file_via_rust_server.py bria.mp3
|
||||
uv run scripts/transcribe_from_file_via_rust_server.py audio/bria.mp3
|
||||
```
|
||||
|
||||
The script limits the decoding speed to simulates real-time processing of the audio.
|
||||
|
|
@ -147,7 +141,7 @@ A standalone Rust example script is provided in the `stt-rs` directory in this r
|
|||
This can be used as follows:
|
||||
```bash
|
||||
cd stt-rs
|
||||
cargo run --features cuda -r -- bria.mp3
|
||||
cargo run --features cuda -r -- audio/bria.mp3
|
||||
```
|
||||
You can get the timestamps by adding the `--timestamps` flag, and see the output
|
||||
of the semantic VAD by adding the `--vad` flag.
|
||||
|
|
@ -164,12 +158,12 @@ This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/)
|
|||
with version 0.2.6 or later, which can be installed via pip.
|
||||
|
||||
```bash
|
||||
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0
|
||||
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0
|
||||
```
|
||||
|
||||
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly:
|
||||
```bash
|
||||
uvx --with moshi-mlx python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0
|
||||
uvx --with moshi-mlx python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0
|
||||
```
|
||||
It will install the moshi package in a temporary environment and run the speech-to-text.
|
||||
|
||||
|
|
|
|||
BIN
audio/bria.mp3
Normal file
BIN
audio/bria.mp3
Normal file
Binary file not shown.
BIN
audio/sample_fr_hibiki_crepes.mp3
Normal file
BIN
audio/sample_fr_hibiki_crepes.mp3
Normal file
Binary file not shown.
|
|
@ -1,7 +1,7 @@
|
|||
static_dir = "./static/"
|
||||
log_dir = "$HOME/tmp/tts-logs"
|
||||
instance_name = "tts"
|
||||
authorized_ids = ["open_token"]
|
||||
authorized_ids = ["public_token"]
|
||||
|
||||
[modules.asr]
|
||||
path = "/api/asr-streaming"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
static_dir = "./static/"
|
||||
log_dir = "$HOME/tmp/tts-logs"
|
||||
instance_name = "tts"
|
||||
authorized_ids = ["open_token"]
|
||||
authorized_ids = ["public_token"]
|
||||
|
||||
[modules.asr]
|
||||
path = "/api/asr-streaming"
|
||||
|
|
|
|||
|
|
@ -9,129 +9,127 @@
|
|||
# ///
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
import sphn
|
||||
import websockets
|
||||
|
||||
# Desired audio properties
|
||||
TARGET_SAMPLE_RATE = 24000
|
||||
TARGET_CHANNELS = 1 # Mono
|
||||
HEADERS = {"kyutai-api-key": "open_token"}
|
||||
all_text = []
|
||||
transcript = []
|
||||
finished = False
|
||||
SAMPLE_RATE = 24000
|
||||
FRAME_SIZE = 1920 # Send data in chunks
|
||||
|
||||
|
||||
def load_and_process_audio(file_path):
|
||||
"""Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
|
||||
pcm_data, _ = sphn.read(file_path, sample_rate=TARGET_SAMPLE_RATE)
|
||||
pcm_data, _ = sphn.read(file_path, sample_rate=SAMPLE_RATE)
|
||||
return pcm_data[0]
|
||||
|
||||
|
||||
async def receive_messages(websocket):
|
||||
global all_text
|
||||
global transcript
|
||||
global finished
|
||||
try:
|
||||
async for message in websocket:
|
||||
data = msgpack.unpackb(message, raw=False)
|
||||
if data["type"] == "Step":
|
||||
continue
|
||||
print("received:", data)
|
||||
if data["type"] == "Word":
|
||||
all_text.append(data["text"])
|
||||
transcript.append(
|
||||
{
|
||||
"speaker": "SPEAKER_00",
|
||||
"text": data["text"],
|
||||
"timestamp": [data["start_time"], data["start_time"]],
|
||||
}
|
||||
)
|
||||
if data["type"] == "EndWord":
|
||||
if len(transcript) > 0:
|
||||
transcript[-1]["timestamp"][1] = data["stop_time"]
|
||||
if data["type"] == "Marker":
|
||||
print("Received marker, stopping stream.")
|
||||
break
|
||||
except websockets.ConnectionClosed:
|
||||
print("Connection closed while receiving messages.")
|
||||
finished = True
|
||||
transcript = []
|
||||
|
||||
async for message in websocket:
|
||||
data = msgpack.unpackb(message, raw=False)
|
||||
if data["type"] == "Step":
|
||||
# This message contains the signal from the semantic VAD, and tells us how
|
||||
# much audio the server has already processed. We don't use either here.
|
||||
continue
|
||||
if data["type"] == "Word":
|
||||
print(data["text"], end=" ", flush=True)
|
||||
transcript.append(
|
||||
{
|
||||
"text": data["text"],
|
||||
"timestamp": [data["start_time"], data["start_time"]],
|
||||
}
|
||||
)
|
||||
if data["type"] == "EndWord":
|
||||
if len(transcript) > 0:
|
||||
transcript[-1]["timestamp"][1] = data["stop_time"]
|
||||
if data["type"] == "Marker":
|
||||
# Received marker, stopping stream
|
||||
break
|
||||
|
||||
return transcript
|
||||
|
||||
|
||||
async def send_messages(websocket, rtf: float):
|
||||
global finished
|
||||
audio_data = load_and_process_audio(args.in_file)
|
||||
try:
|
||||
# Start with a second of silence.
|
||||
# This is needed for the 2.6B model for technical reasons.
|
||||
chunk = {"type": "Audio", "pcm": [0.0] * 24000}
|
||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||
await websocket.send(msg)
|
||||
|
||||
chunk_size = 1920 # Send data in chunks
|
||||
start_time = time.time()
|
||||
for i in range(0, len(audio_data), chunk_size):
|
||||
chunk = {
|
||||
"type": "Audio",
|
||||
"pcm": [float(x) for x in audio_data[i : i + chunk_size]],
|
||||
}
|
||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||
await websocket.send(msg)
|
||||
expected_send_time = start_time + (i + 1) / 24000 / rtf
|
||||
current_time = time.time()
|
||||
if current_time < expected_send_time:
|
||||
await asyncio.sleep(expected_send_time - current_time)
|
||||
else:
|
||||
await asyncio.sleep(0.001)
|
||||
chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5}
|
||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||
await websocket.send(msg)
|
||||
msg = msgpack.packb(
|
||||
{"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True
|
||||
async def send_audio(audio: np.ndarray):
|
||||
await websocket.send(
|
||||
msgpack.packb(
|
||||
{"type": "Audio", "pcm": [float(x) for x in audio]},
|
||||
use_single_float=True,
|
||||
)
|
||||
)
|
||||
await websocket.send(msg)
|
||||
for _ in range(35):
|
||||
chunk = {"type": "Audio", "pcm": [0.0] * 1920}
|
||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||
await websocket.send(msg)
|
||||
while True:
|
||||
if finished:
|
||||
break
|
||||
await asyncio.sleep(1.0)
|
||||
# Keep the connection alive as there is a 20s timeout on the rust side.
|
||||
await websocket.ping()
|
||||
except websockets.ConnectionClosed:
|
||||
print("Connection closed while sending messages.")
|
||||
|
||||
# Start with a second of silence.
|
||||
# This is needed for the 2.6B model for technical reasons.
|
||||
await send_audio([0.0] * SAMPLE_RATE)
|
||||
|
||||
start_time = time.time()
|
||||
for i in range(0, len(audio_data), FRAME_SIZE):
|
||||
await send_audio(audio_data[i : i + FRAME_SIZE])
|
||||
|
||||
expected_send_time = start_time + (i + 1) / SAMPLE_RATE / rtf
|
||||
current_time = time.time()
|
||||
if current_time < expected_send_time:
|
||||
await asyncio.sleep(expected_send_time - current_time)
|
||||
else:
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
for _ in range(5):
|
||||
await send_audio([0.0] * SAMPLE_RATE)
|
||||
|
||||
# Send a marker to indicate the end of the stream.
|
||||
await websocket.send(
|
||||
msgpack.packb({"type": "Marker", "id": 0}, use_single_float=True)
|
||||
)
|
||||
|
||||
# We'll get back the marker once the corresponding audio has been transcribed,
|
||||
# accounting for the delay of the model. That's why we need to send some silence
|
||||
# after the marker, because the model will not return the marker immediately.
|
||||
for _ in range(35):
|
||||
await send_audio([0.0] * SAMPLE_RATE)
|
||||
|
||||
|
||||
async def stream_audio(url: str, rtf: float):
|
||||
async def stream_audio(url: str, api_key: str, rtf: float):
|
||||
"""Stream audio data to a WebSocket server."""
|
||||
headers = {"kyutai-api-key": api_key}
|
||||
|
||||
async with websockets.connect(url, additional_headers=HEADERS) as websocket:
|
||||
# Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
|
||||
async with websockets.connect(url, additional_headers=headers) as websocket:
|
||||
send_task = asyncio.create_task(send_messages(websocket, rtf))
|
||||
receive_task = asyncio.create_task(receive_messages(websocket))
|
||||
await asyncio.gather(send_task, receive_task)
|
||||
print("exiting")
|
||||
_, transcript = await asyncio.gather(send_task, receive_task)
|
||||
|
||||
return transcript
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("in_file")
|
||||
parser.add_argument("--transcript")
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
help="The url of the server to which to send the audio",
|
||||
default="ws://127.0.0.1:8080",
|
||||
)
|
||||
parser.add_argument("--rtf", type=float, default=1.01)
|
||||
parser.add_argument("--api-key", default="public_token")
|
||||
parser.add_argument(
|
||||
"--rtf",
|
||||
type=float,
|
||||
default=1.01,
|
||||
help="The real-time factor of how fast to feed in the audio.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.url}/api/asr-streaming"
|
||||
asyncio.run(stream_audio(url, args.rtf))
|
||||
print(" ".join(all_text))
|
||||
if args.transcript is not None:
|
||||
with open(args.transcript, "w") as fobj:
|
||||
json.dump({"transcript": transcript}, fobj, indent=4)
|
||||
transcript = asyncio.run(stream_audio(url, args.api_key, args.rtf))
|
||||
|
||||
print()
|
||||
print()
|
||||
for word in transcript:
|
||||
print(
|
||||
f"{word['timestamp'][0]:7.2f} -{word['timestamp'][1]:7.2f} {word['text']}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,43 +16,59 @@ import numpy as np
|
|||
import sounddevice as sd
|
||||
import websockets
|
||||
|
||||
# Desired audio properties
|
||||
TARGET_SAMPLE_RATE = 24000
|
||||
TARGET_CHANNELS = 1 # Mono
|
||||
audio_queue = asyncio.Queue()
|
||||
SAMPLE_RATE = 24000
|
||||
|
||||
# The VAD has several prediction heads, each of which tries to determine whether there
|
||||
# has been a pause of a given length. The lengths are 0.5, 1.0, 2.0, and 3.0 seconds.
|
||||
# Lower indices predict pauses more aggressively. In Unmute, we use 2.0 seconds = index 2.
|
||||
PAUSE_PREDICTION_HEAD_INDEX = 2
|
||||
|
||||
|
||||
async def receive_messages(websocket):
|
||||
async def receive_messages(websocket, show_vad: bool = False):
|
||||
"""Receive and process messages from the WebSocket server."""
|
||||
try:
|
||||
speech_started = False
|
||||
async for message in websocket:
|
||||
data = msgpack.unpackb(message, raw=False)
|
||||
if data["type"] == "Word":
|
||||
|
||||
# The Step message only gets sent if the model has semantic VAD available
|
||||
if data["type"] == "Step" and show_vad:
|
||||
pause_prediction = data["prs"][PAUSE_PREDICTION_HEAD_INDEX]
|
||||
if pause_prediction > 0.5 and speech_started:
|
||||
print("| ", end="", flush=True)
|
||||
speech_started = False
|
||||
|
||||
elif data["type"] == "Word":
|
||||
print(data["text"], end=" ", flush=True)
|
||||
speech_started = True
|
||||
except websockets.ConnectionClosed:
|
||||
print("Connection closed while receiving messages.")
|
||||
|
||||
|
||||
async def send_messages(websocket):
|
||||
async def send_messages(websocket, audio_queue):
|
||||
"""Send audio data from microphone to WebSocket server."""
|
||||
try:
|
||||
# Start by draining the queue to avoid lags
|
||||
while not audio_queue.empty():
|
||||
await audio_queue.get()
|
||||
|
||||
print("Starting the transcription")
|
||||
|
||||
while True:
|
||||
audio_data = await audio_queue.get()
|
||||
chunk = {"type": "Audio", "pcm": [float(x) for x in audio_data]}
|
||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||
await websocket.send(msg)
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
print("Connection closed while sending messages.")
|
||||
|
||||
|
||||
async def stream_audio(url: str, api_key: str):
|
||||
async def stream_audio(url: str, api_key: str, show_vad: bool):
|
||||
"""Stream audio data to a WebSocket server."""
|
||||
print("Starting microphone recording...")
|
||||
print("Press Ctrl+C to stop recording")
|
||||
audio_queue = asyncio.Queue()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
|
@ -63,16 +79,19 @@ async def stream_audio(url: str, api_key: str):
|
|||
|
||||
# Start audio stream
|
||||
with sd.InputStream(
|
||||
samplerate=TARGET_SAMPLE_RATE,
|
||||
channels=TARGET_CHANNELS,
|
||||
samplerate=SAMPLE_RATE,
|
||||
channels=1,
|
||||
dtype="float32",
|
||||
callback=audio_callback,
|
||||
blocksize=1920, # 80ms blocks
|
||||
):
|
||||
headers = {"kyutai-api-key": api_key}
|
||||
# Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
|
||||
async with websockets.connect(url, additional_headers=headers) as websocket:
|
||||
send_task = asyncio.create_task(send_messages(websocket))
|
||||
receive_task = asyncio.create_task(receive_messages(websocket))
|
||||
send_task = asyncio.create_task(send_messages(websocket, audio_queue))
|
||||
receive_task = asyncio.create_task(
|
||||
receive_messages(websocket, show_vad=show_vad)
|
||||
)
|
||||
await asyncio.gather(send_task, receive_task)
|
||||
|
||||
|
||||
|
|
@ -83,18 +102,23 @@ if __name__ == "__main__":
|
|||
help="The URL of the server to which to send the audio",
|
||||
default="ws://127.0.0.1:8080",
|
||||
)
|
||||
parser.add_argument("--api-key", default="open_token")
|
||||
parser.add_argument("--api-key", default="public_token")
|
||||
parser.add_argument(
|
||||
"--list-devices", action="store_true", help="List available audio devices"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=int, help="Input device ID (use --list-devices to see options)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-vad",
|
||||
action="store_true",
|
||||
help="Visualize the predictions of the semantic voice activity detector with a '|' symbol",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def handle_sigint(signum, frame):
|
||||
print("Interrupted by user")
|
||||
print("Interrupted by user") # Don't complain about KeyboardInterrupt
|
||||
exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
|
@ -108,4 +132,4 @@ if __name__ == "__main__":
|
|||
sd.default.device[0] = args.device # Set input device
|
||||
|
||||
url = f"{args.url}/api/asr-streaming"
|
||||
asyncio.run(stream_audio(url, args.api_key))
|
||||
asyncio.run(stream_audio(url, args.api_key, args.show_vad))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user