Add a VAD example in the rust codebase.
This commit is contained in:
parent
3282de0559
commit
d473deddaf
|
|
@ -110,7 +110,8 @@ This can be used as follows:
|
||||||
cd stt-rs
|
cd stt-rs
|
||||||
cargo run --features cuda -r -- bria.mp3
|
cargo run --features cuda -r -- bria.mp3
|
||||||
```
|
```
|
||||||
You can get the timestamps by adding the `--timestamps` flag.
|
You can get the timestamps by adding the `--timestamps` flag, and see the output
|
||||||
|
of the semantic VAD by adding the `--vad` flag.
|
||||||
|
|
||||||
### MLX implementation
|
### MLX implementation
|
||||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;">
|
<a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;">
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,10 @@ struct Args {
|
||||||
/// Display word level timestamps.
|
/// Display word level timestamps.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
timestamps: bool,
|
timestamps: bool,
|
||||||
|
|
||||||
|
/// Display the level of voice activity detection (VAD).
|
||||||
|
#[arg(long)]
|
||||||
|
vad: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn device(cpu: bool) -> Result<Device> {
|
fn device(cpu: bool) -> Result<Device> {
|
||||||
|
|
@ -59,7 +63,7 @@ struct Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
fn model_config(&self) -> moshi::lm::Config {
|
fn model_config(&self, vad: bool) -> moshi::lm::Config {
|
||||||
let lm_cfg = moshi::transformer::Config {
|
let lm_cfg = moshi::transformer::Config {
|
||||||
d_model: self.dim,
|
d_model: self.dim,
|
||||||
num_heads: self.num_heads,
|
num_heads: self.num_heads,
|
||||||
|
|
@ -84,6 +88,14 @@ impl Config {
|
||||||
max_seq_len: 4096 * 4,
|
max_seq_len: 4096 * 4,
|
||||||
shared_cross_attn: false,
|
shared_cross_attn: false,
|
||||||
};
|
};
|
||||||
|
let extra_heads = if vad {
|
||||||
|
Some(moshi::lm::ExtraHeadsConfig {
|
||||||
|
num_heads: 4,
|
||||||
|
dim: 6,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
moshi::lm::Config {
|
moshi::lm::Config {
|
||||||
transformer: lm_cfg,
|
transformer: lm_cfg,
|
||||||
depformer: None,
|
depformer: None,
|
||||||
|
|
@ -92,7 +104,7 @@ impl Config {
|
||||||
text_out_vocab_size: self.text_card,
|
text_out_vocab_size: self.text_card,
|
||||||
audio_codebooks: self.n_q,
|
audio_codebooks: self.n_q,
|
||||||
conditioners: Default::default(),
|
conditioners: Default::default(),
|
||||||
extra_heads: None,
|
extra_heads,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -101,17 +113,18 @@ struct Model {
|
||||||
state: moshi::asr::State,
|
state: moshi::asr::State,
|
||||||
text_tokenizer: sentencepiece::SentencePieceProcessor,
|
text_tokenizer: sentencepiece::SentencePieceProcessor,
|
||||||
timestamps: bool,
|
timestamps: bool,
|
||||||
|
vad: bool,
|
||||||
config: Config,
|
config: Config,
|
||||||
dev: Device,
|
dev: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
fn load_from_hf(hf_repo: &str, timestamps: bool, dev: &Device) -> Result<Self> {
|
fn load_from_hf(args: &Args, dev: &Device) -> Result<Self> {
|
||||||
let dtype = dev.bf16_default_to_f32();
|
let dtype = dev.bf16_default_to_f32();
|
||||||
|
|
||||||
// Retrieve the model files from the Hugging Face Hub
|
// Retrieve the model files from the Hugging Face Hub
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let repo = api.model(hf_repo.to_string());
|
let repo = api.model(args.hf_repo.to_string());
|
||||||
let config_file = repo.get("config.json")?;
|
let config_file = repo.get("config.json")?;
|
||||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_file)?)?;
|
let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_file)?)?;
|
||||||
let tokenizer_file = repo.get(&config.tokenizer_name)?;
|
let tokenizer_file = repo.get(&config.tokenizer_name)?;
|
||||||
|
|
@ -123,7 +136,7 @@ impl Model {
|
||||||
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? };
|
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? };
|
||||||
let audio_tokenizer = moshi::mimi::load(mimi_file.to_str().unwrap(), Some(32), dev)?;
|
let audio_tokenizer = moshi::mimi::load(mimi_file.to_str().unwrap(), Some(32), dev)?;
|
||||||
let lm = moshi::lm::LmModel::new(
|
let lm = moshi::lm::LmModel::new(
|
||||||
&config.model_config(),
|
&config.model_config(args.vad),
|
||||||
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
|
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
|
||||||
)?;
|
)?;
|
||||||
let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize;
|
let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize;
|
||||||
|
|
@ -132,7 +145,8 @@ impl Model {
|
||||||
state,
|
state,
|
||||||
config,
|
config,
|
||||||
text_tokenizer,
|
text_tokenizer,
|
||||||
timestamps,
|
timestamps: args.timestamps,
|
||||||
|
vad: args.vad,
|
||||||
dev: dev.clone(),
|
dev: dev.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -151,13 +165,26 @@ impl Model {
|
||||||
pcm.resize(pcm.len() + suffix + 24000, 0.0);
|
pcm.resize(pcm.len() + suffix + 24000, 0.0);
|
||||||
|
|
||||||
let mut last_word = None;
|
let mut last_word = None;
|
||||||
|
let mut printed_eot = false;
|
||||||
for pcm in pcm.chunks(1920) {
|
for pcm in pcm.chunks(1920) {
|
||||||
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
|
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
|
||||||
let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?;
|
let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?;
|
||||||
for asr_msg in asr_msgs.iter() {
|
for asr_msg in asr_msgs.iter() {
|
||||||
match asr_msg {
|
match asr_msg {
|
||||||
moshi::asr::AsrMsg::Step { .. } => {}
|
moshi::asr::AsrMsg::Step { prs, .. } => {
|
||||||
|
// prs is the probability of voice activity for different time horizons.
|
||||||
|
// The first element is the most recent time horizon.
|
||||||
|
if self.vad && prs[2][0] > 0.5 && !printed_eot {
|
||||||
|
printed_eot = true;
|
||||||
|
if !self.timestamps {
|
||||||
|
print!(" <endofturn pr={}>", prs[2][0]);
|
||||||
|
} else {
|
||||||
|
println!("<endofturn pr={}>", prs[2][0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
moshi::asr::AsrMsg::EndWord { stop_time, .. } => {
|
moshi::asr::AsrMsg::EndWord { stop_time, .. } => {
|
||||||
|
printed_eot = false;
|
||||||
if self.timestamps {
|
if self.timestamps {
|
||||||
if let Some((word, start_time)) = last_word.take() {
|
if let Some((word, start_time)) = last_word.take() {
|
||||||
println!("[{start_time:5.2}-{stop_time:5.2}] {word}");
|
println!("[{start_time:5.2}-{stop_time:5.2}] {word}");
|
||||||
|
|
@ -167,6 +194,7 @@ impl Model {
|
||||||
moshi::asr::AsrMsg::Word {
|
moshi::asr::AsrMsg::Word {
|
||||||
tokens, start_time, ..
|
tokens, start_time, ..
|
||||||
} => {
|
} => {
|
||||||
|
printed_eot = false;
|
||||||
let word = self
|
let word = self
|
||||||
.text_tokenizer
|
.text_tokenizer
|
||||||
.decode_piece_ids(tokens)
|
.decode_piece_ids(tokens)
|
||||||
|
|
@ -198,14 +226,14 @@ fn main() -> Result<()> {
|
||||||
println!("Using device: {:?}", device);
|
println!("Using device: {:?}", device);
|
||||||
|
|
||||||
println!("Loading audio file from: {}", args.in_file);
|
println!("Loading audio file from: {}", args.in_file);
|
||||||
let (pcm, sample_rate) = kaudio::pcm_decode(args.in_file)?;
|
let (pcm, sample_rate) = kaudio::pcm_decode(&args.in_file)?;
|
||||||
let pcm = if sample_rate != 24_000 {
|
let pcm = if sample_rate != 24_000 {
|
||||||
kaudio::resample(&pcm, sample_rate as usize, 24_000)?
|
kaudio::resample(&pcm, sample_rate as usize, 24_000)?
|
||||||
} else {
|
} else {
|
||||||
pcm
|
pcm
|
||||||
};
|
};
|
||||||
println!("Loading model from repository: {}", args.hf_repo);
|
println!("Loading model from repository: {}", args.hf_repo);
|
||||||
let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?;
|
let mut model = Model::load_from_hf(&args, &device)?;
|
||||||
println!("Running inference");
|
println!("Running inference");
|
||||||
model.run(pcm)?;
|
model.run(pcm)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user