Formatting fix.
This commit is contained in:
parent
00b856a94d
commit
45f01cc617
|
|
@ -76,7 +76,8 @@ class TTSGen:
|
|||
if tts_model.valid_cfg_conditionings:
|
||||
raise ValueError(
|
||||
"This model does not support direct CFG, but was trained with "
|
||||
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`.")
|
||||
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
|
||||
)
|
||||
nulled = _make_null(attributes)
|
||||
attributes = list(attributes) + nulled
|
||||
|
||||
|
|
@ -126,7 +127,6 @@ class TTSGen:
|
|||
# cfg_is_no_text=cfg_is_no_text,
|
||||
)
|
||||
|
||||
|
||||
def process_last(self):
|
||||
while len(self.state.entries) > 0 or self.state.end_step is not None:
|
||||
self._step()
|
||||
|
|
@ -143,8 +143,13 @@ class TTSGen:
|
|||
def _step(self):
|
||||
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
||||
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
||||
input_tokens = mx.ones((1, missing), dtype=mx.int64) * self.tts_model.machine.token_ids.zero
|
||||
self.lm_gen.step(input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src)
|
||||
input_tokens = (
|
||||
mx.ones((1, missing), dtype=mx.int64)
|
||||
* self.tts_model.machine.token_ids.zero
|
||||
)
|
||||
self.lm_gen.step(
|
||||
input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src
|
||||
)
|
||||
frame = self.lm_gen.last_audio_tokens()
|
||||
self.offset += 1
|
||||
if frame is not None:
|
||||
|
|
@ -307,4 +312,3 @@ def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user