Run Ruff on tts_mlx.py

This commit is contained in:
Václav Volhejn 2025-07-02 21:54:05 +02:00
parent 433dca3751
commit da83b4b63f

View File

@ -10,23 +10,24 @@
import argparse import argparse
import json import json
from pathlib import Path
import queue import queue
import sys
import time import time
import numpy as np
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np
import sentencepiece import sentencepiece
import sphn
import time
import sounddevice as sd import sounddevice as sd
import sphn
from moshi_mlx.client_utils import make_log
from moshi_mlx import models from moshi_mlx import models
from moshi_mlx.client_utils import make_log
from moshi_mlx.models.tts import (
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
)
from moshi_mlx.utils.loaders import hf_get from moshi_mlx.utils.loaders import hf_get
from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
def log(level: str, msg: str): def log(level: str, msg: str):
@ -34,15 +35,32 @@ def log(level: str, msg: str):
def main(): def main():
parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi') parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
)
parser.add_argument("inp", type=str, help="Input file, use - for stdin") parser.add_argument("inp", type=str, help="Input file, use - for stdin")
parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio") parser.add_argument(
parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO, "out", type=str, help="Output file to generate, use - for playing the audio"
help="HF repo in which to look for the pretrained models.") )
parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO, parser.add_argument(
help="HF repo in which to look for pre-computed voice embeddings.") "--hf-repo",
parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav") type=str,
parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.") default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
)
parser.add_argument(
"--quantize",
type=int,
help="The quantization to be applied, e.g. 8 for 8 bits.",
)
args = parser.parse_args() args = parser.parse_args()
mx.random.seed(299792458) mx.random.seed(299792458)
@ -96,7 +114,7 @@ def main():
if tts_model.valid_cfg_conditionings: if tts_model.valid_cfg_conditionings:
# Model was trained with CFG distillation. # Model was trained with CFG distillation.
cfg_coef_conditioning = tts_model.cfg_coef cfg_coef_conditioning = tts_model.cfg_coef
tts_model.cfg_coef = 1. tts_model.cfg_coef = 1.0
cfg_is_no_text = False cfg_is_no_text = False
cfg_is_no_prefix = False cfg_is_no_prefix = False
else: else:
@ -105,17 +123,25 @@ def main():
mimi = tts_model.mimi mimi = tts_model.mimi
log("info", f"reading input from {args.inp}") log("info", f"reading input from {args.inp}")
with open(args.inp, "r") as fobj: if args.inp == "-":
text_to_tts = fobj.read().strip() if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
text_to_tts = sys.stdin.read().strip()
else:
with open(args.inp, "r") as fobj:
text_to_tts = fobj.read().strip()
all_entries = [tts_model.prepare_script([text_to_tts])] all_entries = [tts_model.prepare_script([text_to_tts])]
if tts_model.multi_speaker: if tts_model.multi_speaker:
voices = [tts_model.get_voice_path(args.voice)] voices = [tts_model.get_voice_path(args.voice)]
else: else:
voices = [] voices = []
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)] all_attributes = [
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
]
wav_frames = queue.Queue() wav_frames = queue.Queue()
def _on_audio_hook(audio_tokens): def _on_audio_hook(audio_tokens):
if (audio_tokens == -1).any(): if (audio_tokens == -1).any():
return return
@ -141,16 +167,20 @@ def main():
return result return result
if args.out == "-": if args.out == "-":
def audio_callback(outdata, _a, _b, _c): def audio_callback(outdata, _a, _b, _c):
try: try:
pcm_data = wav_frames.get(block=False) pcm_data = wav_frames.get(block=False)
outdata[:, 0] = pcm_data outdata[:, 0] = pcm_data
except queue.Empty: except queue.Empty:
outdata[:] = 0 outdata[:] = 0
with sd.OutputStream(samplerate=mimi.sample_rate,
blocksize=1920, with sd.OutputStream(
channels=1, samplerate=mimi.sample_rate,
callback=audio_callback): blocksize=1920,
channels=1,
callback=audio_callback,
):
run() run()
time.sleep(3) time.sleep(3)
while True: while True: