From 91fb68acc46bdd837c4a7e9b11977fa655e0787d Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 19 Jun 2025 09:57:53 +0200 Subject: [PATCH] Add timestamps to the rust example. --- stt-rs/src/main.rs | 67 +++++++++++++++++++++++++++++----------------- 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/stt-rs/src/main.rs b/stt-rs/src/main.rs index 2d2a447..31e02ab 100644 --- a/stt-rs/src/main.rs +++ b/stt-rs/src/main.rs @@ -18,6 +18,10 @@ struct Args { /// Run the model on cpu. #[arg(long)] cpu: bool, + + /// Display word level timestamps. + #[arg(long)] + timestamps: bool, } fn device(cpu: bool) -> Result { @@ -32,6 +36,11 @@ fn device(cpu: bool) -> Result { } } +#[derive(Debug, serde::Deserialize)] +struct SttConfig { + audio_delay_seconds: f64, +} + #[derive(Debug, serde::Deserialize)] struct Config { mimi_name: String, @@ -45,6 +54,7 @@ struct Config { num_heads: usize, num_layers: usize, causal: bool, + stt_config: SttConfig, } impl Config { @@ -89,11 +99,12 @@ impl Config { struct Model { state: moshi::asr::State, text_tokenizer: sentencepiece::SentencePieceProcessor, + timestamps: bool, dev: Device, } impl Model { - fn load_from_hf(hf_repo: &str, dev: &Device) -> Result { + fn load_from_hf(hf_repo: &str, timestamps: bool, dev: &Device) -> Result { let dtype = dev.bf16_default_to_f32(); // Retrieve the model files from the Hugging Face Hub @@ -113,10 +124,12 @@ impl Model { &config.model_config(), moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm), )?; - let state = moshi::asr::State::new(1, 0, 0., audio_tokenizer, lm)?; + let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize; + let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?; Ok(Model { state, text_tokenizer, + timestamps, dev: dev.clone(), }) } @@ -124,39 +137,43 @@ impl Model { fn run(&mut self, pcm: &[f32]) -> Result<()> { use std::io::Write; + let mut last_word = None; for pcm in pcm.chunks(1920) { let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?; let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?; - let mut prev_text_token = 0; for asr_msg in asr_msgs.iter() { match asr_msg { - moshi::asr::AsrMsg::Step { .. } | moshi::asr::AsrMsg::EndWord { .. } => {} - moshi::asr::AsrMsg::Word { tokens, .. } => { - for &text_token in tokens.iter() { - let s = { - let prev_ids = - self.text_tokenizer.decode_piece_ids(&[prev_text_token]); - let ids = self - .text_tokenizer - .decode_piece_ids(&[prev_text_token, text_token]); - prev_text_token = text_token; - prev_ids.and_then(|prev_ids| { - ids.map(|ids| { - if ids.len() > prev_ids.len() { - ids[prev_ids.len()..].to_string() - } else { - String::new() - } - }) - })? - }; - print!("{s}"); + moshi::asr::AsrMsg::Step { .. } => {} + moshi::asr::AsrMsg::EndWord { stop_time, .. } => { + if self.timestamps { + if let Some((word, start_time)) = last_word.take() { + println!("[{start_time:5.2}-{stop_time:5.2}] {word}"); + } + } + } + moshi::asr::AsrMsg::Word { + tokens, start_time, .. + } => { + let word = self + .text_tokenizer + .decode_piece_ids(tokens) + .unwrap_or_else(|_| String::new()); + if !self.timestamps { + print!(" {word}"); std::io::stdout().flush()? + } else { + if let Some((word, prev_start_time)) = last_word.take() { + println!("[{prev_start_time:5.2}-{start_time:5.2}] {word}"); + } + last_word = Some((word, *start_time)); } } } } } + if let Some((word, start_time)) = last_word.take() { + println!("[{start_time:5.2}- ] {word}"); + } println!(); Ok(()) } @@ -177,7 +194,7 @@ fn main() -> Result<()> { // Add some silence at the end to ensure all the audio is processed. pcm.resize(pcm.len() + 1920 * 32, 0.0); println!("Loading model from repository: {}", args.hf_repo); - let mut model = Model::load_from_hf(&args.hf_repo, &device)?; + let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?; println!("Running inference"); model.run(&pcm)?; Ok(())