Run Ruff on tts_mlx.py

This commit is contained in:
Václav Volhejn 2025-07-02 21:54:05 +02:00
parent 433dca3751
commit da83b4b63f

View File

@ -10,23 +10,24 @@
import argparse
import json
from pathlib import Path
import queue
import sys
import time
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import sentencepiece
import sphn
import time
import sounddevice as sd
from moshi_mlx.client_utils import make_log
import sphn
from moshi_mlx import models
from moshi_mlx.client_utils import make_log
from moshi_mlx.models.tts import (
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
)
from moshi_mlx.utils.loaders import hf_get
from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
def log(level: str, msg: str):
@ -34,15 +35,32 @@ def log(level: str, msg: str):
def main():
parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi')
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
)
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio")
parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.")
parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.")
parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav")
parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.")
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
)
parser.add_argument(
"--quantize",
type=int,
help="The quantization to be applied, e.g. 8 for 8 bits.",
)
args = parser.parse_args()
mx.random.seed(299792458)
@ -96,7 +114,7 @@ def main():
if tts_model.valid_cfg_conditionings:
# Model was trained with CFG distillation.
cfg_coef_conditioning = tts_model.cfg_coef
tts_model.cfg_coef = 1.
tts_model.cfg_coef = 1.0
cfg_is_no_text = False
cfg_is_no_prefix = False
else:
@ -105,17 +123,25 @@ def main():
mimi = tts_model.mimi
log("info", f"reading input from {args.inp}")
with open(args.inp, "r") as fobj:
text_to_tts = fobj.read().strip()
if args.inp == "-":
if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
text_to_tts = sys.stdin.read().strip()
else:
with open(args.inp, "r") as fobj:
text_to_tts = fobj.read().strip()
all_entries = [tts_model.prepare_script([text_to_tts])]
if tts_model.multi_speaker:
voices = [tts_model.get_voice_path(args.voice)]
else:
voices = []
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
all_attributes = [
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
]
wav_frames = queue.Queue()
def _on_audio_hook(audio_tokens):
if (audio_tokens == -1).any():
return
@ -141,16 +167,20 @@ def main():
return result
if args.out == "-":
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = wav_frames.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(samplerate=mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback):
with sd.OutputStream(
samplerate=mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
):
run()
time.sleep(3)
while True: