Get the streaming example to work.
This commit is contained in:
parent
49ce18be2e
commit
edce117900
|
|
@ -20,10 +20,17 @@ import typing as tp
|
|||
from moshi.models.loaders import CheckpointInfo
|
||||
from moshi.conditioners import dropout_all_conditions
|
||||
from moshi.models.lm import LMGen
|
||||
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel, ConditionAttributes
|
||||
from moshi.models.tts import (
|
||||
DEFAULT_DSM_TTS_REPO,
|
||||
DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
TTSModel,
|
||||
ConditionAttributes,
|
||||
)
|
||||
|
||||
|
||||
def _make_null(all_attributes: tp.Sequence[ConditionAttributes]) -> list[ConditionAttributes]:
|
||||
def _make_null(
|
||||
all_attributes: tp.Sequence[ConditionAttributes],
|
||||
) -> list[ConditionAttributes]:
|
||||
# When using CFG, returns the null conditions.
|
||||
return dropout_all_conditions(all_attributes)
|
||||
|
||||
|
|
@ -43,7 +50,8 @@ class TTSGen:
|
|||
if tts_model.valid_cfg_conditionings:
|
||||
raise ValueError(
|
||||
"This model does not support direct CFG, but was trained with "
|
||||
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`.")
|
||||
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
|
||||
)
|
||||
nulled = _make_null(attributes)
|
||||
attributes = list(attributes) + nulled
|
||||
|
||||
|
|
@ -53,7 +61,9 @@ class TTSGen:
|
|||
|
||||
def _on_text_logits_hook(text_logits):
|
||||
if tts_model.padding_bonus:
|
||||
text_logits[..., tts_model.machine.token_ids.pad] += tts_model.padding_bonus
|
||||
text_logits[..., tts_model.machine.token_ids.pad] += (
|
||||
tts_model.padding_bonus
|
||||
)
|
||||
return text_logits
|
||||
|
||||
def _on_audio_hook(audio_tokens):
|
||||
|
|
@ -70,29 +80,52 @@ class TTSGen:
|
|||
for token in tokens:
|
||||
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
|
||||
out_tokens.append(out_token)
|
||||
text_tokens[:] = torch.tensor(out_tokens, dtype=torch.long, device=text_tokens.device)
|
||||
text_tokens[:] = torch.tensor(
|
||||
out_tokens, dtype=torch.long, device=text_tokens.device
|
||||
)
|
||||
|
||||
tts_model.lm.dep_q = tts_model.n_q
|
||||
self.lm_gen = LMGen(
|
||||
tts_model.lm, temp=tts_model.temp, temp_text=tts_model.temp,
|
||||
cfg_coef=tts_model.cfg_coef, condition_tensors=condition_tensors,
|
||||
on_text_logits_hook=_on_text_logits_hook, on_text_hook=_on_text_hook, on_audio_hook=_on_audio_hook,
|
||||
cfg_is_masked_until=None, cfg_is_no_text=True)
|
||||
tts_model.lm,
|
||||
temp=tts_model.temp,
|
||||
temp_text=tts_model.temp,
|
||||
cfg_coef=tts_model.cfg_coef,
|
||||
condition_tensors=condition_tensors,
|
||||
on_text_logits_hook=_on_text_logits_hook,
|
||||
on_text_hook=_on_text_hook,
|
||||
on_audio_hook=_on_audio_hook,
|
||||
cfg_is_masked_until=None,
|
||||
cfg_is_no_text=True,
|
||||
)
|
||||
self.lm_gen.streaming_forever(1)
|
||||
|
||||
def process_last(self):
|
||||
while len(self.state.entries) > 0 or self.state.end_step is not None:
|
||||
self._step()
|
||||
additional_steps = (
|
||||
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
|
||||
)
|
||||
for _ in range(additional_steps):
|
||||
self._step()
|
||||
|
||||
def process(self):
|
||||
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
|
||||
self._step()
|
||||
|
||||
def _step(self):
|
||||
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
||||
input_tokens = torch.full((1, missing, 1), self.tts_model.machine.token_ids.zero,
|
||||
dtype=torch.long, device=self.tts_model.lm.device)
|
||||
input_tokens = torch.full(
|
||||
(1, missing, 1),
|
||||
self.tts_model.machine.token_ids.zero,
|
||||
dtype=torch.long,
|
||||
device=self.tts_model.lm.device,
|
||||
)
|
||||
frame = self.lm_gen.step(input_tokens)
|
||||
self.offset += 1
|
||||
if frame is not None:
|
||||
if self.on_frame is not None:
|
||||
self.on_frame(frame)
|
||||
|
||||
|
||||
def append_entry(self, entry):
|
||||
self.state.entries.append(entry)
|
||||
|
||||
|
|
@ -179,7 +212,7 @@ def main():
|
|||
for entry in entries:
|
||||
gen.append_entry(entry)
|
||||
gen.process()
|
||||
time.sleep(3)
|
||||
gen.process_last()
|
||||
while True:
|
||||
if pcms.qsize() == 0:
|
||||
break
|
||||
|
|
@ -200,10 +233,10 @@ def main():
|
|||
for entry in entries:
|
||||
gen.append_entry(entry)
|
||||
gen.process()
|
||||
gen.process_last()
|
||||
pcm = np.concatenate(pcms, axis=-1)
|
||||
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user