Run Ruff on tts_mlx.py
This commit is contained in:
parent
433dca3751
commit
da83b4b63f
|
|
@ -10,23 +10,24 @@
|
|||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
import queue
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import sentencepiece
|
||||
import sphn
|
||||
import time
|
||||
|
||||
import sounddevice as sd
|
||||
|
||||
from moshi_mlx.client_utils import make_log
|
||||
import sphn
|
||||
from moshi_mlx import models
|
||||
from moshi_mlx.client_utils import make_log
|
||||
from moshi_mlx.models.tts import (
|
||||
DEFAULT_DSM_TTS_REPO,
|
||||
DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
TTSModel,
|
||||
)
|
||||
from moshi_mlx.utils.loaders import hf_get
|
||||
from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
|
||||
|
||||
|
||||
def log(level: str, msg: str):
|
||||
|
|
@ -34,15 +35,32 @@ def log(level: str, msg: str):
|
|||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Kyutai TTS using the PyTorch implementation"
|
||||
)
|
||||
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
|
||||
parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio")
|
||||
parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO,
|
||||
help="HF repo in which to look for the pretrained models.")
|
||||
parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
help="HF repo in which to look for pre-computed voice embeddings.")
|
||||
parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav")
|
||||
parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.")
|
||||
parser.add_argument(
|
||||
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-repo",
|
||||
type=str,
|
||||
default=DEFAULT_DSM_TTS_REPO,
|
||||
help="HF repo in which to look for the pretrained models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice-repo",
|
||||
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
type=int,
|
||||
help="The quantization to be applied, e.g. 8 for 8 bits.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(299792458)
|
||||
|
|
@ -96,7 +114,7 @@ def main():
|
|||
if tts_model.valid_cfg_conditionings:
|
||||
# Model was trained with CFG distillation.
|
||||
cfg_coef_conditioning = tts_model.cfg_coef
|
||||
tts_model.cfg_coef = 1.
|
||||
tts_model.cfg_coef = 1.0
|
||||
cfg_is_no_text = False
|
||||
cfg_is_no_prefix = False
|
||||
else:
|
||||
|
|
@ -105,17 +123,25 @@ def main():
|
|||
mimi = tts_model.mimi
|
||||
|
||||
log("info", f"reading input from {args.inp}")
|
||||
with open(args.inp, "r") as fobj:
|
||||
text_to_tts = fobj.read().strip()
|
||||
if args.inp == "-":
|
||||
if sys.stdin.isatty(): # Interactive
|
||||
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||
text_to_tts = sys.stdin.read().strip()
|
||||
else:
|
||||
with open(args.inp, "r") as fobj:
|
||||
text_to_tts = fobj.read().strip()
|
||||
|
||||
all_entries = [tts_model.prepare_script([text_to_tts])]
|
||||
if tts_model.multi_speaker:
|
||||
voices = [tts_model.get_voice_path(args.voice)]
|
||||
else:
|
||||
voices = []
|
||||
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
|
||||
all_attributes = [
|
||||
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
|
||||
]
|
||||
|
||||
wav_frames = queue.Queue()
|
||||
|
||||
def _on_audio_hook(audio_tokens):
|
||||
if (audio_tokens == -1).any():
|
||||
return
|
||||
|
|
@ -141,16 +167,20 @@ def main():
|
|||
return result
|
||||
|
||||
if args.out == "-":
|
||||
|
||||
def audio_callback(outdata, _a, _b, _c):
|
||||
try:
|
||||
pcm_data = wav_frames.get(block=False)
|
||||
outdata[:, 0] = pcm_data
|
||||
except queue.Empty:
|
||||
outdata[:] = 0
|
||||
with sd.OutputStream(samplerate=mimi.sample_rate,
|
||||
blocksize=1920,
|
||||
channels=1,
|
||||
callback=audio_callback):
|
||||
|
||||
with sd.OutputStream(
|
||||
samplerate=mimi.sample_rate,
|
||||
blocksize=1920,
|
||||
channels=1,
|
||||
callback=audio_callback,
|
||||
):
|
||||
run()
|
||||
time.sleep(3)
|
||||
while True:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user