Add timestamps to the rust example.
This commit is contained in:
parent
957edae092
commit
91fb68acc4
|
|
@ -18,6 +18,10 @@ struct Args {
|
|||
/// Run the model on cpu.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Display word level timestamps.
|
||||
#[arg(long)]
|
||||
timestamps: bool,
|
||||
}
|
||||
|
||||
fn device(cpu: bool) -> Result<Device> {
|
||||
|
|
@ -32,6 +36,11 @@ fn device(cpu: bool) -> Result<Device> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct SttConfig {
|
||||
audio_delay_seconds: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct Config {
|
||||
mimi_name: String,
|
||||
|
|
@ -45,6 +54,7 @@ struct Config {
|
|||
num_heads: usize,
|
||||
num_layers: usize,
|
||||
causal: bool,
|
||||
stt_config: SttConfig,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
|
|
@ -89,11 +99,12 @@ impl Config {
|
|||
struct Model {
|
||||
state: moshi::asr::State,
|
||||
text_tokenizer: sentencepiece::SentencePieceProcessor,
|
||||
timestamps: bool,
|
||||
dev: Device,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn load_from_hf(hf_repo: &str, dev: &Device) -> Result<Self> {
|
||||
fn load_from_hf(hf_repo: &str, timestamps: bool, dev: &Device) -> Result<Self> {
|
||||
let dtype = dev.bf16_default_to_f32();
|
||||
|
||||
// Retrieve the model files from the Hugging Face Hub
|
||||
|
|
@ -113,10 +124,12 @@ impl Model {
|
|||
&config.model_config(),
|
||||
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
|
||||
)?;
|
||||
let state = moshi::asr::State::new(1, 0, 0., audio_tokenizer, lm)?;
|
||||
let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize;
|
||||
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
|
||||
Ok(Model {
|
||||
state,
|
||||
text_tokenizer,
|
||||
timestamps,
|
||||
dev: dev.clone(),
|
||||
})
|
||||
}
|
||||
|
|
@ -124,39 +137,43 @@ impl Model {
|
|||
fn run(&mut self, pcm: &[f32]) -> Result<()> {
|
||||
use std::io::Write;
|
||||
|
||||
let mut last_word = None;
|
||||
for pcm in pcm.chunks(1920) {
|
||||
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
|
||||
let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?;
|
||||
let mut prev_text_token = 0;
|
||||
for asr_msg in asr_msgs.iter() {
|
||||
match asr_msg {
|
||||
moshi::asr::AsrMsg::Step { .. } | moshi::asr::AsrMsg::EndWord { .. } => {}
|
||||
moshi::asr::AsrMsg::Word { tokens, .. } => {
|
||||
for &text_token in tokens.iter() {
|
||||
let s = {
|
||||
let prev_ids =
|
||||
self.text_tokenizer.decode_piece_ids(&[prev_text_token]);
|
||||
let ids = self
|
||||
moshi::asr::AsrMsg::Step { .. } => {}
|
||||
moshi::asr::AsrMsg::EndWord { stop_time, .. } => {
|
||||
if self.timestamps {
|
||||
if let Some((word, start_time)) = last_word.take() {
|
||||
println!("[{start_time:5.2}-{stop_time:5.2}] {word}");
|
||||
}
|
||||
}
|
||||
}
|
||||
moshi::asr::AsrMsg::Word {
|
||||
tokens, start_time, ..
|
||||
} => {
|
||||
let word = self
|
||||
.text_tokenizer
|
||||
.decode_piece_ids(&[prev_text_token, text_token]);
|
||||
prev_text_token = text_token;
|
||||
prev_ids.and_then(|prev_ids| {
|
||||
ids.map(|ids| {
|
||||
if ids.len() > prev_ids.len() {
|
||||
ids[prev_ids.len()..].to_string()
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
})
|
||||
})?
|
||||
};
|
||||
print!("{s}");
|
||||
.decode_piece_ids(tokens)
|
||||
.unwrap_or_else(|_| String::new());
|
||||
if !self.timestamps {
|
||||
print!(" {word}");
|
||||
std::io::stdout().flush()?
|
||||
} else {
|
||||
if let Some((word, prev_start_time)) = last_word.take() {
|
||||
println!("[{prev_start_time:5.2}-{start_time:5.2}] {word}");
|
||||
}
|
||||
last_word = Some((word, *start_time));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((word, start_time)) = last_word.take() {
|
||||
println!("[{start_time:5.2}- ] {word}");
|
||||
}
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -177,7 +194,7 @@ fn main() -> Result<()> {
|
|||
// Add some silence at the end to ensure all the audio is processed.
|
||||
pcm.resize(pcm.len() + 1920 * 32, 0.0);
|
||||
println!("Loading model from repository: {}", args.hf_repo);
|
||||
let mut model = Model::load_from_hf(&args.hf_repo, &device)?;
|
||||
let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?;
|
||||
println!("Running inference");
|
||||
model.run(&pcm)?;
|
||||
Ok(())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user