Add timestamps to the rust example.

This commit is contained in:
laurent 2025-06-19 09:57:53 +02:00
parent 957edae092
commit 91fb68acc4

View File

@ -18,6 +18,10 @@ struct Args {
/// Run the model on cpu.
#[arg(long)]
cpu: bool,
/// Display word level timestamps.
#[arg(long)]
timestamps: bool,
}
fn device(cpu: bool) -> Result<Device> {
@ -32,6 +36,11 @@ fn device(cpu: bool) -> Result<Device> {
}
}
#[derive(Debug, serde::Deserialize)]
struct SttConfig {
audio_delay_seconds: f64,
}
#[derive(Debug, serde::Deserialize)]
struct Config {
mimi_name: String,
@ -45,6 +54,7 @@ struct Config {
num_heads: usize,
num_layers: usize,
causal: bool,
stt_config: SttConfig,
}
impl Config {
@ -89,11 +99,12 @@ impl Config {
struct Model {
state: moshi::asr::State,
text_tokenizer: sentencepiece::SentencePieceProcessor,
timestamps: bool,
dev: Device,
}
impl Model {
fn load_from_hf(hf_repo: &str, dev: &Device) -> Result<Self> {
fn load_from_hf(hf_repo: &str, timestamps: bool, dev: &Device) -> Result<Self> {
let dtype = dev.bf16_default_to_f32();
// Retrieve the model files from the Hugging Face Hub
@ -113,10 +124,12 @@ impl Model {
&config.model_config(),
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
)?;
let state = moshi::asr::State::new(1, 0, 0., audio_tokenizer, lm)?;
let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize;
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
Ok(Model {
state,
text_tokenizer,
timestamps,
dev: dev.clone(),
})
}
@ -124,39 +137,43 @@ impl Model {
fn run(&mut self, pcm: &[f32]) -> Result<()> {
use std::io::Write;
let mut last_word = None;
for pcm in pcm.chunks(1920) {
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?;
let mut prev_text_token = 0;
for asr_msg in asr_msgs.iter() {
match asr_msg {
moshi::asr::AsrMsg::Step { .. } | moshi::asr::AsrMsg::EndWord { .. } => {}
moshi::asr::AsrMsg::Word { tokens, .. } => {
for &text_token in tokens.iter() {
let s = {
let prev_ids =
self.text_tokenizer.decode_piece_ids(&[prev_text_token]);
let ids = self
.text_tokenizer
.decode_piece_ids(&[prev_text_token, text_token]);
prev_text_token = text_token;
prev_ids.and_then(|prev_ids| {
ids.map(|ids| {
if ids.len() > prev_ids.len() {
ids[prev_ids.len()..].to_string()
} else {
String::new()
}
})
})?
};
print!("{s}");
moshi::asr::AsrMsg::Step { .. } => {}
moshi::asr::AsrMsg::EndWord { stop_time, .. } => {
if self.timestamps {
if let Some((word, start_time)) = last_word.take() {
println!("[{start_time:5.2}-{stop_time:5.2}] {word}");
}
}
}
moshi::asr::AsrMsg::Word {
tokens, start_time, ..
} => {
let word = self
.text_tokenizer
.decode_piece_ids(tokens)
.unwrap_or_else(|_| String::new());
if !self.timestamps {
print!(" {word}");
std::io::stdout().flush()?
} else {
if let Some((word, prev_start_time)) = last_word.take() {
println!("[{prev_start_time:5.2}-{start_time:5.2}] {word}");
}
last_word = Some((word, *start_time));
}
}
}
}
}
if let Some((word, start_time)) = last_word.take() {
println!("[{start_time:5.2}- ] {word}");
}
println!();
Ok(())
}
@ -177,7 +194,7 @@ fn main() -> Result<()> {
// Add some silence at the end to ensure all the audio is processed.
pcm.resize(pcm.len() + 1920 * 32, 0.0);
println!("Loading model from repository: {}", args.hf_repo);
let mut model = Model::load_from_hf(&args.hf_repo, &device)?;
let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?;
println!("Running inference");
model.run(&pcm)?;
Ok(())