diff --git a/scripts/tts_mlx_streaming.py b/scripts/tts_mlx_streaming.py index 181888d..6106d30 100644 --- a/scripts/tts_mlx_streaming.py +++ b/scripts/tts_mlx_streaming.py @@ -76,7 +76,8 @@ class TTSGen: if tts_model.valid_cfg_conditionings: raise ValueError( "This model does not support direct CFG, but was trained with " - "CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`.") + "CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`." + ) nulled = _make_null(attributes) attributes = list(attributes) + nulled @@ -126,7 +127,6 @@ class TTSGen: # cfg_is_no_text=cfg_is_no_text, ) - def process_last(self): while len(self.state.entries) > 0 or self.state.end_step is not None: self._step() @@ -143,8 +143,13 @@ class TTSGen: def _step(self): missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q - input_tokens = mx.ones((1, missing), dtype=mx.int64) * self.tts_model.machine.token_ids.zero - self.lm_gen.step(input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src) + input_tokens = ( + mx.ones((1, missing), dtype=mx.int64) + * self.tts_model.machine.token_ids.zero + ) + self.lm_gen.step( + input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src + ) frame = self.lm_gen.last_audio_tokens() self.offset += 1 if frame is not None: @@ -307,4 +312,3 @@ def main(): if __name__ == "__main__": main() -