From 0f10335bc87deab08f6215e722290aed6ae06bda Mon Sep 17 00:00:00 2001 From: Laurent Date: Mon, 23 Jun 2025 11:45:51 +0200 Subject: [PATCH] Add a script for using the mlx version with the built-in microphone. --- scripts/mlx-mic.py | 96 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 scripts/mlx-mic.py diff --git a/scripts/mlx-mic.py b/scripts/mlx-mic.py new file mode 100644 index 0000000..1816efe --- /dev/null +++ b/scripts/mlx-mic.py @@ -0,0 +1,96 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "huggingface_hub", +# "moshi_mlx", +# "numpy", +# "rustymimi", +# "sentencepiece", +# "sounddevice", +# ] +# /// + +import argparse +from dataclasses import dataclass +import json +import numpy as np +import queue +import sounddevice as sd + +from huggingface_hub import hf_hub_download +import mlx.core as mx +import mlx.nn as nn +from moshi_mlx import models, utils +import rustymimi +import sentencepiece + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-steps", default=4096) + parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx") + args = parser.parse_args() + + lm_config = hf_hub_download(args.hf_repo, "config.json") + with open(lm_config, "r") as fobj: + lm_config = json.load(fobj) + mimi_weights = hf_hub_download(args.hf_repo, lm_config["mimi_name"]) + moshi_name = lm_config.get("moshi_name", "model.safetensors") + moshi_weights = hf_hub_download(args.hf_repo, moshi_name) + tokenizer = hf_hub_download(args.hf_repo, lm_config["tokenizer_name"]) + + lm_config = models.LmConfig.from_config_dict(lm_config) + model = models.Lm(lm_config) + model.set_dtype(mx.bfloat16) + if moshi_weights.endswith(".q4.safetensors"): + nn.quantize(model, bits=4, group_size=32) + elif moshi_weights.endswith(".q8.safetensors"): + nn.quantize(model, bits=8, group_size=64) + + print(f"loading model weights from {moshi_weights}") + model.load_weights(moshi_weights, strict=True) + + print(f"loading the text tokenizer from {tokenizer}") + text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer) # type: ignore + + print(f"loading the audio tokenizer {mimi_weights}") + generated_codebooks = lm_config.generated_codebooks + other_codebooks = lm_config.other_codebooks + mimi_codebooks = max(generated_codebooks, other_codebooks) + audio_tokenizer = rustymimi.Tokenizer(mimi_weights, num_codebooks=mimi_codebooks) # type: ignore + print("warming up the model") + model.warmup() + gen = models.LmGen( + model=model, + max_steps=args.max_steps, + text_sampler=utils.Sampler(top_k=25, temp=0), + audio_sampler=utils.Sampler(top_k=250, temp=0.8), + check=False, + ) + + block_queue = queue.Queue() + def audio_callback(indata, _frames, _time, _status): + block_queue.put(indata.copy()) + + print("start recording the user input") + with sd.InputStream( + channels=1, + dtype="float32", + samplerate=24000, + blocksize=1920, + callback=audio_callback, + ): + while True: + block = block_queue.get() + block = block[None, :, 0] + other_audio_tokens = audio_tokenizer.encode_step(block[None, 0:1]) + other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[:, :, :other_codebooks] + text_token = gen.step(other_audio_tokens[0]) + text_token = text_token[0].item() + audio_tokens = gen.last_audio_tokens() + _text = None + if text_token not in (0, 3): + _text = text_tokenizer.id_to_piece(text_token) # type: ignore + _text = _text.replace("▁", " ") + print(_text, end="", flush=True) +