From 35c4ea47d8143cdffaee6678bccf213aa48ae2e6 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 19 Jun 2025 10:12:03 +0200 Subject: [PATCH] Use the audio prefix in the rust inference. --- README.md | 1 + stt-rs/src/main.rs | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index ccf749f..22e580d 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ This can be used as follows: cd stt-rs cargo run --features cuda -r -- bria.mp3 ``` +You can get the timestamps by adding the `--timestamps` flag. ### MLX implementation diff --git a/stt-rs/src/main.rs b/stt-rs/src/main.rs index 31e02ab..9e67ac0 100644 --- a/stt-rs/src/main.rs +++ b/stt-rs/src/main.rs @@ -38,6 +38,7 @@ fn device(cpu: bool) -> Result { #[derive(Debug, serde::Deserialize)] struct SttConfig { + audio_silence_prefix_seconds: f64, audio_delay_seconds: f64, } @@ -100,6 +101,7 @@ struct Model { state: moshi::asr::State, text_tokenizer: sentencepiece::SentencePieceProcessor, timestamps: bool, + config: Config, dev: Device, } @@ -128,15 +130,26 @@ impl Model { let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?; Ok(Model { state, + config, text_tokenizer, timestamps, dev: dev.clone(), }) } - fn run(&mut self, pcm: &[f32]) -> Result<()> { + fn run(&mut self, mut pcm: Vec) -> Result<()> { use std::io::Write; + // Add the silence prefix to the audio. + if self.config.stt_config.audio_silence_prefix_seconds > 0.0 { + let silence_len = + (self.config.stt_config.audio_silence_prefix_seconds * 24000.0) as usize; + pcm.splice(0..0, vec![0.0; silence_len]); + } + // Add some silence at the end to ensure all the audio is processed. + let suffix = (self.config.stt_config.audio_delay_seconds * 24000.0) as usize; + pcm.resize(pcm.len() + suffix + 24000, 0.0); + let mut last_word = None; for pcm in pcm.chunks(1920) { let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?; @@ -186,16 +199,14 @@ fn main() -> Result<()> { println!("Loading audio file from: {}", args.in_file); let (pcm, sample_rate) = kaudio::pcm_decode(args.in_file)?; - let mut pcm = if sample_rate != 24_000 { + let pcm = if sample_rate != 24_000 { kaudio::resample(&pcm, sample_rate as usize, 24_000)? } else { pcm }; - // Add some silence at the end to ensure all the audio is processed. - pcm.resize(pcm.len() + 1920 * 32, 0.0); println!("Loading model from repository: {}", args.hf_repo); let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?; println!("Running inference"); - model.run(&pcm)?; + model.run(pcm)?; Ok(()) }