Use the audio prefix in the rust inference.
This commit is contained in:
parent
91fb68acc4
commit
35c4ea47d8
|
|
@ -101,6 +101,7 @@ This can be used as follows:
|
||||||
cd stt-rs
|
cd stt-rs
|
||||||
cargo run --features cuda -r -- bria.mp3
|
cargo run --features cuda -r -- bria.mp3
|
||||||
```
|
```
|
||||||
|
You can get the timestamps by adding the `--timestamps` flag.
|
||||||
|
|
||||||
### MLX implementation
|
### MLX implementation
|
||||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;">
|
<a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;">
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ fn device(cpu: bool) -> Result<Device> {
|
||||||
|
|
||||||
#[derive(Debug, serde::Deserialize)]
|
#[derive(Debug, serde::Deserialize)]
|
||||||
struct SttConfig {
|
struct SttConfig {
|
||||||
|
audio_silence_prefix_seconds: f64,
|
||||||
audio_delay_seconds: f64,
|
audio_delay_seconds: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -100,6 +101,7 @@ struct Model {
|
||||||
state: moshi::asr::State,
|
state: moshi::asr::State,
|
||||||
text_tokenizer: sentencepiece::SentencePieceProcessor,
|
text_tokenizer: sentencepiece::SentencePieceProcessor,
|
||||||
timestamps: bool,
|
timestamps: bool,
|
||||||
|
config: Config,
|
||||||
dev: Device,
|
dev: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -128,15 +130,26 @@ impl Model {
|
||||||
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
|
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
|
||||||
Ok(Model {
|
Ok(Model {
|
||||||
state,
|
state,
|
||||||
|
config,
|
||||||
text_tokenizer,
|
text_tokenizer,
|
||||||
timestamps,
|
timestamps,
|
||||||
dev: dev.clone(),
|
dev: dev.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run(&mut self, pcm: &[f32]) -> Result<()> {
|
fn run(&mut self, mut pcm: Vec<f32>) -> Result<()> {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
|
// Add the silence prefix to the audio.
|
||||||
|
if self.config.stt_config.audio_silence_prefix_seconds > 0.0 {
|
||||||
|
let silence_len =
|
||||||
|
(self.config.stt_config.audio_silence_prefix_seconds * 24000.0) as usize;
|
||||||
|
pcm.splice(0..0, vec![0.0; silence_len]);
|
||||||
|
}
|
||||||
|
// Add some silence at the end to ensure all the audio is processed.
|
||||||
|
let suffix = (self.config.stt_config.audio_delay_seconds * 24000.0) as usize;
|
||||||
|
pcm.resize(pcm.len() + suffix + 24000, 0.0);
|
||||||
|
|
||||||
let mut last_word = None;
|
let mut last_word = None;
|
||||||
for pcm in pcm.chunks(1920) {
|
for pcm in pcm.chunks(1920) {
|
||||||
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
|
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
|
||||||
|
|
@ -186,16 +199,14 @@ fn main() -> Result<()> {
|
||||||
|
|
||||||
println!("Loading audio file from: {}", args.in_file);
|
println!("Loading audio file from: {}", args.in_file);
|
||||||
let (pcm, sample_rate) = kaudio::pcm_decode(args.in_file)?;
|
let (pcm, sample_rate) = kaudio::pcm_decode(args.in_file)?;
|
||||||
let mut pcm = if sample_rate != 24_000 {
|
let pcm = if sample_rate != 24_000 {
|
||||||
kaudio::resample(&pcm, sample_rate as usize, 24_000)?
|
kaudio::resample(&pcm, sample_rate as usize, 24_000)?
|
||||||
} else {
|
} else {
|
||||||
pcm
|
pcm
|
||||||
};
|
};
|
||||||
// Add some silence at the end to ensure all the audio is processed.
|
|
||||||
pcm.resize(pcm.len() + 1920 * 32, 0.0);
|
|
||||||
println!("Loading model from repository: {}", args.hf_repo);
|
println!("Loading model from repository: {}", args.hf_repo);
|
||||||
let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?;
|
let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?;
|
||||||
println!("Running inference");
|
println!("Running inference");
|
||||||
model.run(&pcm)?;
|
model.run(pcm)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user