Merge remote-tracking branch 'refs/remotes/origin/main'
This commit is contained in:
commit
61d947d1eb
56
README.md
56
README.md
|
|
@ -3,10 +3,17 @@ Delayed Streams Modeling (DSM) is a flexible formulation for streaming, multimod
|
||||||
|
|
||||||
## Speech To Text
|
## Speech To Text
|
||||||
|
|
||||||
### English only model
|
DSM can be used to build streaming speech to text models. These models can be
|
||||||
The main model handles english only, it has ~2.6b parameters.
|
batched for efficiency, return word level timestamps, and are great for
|
||||||
|
interactive applications. We provide two such models, these models are
|
||||||
|
characterized by their size as well as the delay it takes for audio to be
|
||||||
|
transcribed into text. We provide two such models:
|
||||||
|
- An English only model with ~2.6b parameters using a 2.5 second delay,
|
||||||
|
`kyutai/stt-2.6b-en`.
|
||||||
|
- An English and French model with ~1b parameters using a 0.5 second delay,
|
||||||
|
`kyutai/stt-1b-en_fr`.
|
||||||
|
|
||||||
#### PyTorch implementation
|
### PyTorch implementation
|
||||||
[[Hugging Face]](https://huggingface.co/kyutai/stt-2.6b-en)
|
[[Hugging Face]](https://huggingface.co/kyutai/stt-2.6b-en)
|
||||||
<a target="_blank" href="https://colab.research.google.com/drive/1mc0Q-FoHxU2pEvId8rTdS4q1r1zorJhS?usp=sharing">
|
<a target="_blank" href="https://colab.research.google.com/drive/1mc0Q-FoHxU2pEvId8rTdS4q1r1zorJhS?usp=sharing">
|
||||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||||
|
|
@ -20,7 +27,7 @@ with version 0.2.5 or later, which can be installed via pip.
|
||||||
python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
|
python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
|
||||||
```
|
```
|
||||||
|
|
||||||
#### MLX implementation
|
### MLX implementation
|
||||||
[[Hugging Face]](https://huggingface.co/kyutai/stt-2.6b-en-mlx)
|
[[Hugging Face]](https://huggingface.co/kyutai/stt-2.6b-en-mlx)
|
||||||
|
|
||||||
This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/)
|
This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/)
|
||||||
|
|
@ -31,7 +38,17 @@ with version 0.2.5 or later, which can be installed via pip.
|
||||||
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0
|
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Rust implementation
|
### Rust implementation
|
||||||
|
[[Hugging Face]](https://huggingface.co/kyutai/stt-2.6b-en-candle)
|
||||||
|
|
||||||
|
A standalone Rust example is provided in the `stt-rs` directory in this repo.
|
||||||
|
This can be used as follows:
|
||||||
|
```bash
|
||||||
|
cd stt-rs
|
||||||
|
cargo run --features cuda -r -- bria.mp3
|
||||||
|
```
|
||||||
|
|
||||||
|
### Rust server
|
||||||
[[Hugging Face]](https://huggingface.co/kyutai/stt-2.6b-en-candle)
|
[[Hugging Face]](https://huggingface.co/kyutai/stt-2.6b-en-candle)
|
||||||
|
|
||||||
The Rust implementation provides a server that can process multiple streaming
|
The Rust implementation provides a server that can process multiple streaming
|
||||||
|
|
@ -64,35 +81,6 @@ The script simulates some real-time processing of the audio. Faster processing
|
||||||
can be triggered by setting the real-time factor, e.g. `--rtf 500` will process
|
can be triggered by setting the real-time factor, e.g. `--rtf 500` will process
|
||||||
the data as fast as possible.
|
the data as fast as possible.
|
||||||
|
|
||||||
### English + French model
|
|
||||||
This model has ~1b parameters and supports both English and French.
|
|
||||||
|
|
||||||
#### PyTorch implementation
|
|
||||||
[[Hugging Face]](https://huggingface.co/kyutai/stt-1b-en_fr)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
|
|
||||||
python -m moshi.run_inference --hf-repo kyutai/stt-1b-en_fr bria.mp3
|
|
||||||
```
|
|
||||||
|
|
||||||
#### MLX implementation
|
|
||||||
[[Hugging Face]](https://huggingface.co/kyutai/stt-1b-en_fr-mlx)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
|
|
||||||
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-1b-en_fr-mlx bria.mp3 --temp 0
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Rust implementation
|
|
||||||
[[Hugging Face]](https://huggingface.co/kyutai/stt-1b-en_fr-candle)
|
|
||||||
|
|
||||||
The only difference with the en only model is the config file used when
|
|
||||||
launching the server.
|
|
||||||
```bash
|
|
||||||
moshi-server worker --config configs/config-stt-enfr-hf.toml
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Text To Speech
|
## Text To Speech
|
||||||
|
|
||||||
We're in the process of open-sourcing our TTS models. Check back for updates!
|
We're in the process of open-sourcing our TTS models. Check back for updates!
|
||||||
|
|
|
||||||
3324
stt-rs/Cargo.lock
generated
Normal file
3324
stt-rs/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
30
stt-rs/Cargo.toml
Normal file
30
stt-rs/Cargo.toml
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
[package]
|
||||||
|
name = "kyutai-stt-rs"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1.0"
|
||||||
|
candle = { version = "0.9.1", package = "candle-core" }
|
||||||
|
candle-nn = "0.9.1"
|
||||||
|
clap = { version = "4.4.12", features = ["derive"] }
|
||||||
|
hf-hub = "0.3.2"
|
||||||
|
kaudio = "0.2.1"
|
||||||
|
moshi = "0.6.1"
|
||||||
|
sentencepiece = "0.11.3"
|
||||||
|
serde = { version = "1.0.210", features = ["derive"] }
|
||||||
|
serde_json = "1.0.115"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||||
|
cudnn = ["candle/cudnn", "candle-nn/cudnn"]
|
||||||
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
debug = true
|
||||||
|
|
||||||
|
[profile.release-no-debug]
|
||||||
|
inherits = "release"
|
||||||
|
debug = false
|
||||||
|
|
||||||
184
stt-rs/src/main.rs
Normal file
184
stt-rs/src/main.rs
Normal file
|
|
@ -0,0 +1,184 @@
|
||||||
|
// Copyright (c) Kyutai, all rights reserved.
|
||||||
|
// This source code is licensed under the license found in the
|
||||||
|
// LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle::{Device, Tensor};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
#[derive(Debug, Parser)]
|
||||||
|
struct Args {
|
||||||
|
/// The audio input file, in wav/mp3/ogg/... format.
|
||||||
|
in_file: String,
|
||||||
|
|
||||||
|
/// The repo where to get the model from.
|
||||||
|
#[arg(long, default_value = "kyutai/stt-1b-en_fr-candle")]
|
||||||
|
hf_repo: String,
|
||||||
|
|
||||||
|
/// Run the model on cpu.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device(cpu: bool) -> Result<Device> {
|
||||||
|
if cpu {
|
||||||
|
Ok(Device::Cpu)
|
||||||
|
} else if candle::utils::cuda_is_available() {
|
||||||
|
Ok(Device::new_cuda(0)?)
|
||||||
|
} else if candle::utils::metal_is_available() {
|
||||||
|
Ok(Device::new_metal(0)?)
|
||||||
|
} else {
|
||||||
|
Ok(Device::Cpu)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize)]
|
||||||
|
struct Config {
|
||||||
|
mimi_name: String,
|
||||||
|
tokenizer_name: String,
|
||||||
|
card: usize,
|
||||||
|
text_card: usize,
|
||||||
|
dim: usize,
|
||||||
|
n_q: usize,
|
||||||
|
context: usize,
|
||||||
|
max_period: f64,
|
||||||
|
num_heads: usize,
|
||||||
|
num_layers: usize,
|
||||||
|
causal: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
fn model_config(&self) -> moshi::lm::Config {
|
||||||
|
let lm_cfg = moshi::transformer::Config {
|
||||||
|
d_model: self.dim,
|
||||||
|
num_heads: self.num_heads,
|
||||||
|
num_layers: self.num_layers,
|
||||||
|
dim_feedforward: self.dim * 4,
|
||||||
|
causal: self.causal,
|
||||||
|
norm_first: true,
|
||||||
|
bias_ff: false,
|
||||||
|
bias_attn: false,
|
||||||
|
layer_scale: None,
|
||||||
|
context: self.context,
|
||||||
|
max_period: self.max_period as usize,
|
||||||
|
use_conv_block: false,
|
||||||
|
use_conv_bias: true,
|
||||||
|
cross_attention: None,
|
||||||
|
gating: Some(candle_nn::Activation::Silu),
|
||||||
|
norm: moshi::NormType::RmsNorm,
|
||||||
|
positional_embedding: moshi::transformer::PositionalEmbedding::Rope,
|
||||||
|
conv_layout: false,
|
||||||
|
conv_kernel_size: 3,
|
||||||
|
kv_repeat: 1,
|
||||||
|
max_seq_len: 4096 * 4,
|
||||||
|
shared_cross_attn: false,
|
||||||
|
};
|
||||||
|
moshi::lm::Config {
|
||||||
|
transformer: lm_cfg,
|
||||||
|
depformer: None,
|
||||||
|
audio_vocab_size: self.card + 1,
|
||||||
|
text_in_vocab_size: self.text_card + 1,
|
||||||
|
text_out_vocab_size: self.text_card,
|
||||||
|
audio_codebooks: self.n_q,
|
||||||
|
conditioners: Default::default(),
|
||||||
|
extra_heads: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Model {
|
||||||
|
state: moshi::asr::State,
|
||||||
|
text_tokenizer: sentencepiece::SentencePieceProcessor,
|
||||||
|
dev: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn load_from_hf(hf_repo: &str, dev: &Device) -> Result<Self> {
|
||||||
|
let dtype = dev.bf16_default_to_f32();
|
||||||
|
|
||||||
|
// Retrieve the model files from the Hugging Face Hub
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let repo = api.model(hf_repo.to_string());
|
||||||
|
let config_file = repo.get("config.json")?;
|
||||||
|
let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_file)?)?;
|
||||||
|
let tokenizer_file = repo.get(&config.tokenizer_name)?;
|
||||||
|
let model_file = repo.get("model.safetensors")?;
|
||||||
|
let mimi_file = repo.get(&config.mimi_name)?;
|
||||||
|
|
||||||
|
let text_tokenizer = sentencepiece::SentencePieceProcessor::open(&tokenizer_file)?;
|
||||||
|
let vb_lm =
|
||||||
|
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? };
|
||||||
|
let audio_tokenizer = moshi::mimi::load(mimi_file.to_str().unwrap(), Some(32), dev)?;
|
||||||
|
let lm = moshi::lm::LmModel::new(
|
||||||
|
&config.model_config(),
|
||||||
|
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
|
||||||
|
)?;
|
||||||
|
let state = moshi::asr::State::new(1, 0, 0., audio_tokenizer, lm)?;
|
||||||
|
Ok(Model {
|
||||||
|
state,
|
||||||
|
text_tokenizer,
|
||||||
|
dev: dev.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, pcm: &[f32]) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
for pcm in pcm.chunks(1920) {
|
||||||
|
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
|
||||||
|
let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?;
|
||||||
|
let mut prev_text_token = 0;
|
||||||
|
for asr_msg in asr_msgs.iter() {
|
||||||
|
match asr_msg {
|
||||||
|
moshi::asr::AsrMsg::Step { .. } | moshi::asr::AsrMsg::EndWord { .. } => {}
|
||||||
|
moshi::asr::AsrMsg::Word { tokens, .. } => {
|
||||||
|
for &text_token in tokens.iter() {
|
||||||
|
let s = {
|
||||||
|
let prev_ids =
|
||||||
|
self.text_tokenizer.decode_piece_ids(&[prev_text_token]);
|
||||||
|
let ids = self
|
||||||
|
.text_tokenizer
|
||||||
|
.decode_piece_ids(&[prev_text_token, text_token]);
|
||||||
|
prev_text_token = text_token;
|
||||||
|
prev_ids.and_then(|prev_ids| {
|
||||||
|
ids.map(|ids| {
|
||||||
|
if ids.len() > prev_ids.len() {
|
||||||
|
ids[prev_ids.len()..].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
print!("{s}");
|
||||||
|
std::io::stdout().flush()?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let device = device(args.cpu)?;
|
||||||
|
println!("Using device: {:?}", device);
|
||||||
|
|
||||||
|
println!("Loading audio file from: {}", args.in_file);
|
||||||
|
let (pcm, sample_rate) = kaudio::pcm_decode(args.in_file)?;
|
||||||
|
let mut pcm = if sample_rate != 24_000 {
|
||||||
|
kaudio::resample(&pcm, sample_rate as usize, 24_000)?
|
||||||
|
} else {
|
||||||
|
pcm
|
||||||
|
};
|
||||||
|
// Add some silence at the end to ensure all the audio is processed.
|
||||||
|
pcm.resize(pcm.len() + 1920 * 32, 0.0);
|
||||||
|
println!("Loading model from repository: {}", args.hf_repo);
|
||||||
|
let mut model = Model::load_from_hf(&args.hf_repo, &device)?;
|
||||||
|
println!("Running inference");
|
||||||
|
model.run(&pcm)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user