diff --git a/scripts/tts_pytorch_streaming.py b/scripts/tts_pytorch_streaming.py new file mode 100644 index 0000000..d808023 --- /dev/null +++ b/scripts/tts_pytorch_streaming.py @@ -0,0 +1,242 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "moshi==0.2.10", +# "torch", +# "sphn", +# "sounddevice", +# ] +# /// +import argparse +from dataclasses import dataclass +import sys + +import numpy as np +import queue +import sphn +import time +import torch +import typing as tp +from moshi.models.loaders import CheckpointInfo +from moshi.conditioners import dropout_all_conditions +from moshi.models.lm import LMGen +from moshi.models.tts import ( + DEFAULT_DSM_TTS_REPO, + DEFAULT_DSM_TTS_VOICE_REPO, + TTSModel, + ConditionAttributes, +) + + +def _make_null( + all_attributes: tp.Sequence[ConditionAttributes], +) -> list[ConditionAttributes]: + # When using CFG, returns the null conditions. + return dropout_all_conditions(all_attributes) + + +@dataclass +class TTSGen: + tts_model: TTSModel + attributes: tp.Sequence[ConditionAttributes] + on_frame: tp.Optional[tp.Callable[[torch.Tensor], None]] = None + + def __post_init__(self): + tts_model = self.tts_model + attributes = self.attributes + self.offset = 0 + self.state = self.tts_model.machine.new_state([]) + if tts_model.cfg_coef != 1.0: + if tts_model.valid_cfg_conditionings: + raise ValueError( + "This model does not support direct CFG, but was trained with " + "CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`." + ) + nulled = _make_null(attributes) + attributes = list(attributes) + nulled + + assert tts_model.lm.condition_provider is not None + prepared = tts_model.lm.condition_provider.prepare(attributes) + condition_tensors = tts_model.lm.condition_provider(prepared) + + def _on_text_logits_hook(text_logits): + if tts_model.padding_bonus: + text_logits[..., tts_model.machine.token_ids.pad] += ( + tts_model.padding_bonus + ) + return text_logits + + def _on_audio_hook(audio_tokens): + audio_offset = tts_model.lm.audio_offset + delays = tts_model.lm.delays + for q in range(audio_tokens.shape[1]): + delay = delays[q + audio_offset] + if self.offset < delay + tts_model.delay_steps: + audio_tokens[:, q] = tts_model.machine.token_ids.zero + + def _on_text_hook(text_tokens): + tokens = text_tokens.tolist() + out_tokens = [] + for token in tokens: + out_token, _ = tts_model.machine.process(self.offset, self.state, token) + out_tokens.append(out_token) + text_tokens[:] = torch.tensor( + out_tokens, dtype=torch.long, device=text_tokens.device + ) + + tts_model.lm.dep_q = tts_model.n_q + self.lm_gen = LMGen( + tts_model.lm, + temp=tts_model.temp, + temp_text=tts_model.temp, + cfg_coef=tts_model.cfg_coef, + condition_tensors=condition_tensors, + on_text_logits_hook=_on_text_logits_hook, + on_text_hook=_on_text_hook, + on_audio_hook=_on_audio_hook, + cfg_is_masked_until=None, + cfg_is_no_text=True, + ) + self.lm_gen.streaming_forever(1) + + def process_last(self): + while len(self.state.entries) > 0 or self.state.end_step is not None: + self._step() + additional_steps = ( + self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8 + ) + for _ in range(additional_steps): + self._step() + + def process(self): + while len(self.state.entries) > self.tts_model.machine.second_stream_ahead: + self._step() + + def _step(self): + missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q + input_tokens = torch.full( + (1, missing, 1), + self.tts_model.machine.token_ids.zero, + dtype=torch.long, + device=self.tts_model.lm.device, + ) + frame = self.lm_gen.step(input_tokens) + self.offset += 1 + if frame is not None: + if self.on_frame is not None: + self.on_frame(frame) + + def append_entry(self, entry): + self.state.entries.append(entry) + + +@torch.no_grad() +def main(): + parser = argparse.ArgumentParser( + description="Run Kyutai TTS using the PyTorch implementation" + ) + parser.add_argument( + "out", type=str, help="Output file to generate, use - for playing the audio" + ) + parser.add_argument( + "--hf-repo", + type=str, + default=DEFAULT_DSM_TTS_REPO, + help="HF repo in which to look for the pretrained models.", + ) + parser.add_argument( + "--voice-repo", + default=DEFAULT_DSM_TTS_VOICE_REPO, + help="HF repo in which to look for pre-computed voice embeddings.", + ) + parser.add_argument( + "--voice", + default="expresso/ex03-ex01_happy_001_channel1_334s.wav", + help="The voice to use, relative to the voice repo root. " + f"See {DEFAULT_DSM_TTS_VOICE_REPO}", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device on which to run, defaults to 'cuda'.", + ) + args = parser.parse_args() + + print("Loading model...") + checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo) + tts_model = TTSModel.from_checkpoint_info( + checkpoint_info, n_q=32, temp=0.6, device=args.device + ) + + voice_path = tts_model.get_voice_path(args.voice) + # CFG coef goes here because the model was trained with CFG distillation, + # so it's not _actually_ doing CFG at inference time. + # Also, if you are generating a dialog, you should have two voices in the list. + condition_attributes = tts_model.make_condition_attributes( + [voice_path], cfg_coef=2.0 + ) + + if sys.stdin.isatty(): # Interactive + print("Enter text to synthesize (Ctrl+D to end input):") + + if args.out == "-": + # Stream the audio to the speakers using sounddevice. + import sounddevice as sd + + pcms = queue.Queue() + + def _on_frame(frame): + if (frame != -1).all(): + pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() + pcms.put_nowait(np.clip(pcm[0, 0], -1, 1)) + + def audio_callback(outdata, _a, _b, _c): + try: + pcm_data = pcms.get(block=False) + outdata[:, 0] = pcm_data + except queue.Empty: + outdata[:] = 0 + + gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame) + + with sd.OutputStream( + samplerate=tts_model.mimi.sample_rate, + blocksize=1920, + channels=1, + callback=audio_callback, + ) and tts_model.mimi.streaming(1): + for line in sys.stdin: + # TODO: Fix the following to only include bos on the first line. + entries = tts_model.prepare_script([line.strip()], padding_between=1) + for entry in entries: + gen.append_entry(entry) + gen.process() + gen.process_last() + while True: + if pcms.qsize() == 0: + break + time.sleep(1) + else: + pcms = [] + + def _on_frame(frame: torch.Tensor): + if (frame != -1).all(): + pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() + pcms.append(np.clip(pcm[0, 0])) + + gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame) + with tts_model.mimi.streaming(1): + for line in sys.stdin: + # TODO: Fix the following to only include bos on the first line. + entries = tts_model.prepare_script([line.strip()], padding_between=1) + for entry in entries: + gen.append_entry(entry) + gen.process() + gen.process_last() + pcm = np.concatenate(pcms, axis=-1) + sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate) + + +if __name__ == "__main__": + main()