Add timestamps to the rust example.

This commit is contained in:
laurent 2025-06-19 09:57:53 +02:00
parent 957edae092
commit 91fb68acc4

View File

@ -18,6 +18,10 @@ struct Args {
/// Run the model on cpu. /// Run the model on cpu.
#[arg(long)] #[arg(long)]
cpu: bool, cpu: bool,
/// Display word level timestamps.
#[arg(long)]
timestamps: bool,
} }
fn device(cpu: bool) -> Result<Device> { fn device(cpu: bool) -> Result<Device> {
@ -32,6 +36,11 @@ fn device(cpu: bool) -> Result<Device> {
} }
} }
#[derive(Debug, serde::Deserialize)]
struct SttConfig {
audio_delay_seconds: f64,
}
#[derive(Debug, serde::Deserialize)] #[derive(Debug, serde::Deserialize)]
struct Config { struct Config {
mimi_name: String, mimi_name: String,
@ -45,6 +54,7 @@ struct Config {
num_heads: usize, num_heads: usize,
num_layers: usize, num_layers: usize,
causal: bool, causal: bool,
stt_config: SttConfig,
} }
impl Config { impl Config {
@ -89,11 +99,12 @@ impl Config {
struct Model { struct Model {
state: moshi::asr::State, state: moshi::asr::State,
text_tokenizer: sentencepiece::SentencePieceProcessor, text_tokenizer: sentencepiece::SentencePieceProcessor,
timestamps: bool,
dev: Device, dev: Device,
} }
impl Model { impl Model {
fn load_from_hf(hf_repo: &str, dev: &Device) -> Result<Self> { fn load_from_hf(hf_repo: &str, timestamps: bool, dev: &Device) -> Result<Self> {
let dtype = dev.bf16_default_to_f32(); let dtype = dev.bf16_default_to_f32();
// Retrieve the model files from the Hugging Face Hub // Retrieve the model files from the Hugging Face Hub
@ -113,10 +124,12 @@ impl Model {
&config.model_config(), &config.model_config(),
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm), moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
)?; )?;
let state = moshi::asr::State::new(1, 0, 0., audio_tokenizer, lm)?; let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize;
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
Ok(Model { Ok(Model {
state, state,
text_tokenizer, text_tokenizer,
timestamps,
dev: dev.clone(), dev: dev.clone(),
}) })
} }
@ -124,39 +137,43 @@ impl Model {
fn run(&mut self, pcm: &[f32]) -> Result<()> { fn run(&mut self, pcm: &[f32]) -> Result<()> {
use std::io::Write; use std::io::Write;
let mut last_word = None;
for pcm in pcm.chunks(1920) { for pcm in pcm.chunks(1920) {
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?; let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?; let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?;
let mut prev_text_token = 0;
for asr_msg in asr_msgs.iter() { for asr_msg in asr_msgs.iter() {
match asr_msg { match asr_msg {
moshi::asr::AsrMsg::Step { .. } | moshi::asr::AsrMsg::EndWord { .. } => {} moshi::asr::AsrMsg::Step { .. } => {}
moshi::asr::AsrMsg::Word { tokens, .. } => { moshi::asr::AsrMsg::EndWord { stop_time, .. } => {
for &text_token in tokens.iter() { if self.timestamps {
let s = { if let Some((word, start_time)) = last_word.take() {
let prev_ids = println!("[{start_time:5.2}-{stop_time:5.2}] {word}");
self.text_tokenizer.decode_piece_ids(&[prev_text_token]); }
let ids = self }
.text_tokenizer }
.decode_piece_ids(&[prev_text_token, text_token]); moshi::asr::AsrMsg::Word {
prev_text_token = text_token; tokens, start_time, ..
prev_ids.and_then(|prev_ids| { } => {
ids.map(|ids| { let word = self
if ids.len() > prev_ids.len() { .text_tokenizer
ids[prev_ids.len()..].to_string() .decode_piece_ids(tokens)
} else { .unwrap_or_else(|_| String::new());
String::new() if !self.timestamps {
} print!(" {word}");
})
})?
};
print!("{s}");
std::io::stdout().flush()? std::io::stdout().flush()?
} else {
if let Some((word, prev_start_time)) = last_word.take() {
println!("[{prev_start_time:5.2}-{start_time:5.2}] {word}");
}
last_word = Some((word, *start_time));
} }
} }
} }
} }
} }
if let Some((word, start_time)) = last_word.take() {
println!("[{start_time:5.2}- ] {word}");
}
println!(); println!();
Ok(()) Ok(())
} }
@ -177,7 +194,7 @@ fn main() -> Result<()> {
// Add some silence at the end to ensure all the audio is processed. // Add some silence at the end to ensure all the audio is processed.
pcm.resize(pcm.len() + 1920 * 32, 0.0); pcm.resize(pcm.len() + 1920 * 32, 0.0);
println!("Loading model from repository: {}", args.hf_repo); println!("Loading model from repository: {}", args.hf_repo);
let mut model = Model::load_from_hf(&args.hf_repo, &device)?; let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?;
println!("Running inference"); println!("Running inference");
model.run(&pcm)?; model.run(&pcm)?;
Ok(()) Ok(())