Use the audio prefix in the rust inference.

This commit is contained in:
laurent 2025-06-19 10:12:03 +02:00
parent 91fb68acc4
commit 35c4ea47d8
2 changed files with 17 additions and 5 deletions

View File

@ -101,6 +101,7 @@ This can be used as follows:
cd stt-rs
cargo run --features cuda -r -- bria.mp3
```
You can get the timestamps by adding the `--timestamps` flag.
### MLX implementation
<a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;">

View File

@ -38,6 +38,7 @@ fn device(cpu: bool) -> Result<Device> {
#[derive(Debug, serde::Deserialize)]
struct SttConfig {
audio_silence_prefix_seconds: f64,
audio_delay_seconds: f64,
}
@ -100,6 +101,7 @@ struct Model {
state: moshi::asr::State,
text_tokenizer: sentencepiece::SentencePieceProcessor,
timestamps: bool,
config: Config,
dev: Device,
}
@ -128,15 +130,26 @@ impl Model {
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
Ok(Model {
state,
config,
text_tokenizer,
timestamps,
dev: dev.clone(),
})
}
fn run(&mut self, pcm: &[f32]) -> Result<()> {
fn run(&mut self, mut pcm: Vec<f32>) -> Result<()> {
use std::io::Write;
// Add the silence prefix to the audio.
if self.config.stt_config.audio_silence_prefix_seconds > 0.0 {
let silence_len =
(self.config.stt_config.audio_silence_prefix_seconds * 24000.0) as usize;
pcm.splice(0..0, vec![0.0; silence_len]);
}
// Add some silence at the end to ensure all the audio is processed.
let suffix = (self.config.stt_config.audio_delay_seconds * 24000.0) as usize;
pcm.resize(pcm.len() + suffix + 24000, 0.0);
let mut last_word = None;
for pcm in pcm.chunks(1920) {
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
@ -186,16 +199,14 @@ fn main() -> Result<()> {
println!("Loading audio file from: {}", args.in_file);
let (pcm, sample_rate) = kaudio::pcm_decode(args.in_file)?;
let mut pcm = if sample_rate != 24_000 {
let pcm = if sample_rate != 24_000 {
kaudio::resample(&pcm, sample_rate as usize, 24_000)?
} else {
pcm
};
// Add some silence at the end to ensure all the audio is processed.
pcm.resize(pcm.len() + 1920 * 32, 0.0);
println!("Loading model from repository: {}", args.hf_repo);
let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?;
println!("Running inference");
model.run(&pcm)?;
model.run(pcm)?;
Ok(())
}