Add timestamps to the rust example.
This commit is contained in:
parent
957edae092
commit
91fb68acc4
|
|
@ -18,6 +18,10 @@ struct Args {
|
||||||
/// Run the model on cpu.
|
/// Run the model on cpu.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Display word level timestamps.
|
||||||
|
#[arg(long)]
|
||||||
|
timestamps: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn device(cpu: bool) -> Result<Device> {
|
fn device(cpu: bool) -> Result<Device> {
|
||||||
|
|
@ -32,6 +36,11 @@ fn device(cpu: bool) -> Result<Device> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize)]
|
||||||
|
struct SttConfig {
|
||||||
|
audio_delay_seconds: f64,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, serde::Deserialize)]
|
#[derive(Debug, serde::Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
mimi_name: String,
|
mimi_name: String,
|
||||||
|
|
@ -45,6 +54,7 @@ struct Config {
|
||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
num_layers: usize,
|
num_layers: usize,
|
||||||
causal: bool,
|
causal: bool,
|
||||||
|
stt_config: SttConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
|
@ -89,11 +99,12 @@ impl Config {
|
||||||
struct Model {
|
struct Model {
|
||||||
state: moshi::asr::State,
|
state: moshi::asr::State,
|
||||||
text_tokenizer: sentencepiece::SentencePieceProcessor,
|
text_tokenizer: sentencepiece::SentencePieceProcessor,
|
||||||
|
timestamps: bool,
|
||||||
dev: Device,
|
dev: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
fn load_from_hf(hf_repo: &str, dev: &Device) -> Result<Self> {
|
fn load_from_hf(hf_repo: &str, timestamps: bool, dev: &Device) -> Result<Self> {
|
||||||
let dtype = dev.bf16_default_to_f32();
|
let dtype = dev.bf16_default_to_f32();
|
||||||
|
|
||||||
// Retrieve the model files from the Hugging Face Hub
|
// Retrieve the model files from the Hugging Face Hub
|
||||||
|
|
@ -113,10 +124,12 @@ impl Model {
|
||||||
&config.model_config(),
|
&config.model_config(),
|
||||||
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
|
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
|
||||||
)?;
|
)?;
|
||||||
let state = moshi::asr::State::new(1, 0, 0., audio_tokenizer, lm)?;
|
let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize;
|
||||||
|
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
|
||||||
Ok(Model {
|
Ok(Model {
|
||||||
state,
|
state,
|
||||||
text_tokenizer,
|
text_tokenizer,
|
||||||
|
timestamps,
|
||||||
dev: dev.clone(),
|
dev: dev.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -124,39 +137,43 @@ impl Model {
|
||||||
fn run(&mut self, pcm: &[f32]) -> Result<()> {
|
fn run(&mut self, pcm: &[f32]) -> Result<()> {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
|
let mut last_word = None;
|
||||||
for pcm in pcm.chunks(1920) {
|
for pcm in pcm.chunks(1920) {
|
||||||
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
|
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
|
||||||
let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?;
|
let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?;
|
||||||
let mut prev_text_token = 0;
|
|
||||||
for asr_msg in asr_msgs.iter() {
|
for asr_msg in asr_msgs.iter() {
|
||||||
match asr_msg {
|
match asr_msg {
|
||||||
moshi::asr::AsrMsg::Step { .. } | moshi::asr::AsrMsg::EndWord { .. } => {}
|
moshi::asr::AsrMsg::Step { .. } => {}
|
||||||
moshi::asr::AsrMsg::Word { tokens, .. } => {
|
moshi::asr::AsrMsg::EndWord { stop_time, .. } => {
|
||||||
for &text_token in tokens.iter() {
|
if self.timestamps {
|
||||||
let s = {
|
if let Some((word, start_time)) = last_word.take() {
|
||||||
let prev_ids =
|
println!("[{start_time:5.2}-{stop_time:5.2}] {word}");
|
||||||
self.text_tokenizer.decode_piece_ids(&[prev_text_token]);
|
}
|
||||||
let ids = self
|
}
|
||||||
.text_tokenizer
|
}
|
||||||
.decode_piece_ids(&[prev_text_token, text_token]);
|
moshi::asr::AsrMsg::Word {
|
||||||
prev_text_token = text_token;
|
tokens, start_time, ..
|
||||||
prev_ids.and_then(|prev_ids| {
|
} => {
|
||||||
ids.map(|ids| {
|
let word = self
|
||||||
if ids.len() > prev_ids.len() {
|
.text_tokenizer
|
||||||
ids[prev_ids.len()..].to_string()
|
.decode_piece_ids(tokens)
|
||||||
} else {
|
.unwrap_or_else(|_| String::new());
|
||||||
String::new()
|
if !self.timestamps {
|
||||||
}
|
print!(" {word}");
|
||||||
})
|
|
||||||
})?
|
|
||||||
};
|
|
||||||
print!("{s}");
|
|
||||||
std::io::stdout().flush()?
|
std::io::stdout().flush()?
|
||||||
|
} else {
|
||||||
|
if let Some((word, prev_start_time)) = last_word.take() {
|
||||||
|
println!("[{prev_start_time:5.2}-{start_time:5.2}] {word}");
|
||||||
|
}
|
||||||
|
last_word = Some((word, *start_time));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if let Some((word, start_time)) = last_word.take() {
|
||||||
|
println!("[{start_time:5.2}- ] {word}");
|
||||||
|
}
|
||||||
println!();
|
println!();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -177,7 +194,7 @@ fn main() -> Result<()> {
|
||||||
// Add some silence at the end to ensure all the audio is processed.
|
// Add some silence at the end to ensure all the audio is processed.
|
||||||
pcm.resize(pcm.len() + 1920 * 32, 0.0);
|
pcm.resize(pcm.len() + 1920 * 32, 0.0);
|
||||||
println!("Loading model from repository: {}", args.hf_repo);
|
println!("Loading model from repository: {}", args.hf_repo);
|
||||||
let mut model = Model::load_from_hf(&args.hf_repo, &device)?;
|
let mut model = Model::load_from_hf(&args.hf_repo, args.timestamps, &device)?;
|
||||||
println!("Running inference");
|
println!("Running inference");
|
||||||
model.run(&pcm)?;
|
model.run(&pcm)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user