2025-06-18 06:32:17 +00:00
|
|
|
// Copyright (c) Kyutai, all rights reserved.
|
|
|
|
|
// This source code is licensed under the license found in the
|
|
|
|
|
// LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
use anyhow::Result;
|
|
|
|
|
use candle::{Device, Tensor};
|
|
|
|
|
use clap::Parser;
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Parser)]
|
|
|
|
|
struct Args {
|
|
|
|
|
/// The audio input file, in wav/mp3/ogg/... format.
|
|
|
|
|
in_file: String,
|
|
|
|
|
|
|
|
|
|
/// The repo where to get the model from.
|
|
|
|
|
#[arg(long, default_value = "kyutai/stt-1b-en_fr-candle")]
|
|
|
|
|
hf_repo: String,
|
|
|
|
|
|
|
|
|
|
/// Run the model on cpu.
|
|
|
|
|
#[arg(long)]
|
|
|
|
|
cpu: bool,
|
2025-06-19 07:57:53 +00:00
|
|
|
|
|
|
|
|
/// Display word level timestamps.
|
|
|
|
|
#[arg(long)]
|
|
|
|
|
timestamps: bool,
|
2025-06-19 13:21:52 +00:00
|
|
|
|
|
|
|
|
/// Display the level of voice activity detection (VAD).
|
|
|
|
|
#[arg(long)]
|
|
|
|
|
vad: bool,
|
2025-06-18 06:32:17 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn device(cpu: bool) -> Result<Device> {
|
|
|
|
|
if cpu {
|
|
|
|
|
Ok(Device::Cpu)
|
|
|
|
|
} else if candle::utils::cuda_is_available() {
|
|
|
|
|
Ok(Device::new_cuda(0)?)
|
|
|
|
|
} else if candle::utils::metal_is_available() {
|
|
|
|
|
Ok(Device::new_metal(0)?)
|
|
|
|
|
} else {
|
|
|
|
|
Ok(Device::Cpu)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2025-06-19 07:57:53 +00:00
|
|
|
#[derive(Debug, serde::Deserialize)]
|
|
|
|
|
struct SttConfig {
|
2025-06-19 08:12:03 +00:00
|
|
|
audio_silence_prefix_seconds: f64,
|
2025-06-19 07:57:53 +00:00
|
|
|
audio_delay_seconds: f64,
|
|
|
|
|
}
|
|
|
|
|
|
2025-06-18 06:32:17 +00:00
|
|
|
#[derive(Debug, serde::Deserialize)]
|
|
|
|
|
struct Config {
|
|
|
|
|
mimi_name: String,
|
|
|
|
|
tokenizer_name: String,
|
|
|
|
|
card: usize,
|
|
|
|
|
text_card: usize,
|
|
|
|
|
dim: usize,
|
|
|
|
|
n_q: usize,
|
|
|
|
|
context: usize,
|
|
|
|
|
max_period: f64,
|
|
|
|
|
num_heads: usize,
|
|
|
|
|
num_layers: usize,
|
|
|
|
|
causal: bool,
|
2025-06-19 07:57:53 +00:00
|
|
|
stt_config: SttConfig,
|
2025-06-18 06:32:17 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Config {
|
2025-06-19 13:21:52 +00:00
|
|
|
fn model_config(&self, vad: bool) -> moshi::lm::Config {
|
2025-06-18 06:32:17 +00:00
|
|
|
let lm_cfg = moshi::transformer::Config {
|
|
|
|
|
d_model: self.dim,
|
|
|
|
|
num_heads: self.num_heads,
|
|
|
|
|
num_layers: self.num_layers,
|
|
|
|
|
dim_feedforward: self.dim * 4,
|
|
|
|
|
causal: self.causal,
|
|
|
|
|
norm_first: true,
|
|
|
|
|
bias_ff: false,
|
|
|
|
|
bias_attn: false,
|
|
|
|
|
layer_scale: None,
|
|
|
|
|
context: self.context,
|
|
|
|
|
max_period: self.max_period as usize,
|
|
|
|
|
use_conv_block: false,
|
|
|
|
|
use_conv_bias: true,
|
|
|
|
|
cross_attention: None,
|
|
|
|
|
gating: Some(candle_nn::Activation::Silu),
|
|
|
|
|
norm: moshi::NormType::RmsNorm,
|
|
|
|
|
positional_embedding: moshi::transformer::PositionalEmbedding::Rope,
|
|
|
|
|
conv_layout: false,
|
|
|
|
|
conv_kernel_size: 3,
|
|
|
|
|
kv_repeat: 1,
|
|
|
|
|
max_seq_len: 4096 * 4,
|
|
|
|
|
shared_cross_attn: false,
|
|
|
|
|
};
|
2025-06-19 13:21:52 +00:00
|
|
|
let extra_heads = if vad {
|
|
|
|
|
Some(moshi::lm::ExtraHeadsConfig {
|
|
|
|
|
num_heads: 4,
|
|
|
|
|
dim: 6,
|
|
|
|
|
})
|
|
|
|
|
} else {
|
|
|
|
|
None
|
|
|
|
|
};
|
2025-06-18 06:32:17 +00:00
|
|
|
moshi::lm::Config {
|
|
|
|
|
transformer: lm_cfg,
|
|
|
|
|
depformer: None,
|
|
|
|
|
audio_vocab_size: self.card + 1,
|
|
|
|
|
text_in_vocab_size: self.text_card + 1,
|
|
|
|
|
text_out_vocab_size: self.text_card,
|
|
|
|
|
audio_codebooks: self.n_q,
|
|
|
|
|
conditioners: Default::default(),
|
2025-06-19 13:21:52 +00:00
|
|
|
extra_heads,
|
2025-06-18 06:32:17 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct Model {
|
|
|
|
|
state: moshi::asr::State,
|
|
|
|
|
text_tokenizer: sentencepiece::SentencePieceProcessor,
|
2025-06-19 07:57:53 +00:00
|
|
|
timestamps: bool,
|
2025-06-19 13:21:52 +00:00
|
|
|
vad: bool,
|
2025-06-19 08:12:03 +00:00
|
|
|
config: Config,
|
2025-06-18 06:32:17 +00:00
|
|
|
dev: Device,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Model {
|
2025-06-19 13:21:52 +00:00
|
|
|
fn load_from_hf(args: &Args, dev: &Device) -> Result<Self> {
|
2025-06-18 06:32:17 +00:00
|
|
|
let dtype = dev.bf16_default_to_f32();
|
|
|
|
|
|
|
|
|
|
// Retrieve the model files from the Hugging Face Hub
|
|
|
|
|
let api = hf_hub::api::sync::Api::new()?;
|
2025-06-19 13:21:52 +00:00
|
|
|
let repo = api.model(args.hf_repo.to_string());
|
2025-06-18 06:32:17 +00:00
|
|
|
let config_file = repo.get("config.json")?;
|
|
|
|
|
let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_file)?)?;
|
|
|
|
|
let tokenizer_file = repo.get(&config.tokenizer_name)?;
|
|
|
|
|
let model_file = repo.get("model.safetensors")?;
|
|
|
|
|
let mimi_file = repo.get(&config.mimi_name)?;
|
|
|
|
|
|
|
|
|
|
let text_tokenizer = sentencepiece::SentencePieceProcessor::open(&tokenizer_file)?;
|
|
|
|
|
let vb_lm =
|
|
|
|
|
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? };
|
|
|
|
|
let audio_tokenizer = moshi::mimi::load(mimi_file.to_str().unwrap(), Some(32), dev)?;
|
|
|
|
|
let lm = moshi::lm::LmModel::new(
|
2025-06-19 13:21:52 +00:00
|
|
|
&config.model_config(args.vad),
|
2025-06-18 06:32:17 +00:00
|
|
|
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
|
|
|
|
|
)?;
|
2025-06-19 07:57:53 +00:00
|
|
|
let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize;
|
|
|
|
|
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
|
2025-06-18 06:32:17 +00:00
|
|
|
Ok(Model {
|
|
|
|
|
state,
|
2025-06-19 08:12:03 +00:00
|
|
|
config,
|
2025-06-18 06:32:17 +00:00
|
|
|
text_tokenizer,
|
2025-06-19 13:21:52 +00:00
|
|
|
timestamps: args.timestamps,
|
|
|
|
|
vad: args.vad,
|
2025-06-18 06:32:17 +00:00
|
|
|
dev: dev.clone(),
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
2025-06-19 08:12:03 +00:00
|
|
|
fn run(&mut self, mut pcm: Vec<f32>) -> Result<()> {
|
2025-06-18 06:32:17 +00:00
|
|
|
use std::io::Write;
|
|
|
|
|
|
2025-06-19 08:12:03 +00:00
|
|
|
// Add the silence prefix to the audio.
|
|
|
|
|
if self.config.stt_config.audio_silence_prefix_seconds > 0.0 {
|
|
|
|
|
let silence_len =
|
|
|
|
|
(self.config.stt_config.audio_silence_prefix_seconds * 24000.0) as usize;
|
|
|
|
|
pcm.splice(0..0, vec![0.0; silence_len]);
|
|
|
|
|
}
|
|
|
|
|
// Add some silence at the end to ensure all the audio is processed.
|
|
|
|
|
let suffix = (self.config.stt_config.audio_delay_seconds * 24000.0) as usize;
|
|
|
|
|
pcm.resize(pcm.len() + suffix + 24000, 0.0);
|
|
|
|
|
|
2025-06-19 07:57:53 +00:00
|
|
|
let mut last_word = None;
|
2025-06-19 13:21:52 +00:00
|
|
|
let mut printed_eot = false;
|
2025-06-18 06:32:17 +00:00
|
|
|
for pcm in pcm.chunks(1920) {
|
|
|
|
|
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
|
|
|
|
|
let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?;
|
|
|
|
|
for asr_msg in asr_msgs.iter() {
|
|
|
|
|
match asr_msg {
|
2025-06-19 13:21:52 +00:00
|
|
|
moshi::asr::AsrMsg::Step { prs, .. } => {
|
2025-06-19 13:34:49 +00:00
|
|
|
// prs is the probability of having no voice activity for different time
|
|
|
|
|
// horizons.
|
|
|
|
|
// In kyutai/stt-1b-en_fr-candle, these horizons are 0.5s, 1s, 2s, and 3s.
|
2025-06-19 13:21:52 +00:00
|
|
|
if self.vad && prs[2][0] > 0.5 && !printed_eot {
|
|
|
|
|
printed_eot = true;
|
|
|
|
|
if !self.timestamps {
|
|
|
|
|
print!(" <endofturn pr={}>", prs[2][0]);
|
|
|
|
|
} else {
|
|
|
|
|
println!("<endofturn pr={}>", prs[2][0]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2025-06-19 07:57:53 +00:00
|
|
|
moshi::asr::AsrMsg::EndWord { stop_time, .. } => {
|
2025-06-19 13:21:52 +00:00
|
|
|
printed_eot = false;
|
2025-06-19 07:57:53 +00:00
|
|
|
if self.timestamps {
|
|
|
|
|
if let Some((word, start_time)) = last_word.take() {
|
|
|
|
|
println!("[{start_time:5.2}-{stop_time:5.2}] {word}");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
moshi::asr::AsrMsg::Word {
|
|
|
|
|
tokens, start_time, ..
|
|
|
|
|
} => {
|
2025-06-19 13:21:52 +00:00
|
|
|
printed_eot = false;
|
2025-06-19 07:57:53 +00:00
|
|
|
let word = self
|
|
|
|
|
.text_tokenizer
|
|
|
|
|
.decode_piece_ids(tokens)
|
|
|
|
|
.unwrap_or_else(|_| String::new());
|
|
|
|
|
if !self.timestamps {
|
|
|
|
|
print!(" {word}");
|
2025-06-18 06:32:17 +00:00
|
|
|
std::io::stdout().flush()?
|
2025-06-19 07:57:53 +00:00
|
|
|
} else {
|
|
|
|
|
if let Some((word, prev_start_time)) = last_word.take() {
|
|
|
|
|
println!("[{prev_start_time:5.2}-{start_time:5.2}] {word}");
|
|
|
|
|
}
|
|
|
|
|
last_word = Some((word, *start_time));
|
2025-06-18 06:32:17 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2025-06-19 07:57:53 +00:00
|
|
|
if let Some((word, start_time)) = last_word.take() {
|
|
|
|
|
println!("[{start_time:5.2}- ] {word}");
|
|
|
|
|
}
|
2025-06-18 06:32:17 +00:00
|
|
|
println!();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn main() -> Result<()> {
|
|
|
|
|
let args = Args::parse();
|
|
|
|
|
let device = device(args.cpu)?;
|
|
|
|
|
println!("Using device: {:?}", device);
|
|
|
|
|
|
|
|
|
|
println!("Loading audio file from: {}", args.in_file);
|
2025-06-19 13:21:52 +00:00
|
|
|
let (pcm, sample_rate) = kaudio::pcm_decode(&args.in_file)?;
|
2025-06-19 08:12:03 +00:00
|
|
|
let pcm = if sample_rate != 24_000 {
|
2025-06-18 06:32:17 +00:00
|
|
|
kaudio::resample(&pcm, sample_rate as usize, 24_000)?
|
|
|
|
|
} else {
|
|
|
|
|
pcm
|
|
|
|
|
};
|
|
|
|
|
println!("Loading model from repository: {}", args.hf_repo);
|
2025-06-19 13:21:52 +00:00
|
|
|
let mut model = Model::load_from_hf(&args, &device)?;
|
2025-06-18 06:32:17 +00:00
|
|
|
println!("Running inference");
|
2025-06-19 08:12:03 +00:00
|
|
|
model.run(pcm)?;
|
2025-06-18 06:32:17 +00:00
|
|
|
Ok(())
|
|
|
|
|
}
|