Formatting fix.

This commit is contained in:
Laurent 2025-07-16 22:30:03 +02:00
parent 00b856a94d
commit 45f01cc617

View File

@ -76,7 +76,8 @@ class TTSGen:
if tts_model.valid_cfg_conditionings: if tts_model.valid_cfg_conditionings:
raise ValueError( raise ValueError(
"This model does not support direct CFG, but was trained with " "This model does not support direct CFG, but was trained with "
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`.") "CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
)
nulled = _make_null(attributes) nulled = _make_null(attributes)
attributes = list(attributes) + nulled attributes = list(attributes) + nulled
@ -126,7 +127,6 @@ class TTSGen:
# cfg_is_no_text=cfg_is_no_text, # cfg_is_no_text=cfg_is_no_text,
) )
def process_last(self): def process_last(self):
while len(self.state.entries) > 0 or self.state.end_step is not None: while len(self.state.entries) > 0 or self.state.end_step is not None:
self._step() self._step()
@ -143,8 +143,13 @@ class TTSGen:
def _step(self): def _step(self):
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
input_tokens = mx.ones((1, missing), dtype=mx.int64) * self.tts_model.machine.token_ids.zero input_tokens = (
self.lm_gen.step(input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src) mx.ones((1, missing), dtype=mx.int64)
* self.tts_model.machine.token_ids.zero
)
self.lm_gen.step(
input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src
)
frame = self.lm_gen.last_audio_tokens() frame = self.lm_gen.last_audio_tokens()
self.offset += 1 self.offset += 1
if frame is not None: if frame is not None:
@ -307,4 +312,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()