Use the audio prefix in the rust inference.

This commit is contained in:
laurent 2025-06-19 10:12:03 +02:00
parent 91fb68acc4
commit 35c4ea47d8
2 changed files with 17 additions and 5 deletions

View File

@ -101,6 +101,7 @@ This can be used as follows:
cd stt-rs cd stt-rs
cargo run --features cuda -r -- bria.mp3 cargo run --features cuda -r -- bria.mp3
``` ```
You can get the timestamps by adding the `--timestamps` flag.
### MLX implementation ### MLX implementation
<a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;"> <a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;">

View File

@ -38,6 +38,7 @@ fn device(cpu: bool) -> Result<Device> {
#[derive(Debug, serde::Deserialize)] #[derive(Debug, serde::Deserialize)]
struct SttConfig { struct SttConfig {
audio_silence_prefix_seconds: f64,
audio_delay_seconds: f64, audio_delay_seconds: f64,
} }
@ -100,6 +101,7 @@ struct Model {
state: moshi::asr::State, state: moshi::asr::State,
text_tokenizer: sentencepiece::SentencePieceProcessor, text_tokenizer: sentencepiece::SentencePieceProcessor,
timestamps: bool, timestamps: bool,
config: Config,
dev: Device, dev: Device,
} }
@ -128,15 +130,26 @@ impl Model {
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?; let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
Ok(Model { Ok(Model {
state, state,
config,
text_tokenizer, text_tokenizer,
timestamps, timestamps,
dev: dev.clone(), dev: dev.clone(),
}) })
} }
fn run(&mut self, pcm: &[f32]) -> Result<()> { fn run(&mut self, mut pcm: Vec<f32>) -> Result<()> {
use std::io::Write; use std::io::Write;
// Add the silence prefix to the audio.
if self.config.stt_config.audio_silence_prefix_seconds > 0.0 {
let silence_len =
(self.config.stt_config.audio_silence_prefix_seconds * 24000.0) as usize;
pcm.splice(0..0, vec![0.0; silence_len]);
}
// Add some silence at the end to ensure all the audio is processed.
let suffix = (self.config.stt_config.audio_delay_seconds * 24000.0) as usize;
pcm.resize(pcm.len() + suffix + 24000, 0.0);
let mut last_word = None; let mut last_word = None;
for pcm in pcm.chunks(1920) { for pcm in pcm.chunks(1920) {
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?; let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
@ -186,16 +199,14 @@ fn main() -> Result<()> {
println!("Loading audio file from: {}", args.in_file); println!("Loading audio file from: {}", args.in_file);
let (pcm, sample_rate) = kaudio::pcm_decode(args.in_file)?; let (pcm, sample_rate) = kaudio::pcm_decode(args.in_file)?;
let mut pcm = if sample_rate != 24_000 { let pcm = if sample_rate != 24_000 {
kaudio::resample(&pcm, sample_rate as usize, 24_000)? kaudio::resample(&pcm, sample_rate as usize, 24_000)?
} else { } else {
pcm pcm
}; };
// Add some silence at the end to ensure all the audio is processed.
pcm.resize(pcm.len() + 1920 * 32, 0.0);
println!("Loading model from repository: {}", args.hf_repo); println!("Loading model from repository: {}", args.hf_repo);
let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?; let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?;
println!("Running inference"); println!("Running inference");
model.run(&pcm)?; model.run(pcm)?;
Ok(()) Ok(())
} }