diff --git a/scripts/tts_pytorch_streaming.py b/scripts/tts_pytorch_streaming.py index 20fcdd7..d808023 100644 --- a/scripts/tts_pytorch_streaming.py +++ b/scripts/tts_pytorch_streaming.py @@ -20,10 +20,17 @@ import typing as tp from moshi.models.loaders import CheckpointInfo from moshi.conditioners import dropout_all_conditions from moshi.models.lm import LMGen -from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel, ConditionAttributes +from moshi.models.tts import ( + DEFAULT_DSM_TTS_REPO, + DEFAULT_DSM_TTS_VOICE_REPO, + TTSModel, + ConditionAttributes, +) -def _make_null(all_attributes: tp.Sequence[ConditionAttributes]) -> list[ConditionAttributes]: +def _make_null( + all_attributes: tp.Sequence[ConditionAttributes], +) -> list[ConditionAttributes]: # When using CFG, returns the null conditions. return dropout_all_conditions(all_attributes) @@ -43,7 +50,8 @@ class TTSGen: if tts_model.valid_cfg_conditionings: raise ValueError( "This model does not support direct CFG, but was trained with " - "CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`.") + "CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`." + ) nulled = _make_null(attributes) attributes = list(attributes) + nulled @@ -53,7 +61,9 @@ class TTSGen: def _on_text_logits_hook(text_logits): if tts_model.padding_bonus: - text_logits[..., tts_model.machine.token_ids.pad] += tts_model.padding_bonus + text_logits[..., tts_model.machine.token_ids.pad] += ( + tts_model.padding_bonus + ) return text_logits def _on_audio_hook(audio_tokens): @@ -70,28 +80,51 @@ class TTSGen: for token in tokens: out_token, _ = tts_model.machine.process(self.offset, self.state, token) out_tokens.append(out_token) - text_tokens[:] = torch.tensor(out_tokens, dtype=torch.long, device=text_tokens.device) + text_tokens[:] = torch.tensor( + out_tokens, dtype=torch.long, device=text_tokens.device + ) tts_model.lm.dep_q = tts_model.n_q self.lm_gen = LMGen( - tts_model.lm, temp=tts_model.temp, temp_text=tts_model.temp, - cfg_coef=tts_model.cfg_coef, condition_tensors=condition_tensors, - on_text_logits_hook=_on_text_logits_hook, on_text_hook=_on_text_hook, on_audio_hook=_on_audio_hook, - cfg_is_masked_until=None, cfg_is_no_text=True) + tts_model.lm, + temp=tts_model.temp, + temp_text=tts_model.temp, + cfg_coef=tts_model.cfg_coef, + condition_tensors=condition_tensors, + on_text_logits_hook=_on_text_logits_hook, + on_text_hook=_on_text_hook, + on_audio_hook=_on_audio_hook, + cfg_is_masked_until=None, + cfg_is_no_text=True, + ) self.lm_gen.streaming_forever(1) + def process_last(self): + while len(self.state.entries) > 0 or self.state.end_step is not None: + self._step() + additional_steps = ( + self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8 + ) + for _ in range(additional_steps): + self._step() def process(self): while len(self.state.entries) > self.tts_model.machine.second_stream_ahead: - missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q - input_tokens = torch.full((1, missing, 1), self.tts_model.machine.token_ids.zero, - dtype=torch.long, device=self.tts_model.lm.device) - frame = self.lm_gen.step(input_tokens) - self.offset += 1 - if frame is not None: - if self.on_frame is not None: - self.on_frame(frame) + self._step() + def _step(self): + missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q + input_tokens = torch.full( + (1, missing, 1), + self.tts_model.machine.token_ids.zero, + dtype=torch.long, + device=self.tts_model.lm.device, + ) + frame = self.lm_gen.step(input_tokens) + self.offset += 1 + if frame is not None: + if self.on_frame is not None: + self.on_frame(frame) def append_entry(self, entry): self.state.entries.append(entry) @@ -179,7 +212,7 @@ def main(): for entry in entries: gen.append_entry(entry) gen.process() - time.sleep(3) + gen.process_last() while True: if pcms.qsize() == 0: break @@ -200,10 +233,10 @@ def main(): for entry in entries: gen.append_entry(entry) gen.process() + gen.process_last() pcm = np.concatenate(pcms, axis=-1) sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate) if __name__ == "__main__": main() -