Compare commits
10 Commits
main
...
vv/clean-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a55e94a54a | ||
|
|
332f4109c2 | ||
|
|
bf458a9cb6 | ||
|
|
9ba717e547 | ||
|
|
3e4384492e | ||
|
|
b036dc2f68 | ||
|
|
15c8c305ad | ||
|
|
8c18018b6f | ||
|
|
bb0bdbf697 | ||
|
|
7b818c2636 |
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -192,5 +192,3 @@ cython_debug/
|
||||||
# refer to https://docs.cursor.com/context/ignore-files
|
# refer to https://docs.cursor.com/context/ignore-files
|
||||||
.cursorignore
|
.cursorignore
|
||||||
.cursorindexingignore
|
.cursorindexingignore
|
||||||
bria.mp3
|
|
||||||
sample_fr_hibiki_crepes.mp3
|
|
||||||
|
|
|
||||||
20
README.md
20
README.md
|
|
@ -48,12 +48,6 @@ Here is how to choose which one to use:
|
||||||
MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon.
|
MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon.
|
||||||
If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
|
If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
|
||||||
|
|
||||||
You can retrieve the sample files used in the following snippets via:
|
|
||||||
```bash
|
|
||||||
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
|
|
||||||
wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3
|
|
||||||
```
|
|
||||||
|
|
||||||
### PyTorch implementation
|
### PyTorch implementation
|
||||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;">
|
<a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;">
|
||||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||||
|
|
@ -70,12 +64,12 @@ This requires the [moshi package](https://pypi.org/project/moshi/)
|
||||||
with version 0.2.6 or later, which can be installed via pip.
|
with version 0.2.6 or later, which can be installed via pip.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
|
python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3
|
||||||
```
|
```
|
||||||
|
|
||||||
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly:
|
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly:
|
||||||
```bash
|
```bash
|
||||||
uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
|
uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3
|
||||||
```
|
```
|
||||||
|
|
||||||
Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs:
|
Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs:
|
||||||
|
|
@ -84,7 +78,7 @@ Additionally, we provide two scripts that highlight different usage scenarios. T
|
||||||
uv run \
|
uv run \
|
||||||
scripts/transcribe_from_file_via_pytorch.py \
|
scripts/transcribe_from_file_via_pytorch.py \
|
||||||
--hf-repo kyutai/stt-2.6b-en \
|
--hf-repo kyutai/stt-2.6b-en \
|
||||||
--file bria.mp3
|
--file audio/bria.mp3
|
||||||
```
|
```
|
||||||
|
|
||||||
The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics:
|
The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics:
|
||||||
|
|
@ -130,7 +124,7 @@ uv run scripts/transcribe_from_mic_via_rust_server.py
|
||||||
|
|
||||||
We also provide a script for transcribing from an audio file.
|
We also provide a script for transcribing from an audio file.
|
||||||
```bash
|
```bash
|
||||||
uv run scripts/transcribe_from_file_via_rust_server.py bria.mp3
|
uv run scripts/transcribe_from_file_via_rust_server.py audio/bria.mp3
|
||||||
```
|
```
|
||||||
|
|
||||||
The script limits the decoding speed to simulates real-time processing of the audio.
|
The script limits the decoding speed to simulates real-time processing of the audio.
|
||||||
|
|
@ -147,7 +141,7 @@ A standalone Rust example script is provided in the `stt-rs` directory in this r
|
||||||
This can be used as follows:
|
This can be used as follows:
|
||||||
```bash
|
```bash
|
||||||
cd stt-rs
|
cd stt-rs
|
||||||
cargo run --features cuda -r -- bria.mp3
|
cargo run --features cuda -r -- audio/bria.mp3
|
||||||
```
|
```
|
||||||
You can get the timestamps by adding the `--timestamps` flag, and see the output
|
You can get the timestamps by adding the `--timestamps` flag, and see the output
|
||||||
of the semantic VAD by adding the `--vad` flag.
|
of the semantic VAD by adding the `--vad` flag.
|
||||||
|
|
@ -164,12 +158,12 @@ This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/)
|
||||||
with version 0.2.6 or later, which can be installed via pip.
|
with version 0.2.6 or later, which can be installed via pip.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0
|
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0
|
||||||
```
|
```
|
||||||
|
|
||||||
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly:
|
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly:
|
||||||
```bash
|
```bash
|
||||||
uvx --with moshi-mlx python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0
|
uvx --with moshi-mlx python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0
|
||||||
```
|
```
|
||||||
It will install the moshi package in a temporary environment and run the speech-to-text.
|
It will install the moshi package in a temporary environment and run the speech-to-text.
|
||||||
|
|
||||||
|
|
|
||||||
BIN
audio/bria.mp3
Normal file
BIN
audio/bria.mp3
Normal file
Binary file not shown.
BIN
audio/sample_fr_hibiki_crepes.mp3
Normal file
BIN
audio/sample_fr_hibiki_crepes.mp3
Normal file
Binary file not shown.
|
|
@ -1,7 +1,7 @@
|
||||||
static_dir = "./static/"
|
static_dir = "./static/"
|
||||||
log_dir = "$HOME/tmp/tts-logs"
|
log_dir = "$HOME/tmp/tts-logs"
|
||||||
instance_name = "tts"
|
instance_name = "tts"
|
||||||
authorized_ids = ["open_token"]
|
authorized_ids = ["public_token"]
|
||||||
|
|
||||||
[modules.asr]
|
[modules.asr]
|
||||||
path = "/api/asr-streaming"
|
path = "/api/asr-streaming"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
static_dir = "./static/"
|
static_dir = "./static/"
|
||||||
log_dir = "$HOME/tmp/tts-logs"
|
log_dir = "$HOME/tmp/tts-logs"
|
||||||
instance_name = "tts"
|
instance_name = "tts"
|
||||||
authorized_ids = ["open_token"]
|
authorized_ids = ["public_token"]
|
||||||
|
|
||||||
[modules.asr]
|
[modules.asr]
|
||||||
path = "/api/asr-streaming"
|
path = "/api/asr-streaming"
|
||||||
|
|
|
||||||
|
|
@ -9,129 +9,127 @@
|
||||||
# ///
|
# ///
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
import sphn
|
import sphn
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
# Desired audio properties
|
SAMPLE_RATE = 24000
|
||||||
TARGET_SAMPLE_RATE = 24000
|
FRAME_SIZE = 1920 # Send data in chunks
|
||||||
TARGET_CHANNELS = 1 # Mono
|
|
||||||
HEADERS = {"kyutai-api-key": "open_token"}
|
|
||||||
all_text = []
|
|
||||||
transcript = []
|
|
||||||
finished = False
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_process_audio(file_path):
|
def load_and_process_audio(file_path):
|
||||||
"""Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
|
"""Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
|
||||||
pcm_data, _ = sphn.read(file_path, sample_rate=TARGET_SAMPLE_RATE)
|
pcm_data, _ = sphn.read(file_path, sample_rate=SAMPLE_RATE)
|
||||||
return pcm_data[0]
|
return pcm_data[0]
|
||||||
|
|
||||||
|
|
||||||
async def receive_messages(websocket):
|
async def receive_messages(websocket):
|
||||||
global all_text
|
transcript = []
|
||||||
global transcript
|
|
||||||
global finished
|
async for message in websocket:
|
||||||
try:
|
data = msgpack.unpackb(message, raw=False)
|
||||||
async for message in websocket:
|
if data["type"] == "Step":
|
||||||
data = msgpack.unpackb(message, raw=False)
|
# This message contains the signal from the semantic VAD, and tells us how
|
||||||
if data["type"] == "Step":
|
# much audio the server has already processed. We don't use either here.
|
||||||
continue
|
continue
|
||||||
print("received:", data)
|
if data["type"] == "Word":
|
||||||
if data["type"] == "Word":
|
print(data["text"], end=" ", flush=True)
|
||||||
all_text.append(data["text"])
|
transcript.append(
|
||||||
transcript.append(
|
{
|
||||||
{
|
"text": data["text"],
|
||||||
"speaker": "SPEAKER_00",
|
"timestamp": [data["start_time"], data["start_time"]],
|
||||||
"text": data["text"],
|
}
|
||||||
"timestamp": [data["start_time"], data["start_time"]],
|
)
|
||||||
}
|
if data["type"] == "EndWord":
|
||||||
)
|
if len(transcript) > 0:
|
||||||
if data["type"] == "EndWord":
|
transcript[-1]["timestamp"][1] = data["stop_time"]
|
||||||
if len(transcript) > 0:
|
if data["type"] == "Marker":
|
||||||
transcript[-1]["timestamp"][1] = data["stop_time"]
|
# Received marker, stopping stream
|
||||||
if data["type"] == "Marker":
|
break
|
||||||
print("Received marker, stopping stream.")
|
|
||||||
break
|
return transcript
|
||||||
except websockets.ConnectionClosed:
|
|
||||||
print("Connection closed while receiving messages.")
|
|
||||||
finished = True
|
|
||||||
|
|
||||||
|
|
||||||
async def send_messages(websocket, rtf: float):
|
async def send_messages(websocket, rtf: float):
|
||||||
global finished
|
|
||||||
audio_data = load_and_process_audio(args.in_file)
|
audio_data = load_and_process_audio(args.in_file)
|
||||||
try:
|
|
||||||
# Start with a second of silence.
|
|
||||||
# This is needed for the 2.6B model for technical reasons.
|
|
||||||
chunk = {"type": "Audio", "pcm": [0.0] * 24000}
|
|
||||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
|
||||||
await websocket.send(msg)
|
|
||||||
|
|
||||||
chunk_size = 1920 # Send data in chunks
|
async def send_audio(audio: np.ndarray):
|
||||||
start_time = time.time()
|
await websocket.send(
|
||||||
for i in range(0, len(audio_data), chunk_size):
|
msgpack.packb(
|
||||||
chunk = {
|
{"type": "Audio", "pcm": [float(x) for x in audio]},
|
||||||
"type": "Audio",
|
use_single_float=True,
|
||||||
"pcm": [float(x) for x in audio_data[i : i + chunk_size]],
|
)
|
||||||
}
|
|
||||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
|
||||||
await websocket.send(msg)
|
|
||||||
expected_send_time = start_time + (i + 1) / 24000 / rtf
|
|
||||||
current_time = time.time()
|
|
||||||
if current_time < expected_send_time:
|
|
||||||
await asyncio.sleep(expected_send_time - current_time)
|
|
||||||
else:
|
|
||||||
await asyncio.sleep(0.001)
|
|
||||||
chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5}
|
|
||||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
|
||||||
await websocket.send(msg)
|
|
||||||
msg = msgpack.packb(
|
|
||||||
{"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True
|
|
||||||
)
|
)
|
||||||
await websocket.send(msg)
|
|
||||||
for _ in range(35):
|
# Start with a second of silence.
|
||||||
chunk = {"type": "Audio", "pcm": [0.0] * 1920}
|
# This is needed for the 2.6B model for technical reasons.
|
||||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
await send_audio([0.0] * SAMPLE_RATE)
|
||||||
await websocket.send(msg)
|
|
||||||
while True:
|
start_time = time.time()
|
||||||
if finished:
|
for i in range(0, len(audio_data), FRAME_SIZE):
|
||||||
break
|
await send_audio(audio_data[i : i + FRAME_SIZE])
|
||||||
await asyncio.sleep(1.0)
|
|
||||||
# Keep the connection alive as there is a 20s timeout on the rust side.
|
expected_send_time = start_time + (i + 1) / SAMPLE_RATE / rtf
|
||||||
await websocket.ping()
|
current_time = time.time()
|
||||||
except websockets.ConnectionClosed:
|
if current_time < expected_send_time:
|
||||||
print("Connection closed while sending messages.")
|
await asyncio.sleep(expected_send_time - current_time)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
await send_audio([0.0] * SAMPLE_RATE)
|
||||||
|
|
||||||
|
# Send a marker to indicate the end of the stream.
|
||||||
|
await websocket.send(
|
||||||
|
msgpack.packb({"type": "Marker", "id": 0}, use_single_float=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# We'll get back the marker once the corresponding audio has been transcribed,
|
||||||
|
# accounting for the delay of the model. That's why we need to send some silence
|
||||||
|
# after the marker, because the model will not return the marker immediately.
|
||||||
|
for _ in range(35):
|
||||||
|
await send_audio([0.0] * SAMPLE_RATE)
|
||||||
|
|
||||||
|
|
||||||
async def stream_audio(url: str, rtf: float):
|
async def stream_audio(url: str, api_key: str, rtf: float):
|
||||||
"""Stream audio data to a WebSocket server."""
|
"""Stream audio data to a WebSocket server."""
|
||||||
|
headers = {"kyutai-api-key": api_key}
|
||||||
|
|
||||||
async with websockets.connect(url, additional_headers=HEADERS) as websocket:
|
# Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
|
||||||
|
async with websockets.connect(url, additional_headers=headers) as websocket:
|
||||||
send_task = asyncio.create_task(send_messages(websocket, rtf))
|
send_task = asyncio.create_task(send_messages(websocket, rtf))
|
||||||
receive_task = asyncio.create_task(receive_messages(websocket))
|
receive_task = asyncio.create_task(receive_messages(websocket))
|
||||||
await asyncio.gather(send_task, receive_task)
|
_, transcript = await asyncio.gather(send_task, receive_task)
|
||||||
print("exiting")
|
|
||||||
|
return transcript
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("in_file")
|
parser.add_argument("in_file")
|
||||||
parser.add_argument("--transcript")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--url",
|
"--url",
|
||||||
help="The url of the server to which to send the audio",
|
help="The url of the server to which to send the audio",
|
||||||
default="ws://127.0.0.1:8080",
|
default="ws://127.0.0.1:8080",
|
||||||
)
|
)
|
||||||
parser.add_argument("--rtf", type=float, default=1.01)
|
parser.add_argument("--api-key", default="public_token")
|
||||||
|
parser.add_argument(
|
||||||
|
"--rtf",
|
||||||
|
type=float,
|
||||||
|
default=1.01,
|
||||||
|
help="The real-time factor of how fast to feed in the audio.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
url = f"{args.url}/api/asr-streaming"
|
url = f"{args.url}/api/asr-streaming"
|
||||||
asyncio.run(stream_audio(url, args.rtf))
|
transcript = asyncio.run(stream_audio(url, args.api_key, args.rtf))
|
||||||
print(" ".join(all_text))
|
|
||||||
if args.transcript is not None:
|
print()
|
||||||
with open(args.transcript, "w") as fobj:
|
print()
|
||||||
json.dump({"transcript": transcript}, fobj, indent=4)
|
for word in transcript:
|
||||||
|
print(
|
||||||
|
f"{word['timestamp'][0]:7.2f} -{word['timestamp'][1]:7.2f} {word['text']}"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -16,43 +16,59 @@ import numpy as np
|
||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
# Desired audio properties
|
SAMPLE_RATE = 24000
|
||||||
TARGET_SAMPLE_RATE = 24000
|
|
||||||
TARGET_CHANNELS = 1 # Mono
|
# The VAD has several prediction heads, each of which tries to determine whether there
|
||||||
audio_queue = asyncio.Queue()
|
# has been a pause of a given length. The lengths are 0.5, 1.0, 2.0, and 3.0 seconds.
|
||||||
|
# Lower indices predict pauses more aggressively. In Unmute, we use 2.0 seconds = index 2.
|
||||||
|
PAUSE_PREDICTION_HEAD_INDEX = 2
|
||||||
|
|
||||||
|
|
||||||
async def receive_messages(websocket):
|
async def receive_messages(websocket, show_vad: bool = False):
|
||||||
"""Receive and process messages from the WebSocket server."""
|
"""Receive and process messages from the WebSocket server."""
|
||||||
try:
|
try:
|
||||||
|
speech_started = False
|
||||||
async for message in websocket:
|
async for message in websocket:
|
||||||
data = msgpack.unpackb(message, raw=False)
|
data = msgpack.unpackb(message, raw=False)
|
||||||
if data["type"] == "Word":
|
|
||||||
|
# The Step message only gets sent if the model has semantic VAD available
|
||||||
|
if data["type"] == "Step" and show_vad:
|
||||||
|
pause_prediction = data["prs"][PAUSE_PREDICTION_HEAD_INDEX]
|
||||||
|
if pause_prediction > 0.5 and speech_started:
|
||||||
|
print("| ", end="", flush=True)
|
||||||
|
speech_started = False
|
||||||
|
|
||||||
|
elif data["type"] == "Word":
|
||||||
print(data["text"], end=" ", flush=True)
|
print(data["text"], end=" ", flush=True)
|
||||||
|
speech_started = True
|
||||||
except websockets.ConnectionClosed:
|
except websockets.ConnectionClosed:
|
||||||
print("Connection closed while receiving messages.")
|
print("Connection closed while receiving messages.")
|
||||||
|
|
||||||
|
|
||||||
async def send_messages(websocket):
|
async def send_messages(websocket, audio_queue):
|
||||||
"""Send audio data from microphone to WebSocket server."""
|
"""Send audio data from microphone to WebSocket server."""
|
||||||
try:
|
try:
|
||||||
# Start by draining the queue to avoid lags
|
# Start by draining the queue to avoid lags
|
||||||
while not audio_queue.empty():
|
while not audio_queue.empty():
|
||||||
await audio_queue.get()
|
await audio_queue.get()
|
||||||
|
|
||||||
print("Starting the transcription")
|
print("Starting the transcription")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
audio_data = await audio_queue.get()
|
audio_data = await audio_queue.get()
|
||||||
chunk = {"type": "Audio", "pcm": [float(x) for x in audio_data]}
|
chunk = {"type": "Audio", "pcm": [float(x) for x in audio_data]}
|
||||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||||
await websocket.send(msg)
|
await websocket.send(msg)
|
||||||
|
|
||||||
except websockets.ConnectionClosed:
|
except websockets.ConnectionClosed:
|
||||||
print("Connection closed while sending messages.")
|
print("Connection closed while sending messages.")
|
||||||
|
|
||||||
|
|
||||||
async def stream_audio(url: str, api_key: str):
|
async def stream_audio(url: str, api_key: str, show_vad: bool):
|
||||||
"""Stream audio data to a WebSocket server."""
|
"""Stream audio data to a WebSocket server."""
|
||||||
print("Starting microphone recording...")
|
print("Starting microphone recording...")
|
||||||
print("Press Ctrl+C to stop recording")
|
print("Press Ctrl+C to stop recording")
|
||||||
|
audio_queue = asyncio.Queue()
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
|
@ -63,16 +79,19 @@ async def stream_audio(url: str, api_key: str):
|
||||||
|
|
||||||
# Start audio stream
|
# Start audio stream
|
||||||
with sd.InputStream(
|
with sd.InputStream(
|
||||||
samplerate=TARGET_SAMPLE_RATE,
|
samplerate=SAMPLE_RATE,
|
||||||
channels=TARGET_CHANNELS,
|
channels=1,
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
callback=audio_callback,
|
callback=audio_callback,
|
||||||
blocksize=1920, # 80ms blocks
|
blocksize=1920, # 80ms blocks
|
||||||
):
|
):
|
||||||
headers = {"kyutai-api-key": api_key}
|
headers = {"kyutai-api-key": api_key}
|
||||||
|
# Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
|
||||||
async with websockets.connect(url, additional_headers=headers) as websocket:
|
async with websockets.connect(url, additional_headers=headers) as websocket:
|
||||||
send_task = asyncio.create_task(send_messages(websocket))
|
send_task = asyncio.create_task(send_messages(websocket, audio_queue))
|
||||||
receive_task = asyncio.create_task(receive_messages(websocket))
|
receive_task = asyncio.create_task(
|
||||||
|
receive_messages(websocket, show_vad=show_vad)
|
||||||
|
)
|
||||||
await asyncio.gather(send_task, receive_task)
|
await asyncio.gather(send_task, receive_task)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -83,18 +102,23 @@ if __name__ == "__main__":
|
||||||
help="The URL of the server to which to send the audio",
|
help="The URL of the server to which to send the audio",
|
||||||
default="ws://127.0.0.1:8080",
|
default="ws://127.0.0.1:8080",
|
||||||
)
|
)
|
||||||
parser.add_argument("--api-key", default="open_token")
|
parser.add_argument("--api-key", default="public_token")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--list-devices", action="store_true", help="List available audio devices"
|
"--list-devices", action="store_true", help="List available audio devices"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device", type=int, help="Input device ID (use --list-devices to see options)"
|
"--device", type=int, help="Input device ID (use --list-devices to see options)"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--show-vad",
|
||||||
|
action="store_true",
|
||||||
|
help="Visualize the predictions of the semantic voice activity detector with a '|' symbol",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
def handle_sigint(signum, frame):
|
def handle_sigint(signum, frame):
|
||||||
print("Interrupted by user")
|
print("Interrupted by user") # Don't complain about KeyboardInterrupt
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, handle_sigint)
|
signal.signal(signal.SIGINT, handle_sigint)
|
||||||
|
|
@ -108,4 +132,4 @@ if __name__ == "__main__":
|
||||||
sd.default.device[0] = args.device # Set input device
|
sd.default.device[0] = args.device # Set input device
|
||||||
|
|
||||||
url = f"{args.url}/api/asr-streaming"
|
url = f"{args.url}/api/asr-streaming"
|
||||||
asyncio.run(stream_audio(url, args.api_key))
|
asyncio.run(stream_audio(url, args.api_key, args.show_vad))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user