Use the audio prefix in the rust inference.
This commit is contained in:
parent
91fb68acc4
commit
35c4ea47d8
|
|
@ -101,6 +101,7 @@ This can be used as follows:
|
|||
cd stt-rs
|
||||
cargo run --features cuda -r -- bria.mp3
|
||||
```
|
||||
You can get the timestamps by adding the `--timestamps` flag.
|
||||
|
||||
### MLX implementation
|
||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;">
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ fn device(cpu: bool) -> Result<Device> {
|
|||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct SttConfig {
|
||||
audio_silence_prefix_seconds: f64,
|
||||
audio_delay_seconds: f64,
|
||||
}
|
||||
|
||||
|
|
@ -100,6 +101,7 @@ struct Model {
|
|||
state: moshi::asr::State,
|
||||
text_tokenizer: sentencepiece::SentencePieceProcessor,
|
||||
timestamps: bool,
|
||||
config: Config,
|
||||
dev: Device,
|
||||
}
|
||||
|
||||
|
|
@ -128,15 +130,26 @@ impl Model {
|
|||
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
|
||||
Ok(Model {
|
||||
state,
|
||||
config,
|
||||
text_tokenizer,
|
||||
timestamps,
|
||||
dev: dev.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn run(&mut self, pcm: &[f32]) -> Result<()> {
|
||||
fn run(&mut self, mut pcm: Vec<f32>) -> Result<()> {
|
||||
use std::io::Write;
|
||||
|
||||
// Add the silence prefix to the audio.
|
||||
if self.config.stt_config.audio_silence_prefix_seconds > 0.0 {
|
||||
let silence_len =
|
||||
(self.config.stt_config.audio_silence_prefix_seconds * 24000.0) as usize;
|
||||
pcm.splice(0..0, vec![0.0; silence_len]);
|
||||
}
|
||||
// Add some silence at the end to ensure all the audio is processed.
|
||||
let suffix = (self.config.stt_config.audio_delay_seconds * 24000.0) as usize;
|
||||
pcm.resize(pcm.len() + suffix + 24000, 0.0);
|
||||
|
||||
let mut last_word = None;
|
||||
for pcm in pcm.chunks(1920) {
|
||||
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
|
||||
|
|
@ -186,16 +199,14 @@ fn main() -> Result<()> {
|
|||
|
||||
println!("Loading audio file from: {}", args.in_file);
|
||||
let (pcm, sample_rate) = kaudio::pcm_decode(args.in_file)?;
|
||||
let mut pcm = if sample_rate != 24_000 {
|
||||
let pcm = if sample_rate != 24_000 {
|
||||
kaudio::resample(&pcm, sample_rate as usize, 24_000)?
|
||||
} else {
|
||||
pcm
|
||||
};
|
||||
// Add some silence at the end to ensure all the audio is processed.
|
||||
pcm.resize(pcm.len() + 1920 * 32, 0.0);
|
||||
println!("Loading model from repository: {}", args.hf_repo);
|
||||
let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?;
|
||||
println!("Running inference");
|
||||
model.run(&pcm)?;
|
||||
model.run(pcm)?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user