Run Ruff on tts_mlx.py
This commit is contained in:
parent
433dca3751
commit
da83b4b63f
|
|
@ -10,23 +10,24 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
|
||||||
import queue
|
import queue
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
import sentencepiece
|
import sentencepiece
|
||||||
import sphn
|
|
||||||
import time
|
|
||||||
|
|
||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
|
import sphn
|
||||||
from moshi_mlx.client_utils import make_log
|
|
||||||
from moshi_mlx import models
|
from moshi_mlx import models
|
||||||
|
from moshi_mlx.client_utils import make_log
|
||||||
|
from moshi_mlx.models.tts import (
|
||||||
|
DEFAULT_DSM_TTS_REPO,
|
||||||
|
DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
TTSModel,
|
||||||
|
)
|
||||||
from moshi_mlx.utils.loaders import hf_get
|
from moshi_mlx.utils.loaders import hf_get
|
||||||
from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
|
|
||||||
|
|
||||||
|
|
||||||
def log(level: str, msg: str):
|
def log(level: str, msg: str):
|
||||||
|
|
@ -34,15 +35,32 @@ def log(level: str, msg: str):
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi')
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run Kyutai TTS using the PyTorch implementation"
|
||||||
|
)
|
||||||
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
|
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
|
||||||
parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio")
|
parser.add_argument(
|
||||||
parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO,
|
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||||
help="HF repo in which to look for the pretrained models.")
|
)
|
||||||
parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO,
|
parser.add_argument(
|
||||||
help="HF repo in which to look for pre-computed voice embeddings.")
|
"--hf-repo",
|
||||||
parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav")
|
type=str,
|
||||||
parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.")
|
default=DEFAULT_DSM_TTS_REPO,
|
||||||
|
help="HF repo in which to look for the pretrained models.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice-repo",
|
||||||
|
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize",
|
||||||
|
type=int,
|
||||||
|
help="The quantization to be applied, e.g. 8 for 8 bits.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
mx.random.seed(299792458)
|
mx.random.seed(299792458)
|
||||||
|
|
@ -96,7 +114,7 @@ def main():
|
||||||
if tts_model.valid_cfg_conditionings:
|
if tts_model.valid_cfg_conditionings:
|
||||||
# Model was trained with CFG distillation.
|
# Model was trained with CFG distillation.
|
||||||
cfg_coef_conditioning = tts_model.cfg_coef
|
cfg_coef_conditioning = tts_model.cfg_coef
|
||||||
tts_model.cfg_coef = 1.
|
tts_model.cfg_coef = 1.0
|
||||||
cfg_is_no_text = False
|
cfg_is_no_text = False
|
||||||
cfg_is_no_prefix = False
|
cfg_is_no_prefix = False
|
||||||
else:
|
else:
|
||||||
|
|
@ -105,17 +123,25 @@ def main():
|
||||||
mimi = tts_model.mimi
|
mimi = tts_model.mimi
|
||||||
|
|
||||||
log("info", f"reading input from {args.inp}")
|
log("info", f"reading input from {args.inp}")
|
||||||
with open(args.inp, "r") as fobj:
|
if args.inp == "-":
|
||||||
text_to_tts = fobj.read().strip()
|
if sys.stdin.isatty(): # Interactive
|
||||||
|
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||||
|
text_to_tts = sys.stdin.read().strip()
|
||||||
|
else:
|
||||||
|
with open(args.inp, "r") as fobj:
|
||||||
|
text_to_tts = fobj.read().strip()
|
||||||
|
|
||||||
all_entries = [tts_model.prepare_script([text_to_tts])]
|
all_entries = [tts_model.prepare_script([text_to_tts])]
|
||||||
if tts_model.multi_speaker:
|
if tts_model.multi_speaker:
|
||||||
voices = [tts_model.get_voice_path(args.voice)]
|
voices = [tts_model.get_voice_path(args.voice)]
|
||||||
else:
|
else:
|
||||||
voices = []
|
voices = []
|
||||||
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
|
all_attributes = [
|
||||||
|
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
|
||||||
|
]
|
||||||
|
|
||||||
wav_frames = queue.Queue()
|
wav_frames = queue.Queue()
|
||||||
|
|
||||||
def _on_audio_hook(audio_tokens):
|
def _on_audio_hook(audio_tokens):
|
||||||
if (audio_tokens == -1).any():
|
if (audio_tokens == -1).any():
|
||||||
return
|
return
|
||||||
|
|
@ -141,16 +167,20 @@ def main():
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if args.out == "-":
|
if args.out == "-":
|
||||||
|
|
||||||
def audio_callback(outdata, _a, _b, _c):
|
def audio_callback(outdata, _a, _b, _c):
|
||||||
try:
|
try:
|
||||||
pcm_data = wav_frames.get(block=False)
|
pcm_data = wav_frames.get(block=False)
|
||||||
outdata[:, 0] = pcm_data
|
outdata[:, 0] = pcm_data
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
outdata[:] = 0
|
outdata[:] = 0
|
||||||
with sd.OutputStream(samplerate=mimi.sample_rate,
|
|
||||||
blocksize=1920,
|
with sd.OutputStream(
|
||||||
channels=1,
|
samplerate=mimi.sample_rate,
|
||||||
callback=audio_callback):
|
blocksize=1920,
|
||||||
|
channels=1,
|
||||||
|
callback=audio_callback,
|
||||||
|
):
|
||||||
run()
|
run()
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
while True:
|
while True:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user