Formatting fix.
This commit is contained in:
parent
00b856a94d
commit
45f01cc617
|
|
@ -76,7 +76,8 @@ class TTSGen:
|
||||||
if tts_model.valid_cfg_conditionings:
|
if tts_model.valid_cfg_conditionings:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This model does not support direct CFG, but was trained with "
|
"This model does not support direct CFG, but was trained with "
|
||||||
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`.")
|
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
|
||||||
|
)
|
||||||
nulled = _make_null(attributes)
|
nulled = _make_null(attributes)
|
||||||
attributes = list(attributes) + nulled
|
attributes = list(attributes) + nulled
|
||||||
|
|
||||||
|
|
@ -126,7 +127,6 @@ class TTSGen:
|
||||||
# cfg_is_no_text=cfg_is_no_text,
|
# cfg_is_no_text=cfg_is_no_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def process_last(self):
|
def process_last(self):
|
||||||
while len(self.state.entries) > 0 or self.state.end_step is not None:
|
while len(self.state.entries) > 0 or self.state.end_step is not None:
|
||||||
self._step()
|
self._step()
|
||||||
|
|
@ -143,8 +143,13 @@ class TTSGen:
|
||||||
def _step(self):
|
def _step(self):
|
||||||
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
||||||
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
||||||
input_tokens = mx.ones((1, missing), dtype=mx.int64) * self.tts_model.machine.token_ids.zero
|
input_tokens = (
|
||||||
self.lm_gen.step(input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src)
|
mx.ones((1, missing), dtype=mx.int64)
|
||||||
|
* self.tts_model.machine.token_ids.zero
|
||||||
|
)
|
||||||
|
self.lm_gen.step(
|
||||||
|
input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src
|
||||||
|
)
|
||||||
frame = self.lm_gen.last_audio_tokens()
|
frame = self.lm_gen.last_audio_tokens()
|
||||||
self.offset += 1
|
self.offset += 1
|
||||||
if frame is not None:
|
if frame is not None:
|
||||||
|
|
@ -307,4 +312,3 @@ def main():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user