Add the rust indications.
This commit is contained in:
parent
e53c94d6b6
commit
581e1d338b
48
README.md
48
README.md
|
|
@ -3,18 +3,58 @@ Delayed Streams Modeling (DSM) is a flexible formulation for streaming, multimod
|
||||||
|
|
||||||
## Speech To Text
|
## Speech To Text
|
||||||
|
|
||||||
### PyTorch implementation
|
### Leaderboard model
|
||||||
|
The leaderboard model handles english only, it has ~2.6B parameters.
|
||||||
|
|
||||||
|
#### PyTorch implementation
|
||||||
|
[[Hugging Face]](https://huggingface.co/kyutai/stt)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m moshi.run_inference --hf-repo kyutai/stt input.mp3
|
# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
|
||||||
|
python -m moshi.run_inference --hf-repo kyutai/stt bria.mp3
|
||||||
```
|
```
|
||||||
|
|
||||||
### MLX implementation
|
#### MLX implementation
|
||||||
|
[[Hugging Face]](https://huggingface.co/kyutai/stt-mlx)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-mlx ~/tmp/bria-24khz.mp3 --temp 0
|
# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
|
||||||
|
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-mlx bria.mp3 --temp 0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Rust implementation
|
||||||
|
[[Hugging Face]](https://huggingface.co/kyutai/stt-candle)
|
||||||
|
|
||||||
|
The Rust implementation provides a server that can process multiple streaming
|
||||||
|
queries in parallel. Dependening on the amount of memory on your GPU, you may
|
||||||
|
have to adjust the batch size from the config file. For a L40S GPU, a batch size
|
||||||
|
of 64 works well.
|
||||||
|
|
||||||
|
In order to run the server, install the `moshi-server` crate via the following
|
||||||
|
command. The server code can be found in the
|
||||||
|
[kyutai-labs/moshi](https://github.com/kyutai-labs/moshi/tree/main/rust/moshi-server)
|
||||||
|
repository.
|
||||||
|
```bash
|
||||||
|
cargo install --features cuda moshi-server
|
||||||
|
```
|
||||||
|
|
||||||
|
Then the server can be started via the following command using the config file
|
||||||
|
from this repository.
|
||||||
|
```bash
|
||||||
|
moshi-server worker --config configs/config-stt-hf.toml
|
||||||
|
```
|
||||||
|
|
||||||
|
Once the server has started you can run a streaming inference with the following
|
||||||
|
script.
|
||||||
|
```bash
|
||||||
|
# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
|
||||||
|
uv run scripts/asr-streaming-query.py bria.mp3
|
||||||
|
```
|
||||||
|
|
||||||
|
The script simulates some real-time processing of the audio. Faster processing
|
||||||
|
can be triggered by setting the real-time factor, e.g. `--rtf 500` will process
|
||||||
|
the data as fast as possible.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
The present code is provided under the MIT license for the Python parts, and Apache license for the Rust backend.
|
The present code is provided under the MIT license for the Python parts, and Apache license for the Rust backend.
|
||||||
|
|
|
||||||
42
configs/config-stt-hf.toml
Normal file
42
configs/config-stt-hf.toml
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
static_dir = "./static/"
|
||||||
|
log_dir = "$HOME/tmp/tts-logs"
|
||||||
|
instance_name = "tts"
|
||||||
|
authorized_ids = ["open_token"]
|
||||||
|
|
||||||
|
[modules.asr]
|
||||||
|
path = "/api/asr-streaming"
|
||||||
|
type = "BatchedAsr"
|
||||||
|
lm_model_file = "hf://kyutai/stt-candle/model.safetensors"
|
||||||
|
text_tokenizer_file = "hf://kyutai/stt-candle/tokenizer_en_audio_4000.model"
|
||||||
|
audio_tokenizer_file = "hf://kyutai/stt-candle/mimi-pytorch-e351c8d8@125.safetensors"
|
||||||
|
asr_delay_in_tokens = 6
|
||||||
|
batch_size = 16
|
||||||
|
conditioning_learnt_padding = true
|
||||||
|
temperature = 0
|
||||||
|
|
||||||
|
[modules.asr.model]
|
||||||
|
audio_vocab_size = 2049
|
||||||
|
text_in_vocab_size = 4001
|
||||||
|
text_out_vocab_size = 4000
|
||||||
|
audio_codebooks = 32
|
||||||
|
|
||||||
|
[modules.asr.model.transformer]
|
||||||
|
d_model = 2048
|
||||||
|
num_heads = 32
|
||||||
|
num_layers = 48
|
||||||
|
dim_feedforward = 8192
|
||||||
|
causal = true
|
||||||
|
norm_first = true
|
||||||
|
bias_ff = false
|
||||||
|
bias_attn = false
|
||||||
|
context = 375
|
||||||
|
max_period = 100000
|
||||||
|
use_conv_block = false
|
||||||
|
use_conv_bias = true
|
||||||
|
gating = "silu"
|
||||||
|
norm = "RmsNorm"
|
||||||
|
positional_embedding = "Rope"
|
||||||
|
conv_layout = false
|
||||||
|
conv_kernel_size = 3
|
||||||
|
kv_repeat = 1
|
||||||
|
max_seq_len = 40960
|
||||||
131
scripts/asr-streaming-query.py
Normal file
131
scripts/asr-streaming-query.py
Normal file
|
|
@ -0,0 +1,131 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "msgpack",
|
||||||
|
# "numpy",
|
||||||
|
# "sphn",
|
||||||
|
# "websockets",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import msgpack
|
||||||
|
import sphn
|
||||||
|
import struct
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
# Desired audio properties
|
||||||
|
TARGET_SAMPLE_RATE = 24000
|
||||||
|
TARGET_CHANNELS = 1 # Mono
|
||||||
|
HEADERS = {"kyutai-api-key": "open_token"}
|
||||||
|
all_text = []
|
||||||
|
transcript = []
|
||||||
|
finished = False
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_process_audio(file_path):
|
||||||
|
"""Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
|
||||||
|
pcm_data, _ = sphn.read(file_path, sample_rate=TARGET_SAMPLE_RATE)
|
||||||
|
return pcm_data[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def receive_messages(websocket):
|
||||||
|
global all_text
|
||||||
|
global transcript
|
||||||
|
global finished
|
||||||
|
try:
|
||||||
|
async for message in websocket:
|
||||||
|
data = msgpack.unpackb(message, raw=False)
|
||||||
|
if data["type"] == "Step":
|
||||||
|
continue
|
||||||
|
print("received:", data)
|
||||||
|
if data["type"] == "Word":
|
||||||
|
all_text.append(data["text"])
|
||||||
|
transcript.append({
|
||||||
|
"speaker": "SPEAKER_00",
|
||||||
|
"text": data["text"],
|
||||||
|
"timestamp": [data["start_time"], data["start_time"]],
|
||||||
|
})
|
||||||
|
if data["type"] == "EndWord":
|
||||||
|
if len(transcript) > 0:
|
||||||
|
transcript[-1]["timestamp"][1] = data["stop_time"]
|
||||||
|
if data["type"] == "Marker":
|
||||||
|
print("Received marker, stopping stream.")
|
||||||
|
break
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
print("Connection closed while receiving messages.")
|
||||||
|
finished = True
|
||||||
|
|
||||||
|
|
||||||
|
async def send_messages(websocket, rtf: float):
|
||||||
|
global finished
|
||||||
|
audio_data = load_and_process_audio(args.in_file)
|
||||||
|
try:
|
||||||
|
# Start with a second of silence
|
||||||
|
chunk = { "type": "Audio", "pcm": [0.0] * 24000 }
|
||||||
|
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||||
|
await websocket.send(msg)
|
||||||
|
|
||||||
|
chunk_size = 1920 # Send data in chunks
|
||||||
|
start_time = time.time()
|
||||||
|
for i in range(0, len(audio_data), chunk_size):
|
||||||
|
chunk = { "type": "Audio", "pcm": [float(x) for x in audio_data[i : i + chunk_size]] }
|
||||||
|
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||||
|
await websocket.send(msg)
|
||||||
|
expected_send_time = start_time + (i + 1) / 24000 / rtf
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time < expected_send_time:
|
||||||
|
await asyncio.sleep(expected_send_time - current_time)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
chunk = { "type": "Audio", "pcm": [0.0] * 1920 * 5 }
|
||||||
|
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||||
|
await websocket.send(msg)
|
||||||
|
msg = msgpack.packb({"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True)
|
||||||
|
await websocket.send(msg)
|
||||||
|
for _ in range(35):
|
||||||
|
chunk = { "type": "Audio", "pcm": [0.0] * 1920 }
|
||||||
|
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||||
|
await websocket.send(msg)
|
||||||
|
while True:
|
||||||
|
if finished:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
# Keep the connection alive as there is a 20s timeout on the rust side.
|
||||||
|
await websocket.ping()
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
print("Connection closed while sending messages.")
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_audio(url: str, rtf: float):
|
||||||
|
"""Stream audio data to a WebSocket server."""
|
||||||
|
|
||||||
|
async with websockets.connect(url, additional_headers=HEADERS) as websocket:
|
||||||
|
send_task = asyncio.create_task(send_messages(websocket, rtf))
|
||||||
|
receive_task = asyncio.create_task(receive_messages(websocket))
|
||||||
|
await asyncio.gather(send_task, receive_task)
|
||||||
|
print("exiting")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("in_file")
|
||||||
|
parser.add_argument("--transcript")
|
||||||
|
parser.add_argument(
|
||||||
|
"--url",
|
||||||
|
help="The url of the server to which to send the audio",
|
||||||
|
default="ws://127.0.0.1:8080",
|
||||||
|
)
|
||||||
|
parser.add_argument("--rtf", type=float, default=1.01)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
url = f"{args.url}/api/asr-streaming"
|
||||||
|
asyncio.run(stream_audio(url, args.rtf))
|
||||||
|
print(" ".join(all_text))
|
||||||
|
if args.transcript is not None:
|
||||||
|
with open(args.transcript, "w") as fobj:
|
||||||
|
json.dump({"transcript": transcript}, fobj, indent=4)
|
||||||
Loading…
Reference in New Issue
Block a user