diff --git a/README.md b/README.md index 55b9c7a..48224f4 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,8 @@ This can be used as follows: cd stt-rs cargo run --features cuda -r -- bria.mp3 ``` -You can get the timestamps by adding the `--timestamps` flag. +You can get the timestamps by adding the `--timestamps` flag, and see the output +of the semantic VAD by adding the `--vad` flag. ### MLX implementation diff --git a/stt-rs/src/main.rs b/stt-rs/src/main.rs index 9e67ac0..43d8243 100644 --- a/stt-rs/src/main.rs +++ b/stt-rs/src/main.rs @@ -22,6 +22,10 @@ struct Args { /// Display word level timestamps. #[arg(long)] timestamps: bool, + + /// Display the level of voice activity detection (VAD). + #[arg(long)] + vad: bool, } fn device(cpu: bool) -> Result { @@ -59,7 +63,7 @@ struct Config { } impl Config { - fn model_config(&self) -> moshi::lm::Config { + fn model_config(&self, vad: bool) -> moshi::lm::Config { let lm_cfg = moshi::transformer::Config { d_model: self.dim, num_heads: self.num_heads, @@ -84,6 +88,14 @@ impl Config { max_seq_len: 4096 * 4, shared_cross_attn: false, }; + let extra_heads = if vad { + Some(moshi::lm::ExtraHeadsConfig { + num_heads: 4, + dim: 6, + }) + } else { + None + }; moshi::lm::Config { transformer: lm_cfg, depformer: None, @@ -92,7 +104,7 @@ impl Config { text_out_vocab_size: self.text_card, audio_codebooks: self.n_q, conditioners: Default::default(), - extra_heads: None, + extra_heads, } } } @@ -101,17 +113,18 @@ struct Model { state: moshi::asr::State, text_tokenizer: sentencepiece::SentencePieceProcessor, timestamps: bool, + vad: bool, config: Config, dev: Device, } impl Model { - fn load_from_hf(hf_repo: &str, timestamps: bool, dev: &Device) -> Result { + fn load_from_hf(args: &Args, dev: &Device) -> Result { let dtype = dev.bf16_default_to_f32(); // Retrieve the model files from the Hugging Face Hub let api = hf_hub::api::sync::Api::new()?; - let repo = api.model(hf_repo.to_string()); + let repo = api.model(args.hf_repo.to_string()); let config_file = repo.get("config.json")?; let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_file)?)?; let tokenizer_file = repo.get(&config.tokenizer_name)?; @@ -123,7 +136,7 @@ impl Model { unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? }; let audio_tokenizer = moshi::mimi::load(mimi_file.to_str().unwrap(), Some(32), dev)?; let lm = moshi::lm::LmModel::new( - &config.model_config(), + &config.model_config(args.vad), moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm), )?; let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize; @@ -132,7 +145,8 @@ impl Model { state, config, text_tokenizer, - timestamps, + timestamps: args.timestamps, + vad: args.vad, dev: dev.clone(), }) } @@ -151,13 +165,26 @@ impl Model { pcm.resize(pcm.len() + suffix + 24000, 0.0); let mut last_word = None; + let mut printed_eot = false; for pcm in pcm.chunks(1920) { let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?; let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?; for asr_msg in asr_msgs.iter() { match asr_msg { - moshi::asr::AsrMsg::Step { .. } => {} + moshi::asr::AsrMsg::Step { prs, .. } => { + // prs is the probability of voice activity for different time horizons. + // The first element is the most recent time horizon. + if self.vad && prs[2][0] > 0.5 && !printed_eot { + printed_eot = true; + if !self.timestamps { + print!(" ", prs[2][0]); + } else { + println!("", prs[2][0]); + } + } + } moshi::asr::AsrMsg::EndWord { stop_time, .. } => { + printed_eot = false; if self.timestamps { if let Some((word, start_time)) = last_word.take() { println!("[{start_time:5.2}-{stop_time:5.2}] {word}"); @@ -167,6 +194,7 @@ impl Model { moshi::asr::AsrMsg::Word { tokens, start_time, .. } => { + printed_eot = false; let word = self .text_tokenizer .decode_piece_ids(tokens) @@ -198,14 +226,14 @@ fn main() -> Result<()> { println!("Using device: {:?}", device); println!("Loading audio file from: {}", args.in_file); - let (pcm, sample_rate) = kaudio::pcm_decode(args.in_file)?; + let (pcm, sample_rate) = kaudio::pcm_decode(&args.in_file)?; let pcm = if sample_rate != 24_000 { kaudio::resample(&pcm, sample_rate as usize, 24_000)? } else { pcm }; println!("Loading model from repository: {}", args.hf_repo); - let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?; + let mut model = Model::load_from_hf(&args, &device)?; println!("Running inference"); model.run(pcm)?; Ok(())