diff --git a/scripts/tts_mlx.py b/scripts/tts_mlx.py index 050fce5..8e08677 100644 --- a/scripts/tts_mlx.py +++ b/scripts/tts_mlx.py @@ -10,23 +10,24 @@ import argparse import json -from pathlib import Path import queue +import sys import time -import numpy as np import mlx.core as mx import mlx.nn as nn +import numpy as np import sentencepiece -import sphn -import time - import sounddevice as sd - -from moshi_mlx.client_utils import make_log +import sphn from moshi_mlx import models +from moshi_mlx.client_utils import make_log +from moshi_mlx.models.tts import ( + DEFAULT_DSM_TTS_REPO, + DEFAULT_DSM_TTS_VOICE_REPO, + TTSModel, +) from moshi_mlx.utils.loaders import hf_get -from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO def log(level: str, msg: str): @@ -34,15 +35,32 @@ def log(level: str, msg: str): def main(): - parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi') + parser = argparse.ArgumentParser( + description="Run Kyutai TTS using the PyTorch implementation" + ) parser.add_argument("inp", type=str, help="Input file, use - for stdin") - parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio") - parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO, - help="HF repo in which to look for the pretrained models.") - parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO, - help="HF repo in which to look for pre-computed voice embeddings.") - parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav") - parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.") + parser.add_argument( + "out", type=str, help="Output file to generate, use - for playing the audio" + ) + parser.add_argument( + "--hf-repo", + type=str, + default=DEFAULT_DSM_TTS_REPO, + help="HF repo in which to look for the pretrained models.", + ) + parser.add_argument( + "--voice-repo", + default=DEFAULT_DSM_TTS_VOICE_REPO, + help="HF repo in which to look for pre-computed voice embeddings.", + ) + parser.add_argument( + "--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav" + ) + parser.add_argument( + "--quantize", + type=int, + help="The quantization to be applied, e.g. 8 for 8 bits.", + ) args = parser.parse_args() mx.random.seed(299792458) @@ -96,7 +114,7 @@ def main(): if tts_model.valid_cfg_conditionings: # Model was trained with CFG distillation. cfg_coef_conditioning = tts_model.cfg_coef - tts_model.cfg_coef = 1. + tts_model.cfg_coef = 1.0 cfg_is_no_text = False cfg_is_no_prefix = False else: @@ -105,17 +123,25 @@ def main(): mimi = tts_model.mimi log("info", f"reading input from {args.inp}") - with open(args.inp, "r") as fobj: - text_to_tts = fobj.read().strip() + if args.inp == "-": + if sys.stdin.isatty(): # Interactive + print("Enter text to synthesize (Ctrl+D to end input):") + text_to_tts = sys.stdin.read().strip() + else: + with open(args.inp, "r") as fobj: + text_to_tts = fobj.read().strip() all_entries = [tts_model.prepare_script([text_to_tts])] if tts_model.multi_speaker: voices = [tts_model.get_voice_path(args.voice)] else: voices = [] - all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)] + all_attributes = [ + tts_model.make_condition_attributes(voices, cfg_coef_conditioning) + ] wav_frames = queue.Queue() + def _on_audio_hook(audio_tokens): if (audio_tokens == -1).any(): return @@ -141,16 +167,20 @@ def main(): return result if args.out == "-": + def audio_callback(outdata, _a, _b, _c): try: pcm_data = wav_frames.get(block=False) outdata[:, 0] = pcm_data except queue.Empty: outdata[:] = 0 - with sd.OutputStream(samplerate=mimi.sample_rate, - blocksize=1920, - channels=1, - callback=audio_callback): + + with sd.OutputStream( + samplerate=mimi.sample_rate, + blocksize=1920, + channels=1, + callback=audio_callback, + ): run() time.sleep(3) while True: