Get the streaming example to work.
This commit is contained in:
parent
49ce18be2e
commit
edce117900
|
|
@ -20,10 +20,17 @@ import typing as tp
|
||||||
from moshi.models.loaders import CheckpointInfo
|
from moshi.models.loaders import CheckpointInfo
|
||||||
from moshi.conditioners import dropout_all_conditions
|
from moshi.conditioners import dropout_all_conditions
|
||||||
from moshi.models.lm import LMGen
|
from moshi.models.lm import LMGen
|
||||||
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel, ConditionAttributes
|
from moshi.models.tts import (
|
||||||
|
DEFAULT_DSM_TTS_REPO,
|
||||||
|
DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
TTSModel,
|
||||||
|
ConditionAttributes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_null(all_attributes: tp.Sequence[ConditionAttributes]) -> list[ConditionAttributes]:
|
def _make_null(
|
||||||
|
all_attributes: tp.Sequence[ConditionAttributes],
|
||||||
|
) -> list[ConditionAttributes]:
|
||||||
# When using CFG, returns the null conditions.
|
# When using CFG, returns the null conditions.
|
||||||
return dropout_all_conditions(all_attributes)
|
return dropout_all_conditions(all_attributes)
|
||||||
|
|
||||||
|
|
@ -43,7 +50,8 @@ class TTSGen:
|
||||||
if tts_model.valid_cfg_conditionings:
|
if tts_model.valid_cfg_conditionings:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This model does not support direct CFG, but was trained with "
|
"This model does not support direct CFG, but was trained with "
|
||||||
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`.")
|
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
|
||||||
|
)
|
||||||
nulled = _make_null(attributes)
|
nulled = _make_null(attributes)
|
||||||
attributes = list(attributes) + nulled
|
attributes = list(attributes) + nulled
|
||||||
|
|
||||||
|
|
@ -53,7 +61,9 @@ class TTSGen:
|
||||||
|
|
||||||
def _on_text_logits_hook(text_logits):
|
def _on_text_logits_hook(text_logits):
|
||||||
if tts_model.padding_bonus:
|
if tts_model.padding_bonus:
|
||||||
text_logits[..., tts_model.machine.token_ids.pad] += tts_model.padding_bonus
|
text_logits[..., tts_model.machine.token_ids.pad] += (
|
||||||
|
tts_model.padding_bonus
|
||||||
|
)
|
||||||
return text_logits
|
return text_logits
|
||||||
|
|
||||||
def _on_audio_hook(audio_tokens):
|
def _on_audio_hook(audio_tokens):
|
||||||
|
|
@ -70,29 +80,52 @@ class TTSGen:
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
|
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
|
||||||
out_tokens.append(out_token)
|
out_tokens.append(out_token)
|
||||||
text_tokens[:] = torch.tensor(out_tokens, dtype=torch.long, device=text_tokens.device)
|
text_tokens[:] = torch.tensor(
|
||||||
|
out_tokens, dtype=torch.long, device=text_tokens.device
|
||||||
|
)
|
||||||
|
|
||||||
tts_model.lm.dep_q = tts_model.n_q
|
tts_model.lm.dep_q = tts_model.n_q
|
||||||
self.lm_gen = LMGen(
|
self.lm_gen = LMGen(
|
||||||
tts_model.lm, temp=tts_model.temp, temp_text=tts_model.temp,
|
tts_model.lm,
|
||||||
cfg_coef=tts_model.cfg_coef, condition_tensors=condition_tensors,
|
temp=tts_model.temp,
|
||||||
on_text_logits_hook=_on_text_logits_hook, on_text_hook=_on_text_hook, on_audio_hook=_on_audio_hook,
|
temp_text=tts_model.temp,
|
||||||
cfg_is_masked_until=None, cfg_is_no_text=True)
|
cfg_coef=tts_model.cfg_coef,
|
||||||
|
condition_tensors=condition_tensors,
|
||||||
|
on_text_logits_hook=_on_text_logits_hook,
|
||||||
|
on_text_hook=_on_text_hook,
|
||||||
|
on_audio_hook=_on_audio_hook,
|
||||||
|
cfg_is_masked_until=None,
|
||||||
|
cfg_is_no_text=True,
|
||||||
|
)
|
||||||
self.lm_gen.streaming_forever(1)
|
self.lm_gen.streaming_forever(1)
|
||||||
|
|
||||||
|
def process_last(self):
|
||||||
|
while len(self.state.entries) > 0 or self.state.end_step is not None:
|
||||||
|
self._step()
|
||||||
|
additional_steps = (
|
||||||
|
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
|
||||||
|
)
|
||||||
|
for _ in range(additional_steps):
|
||||||
|
self._step()
|
||||||
|
|
||||||
def process(self):
|
def process(self):
|
||||||
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
|
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
|
||||||
|
self._step()
|
||||||
|
|
||||||
|
def _step(self):
|
||||||
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
||||||
input_tokens = torch.full((1, missing, 1), self.tts_model.machine.token_ids.zero,
|
input_tokens = torch.full(
|
||||||
dtype=torch.long, device=self.tts_model.lm.device)
|
(1, missing, 1),
|
||||||
|
self.tts_model.machine.token_ids.zero,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.tts_model.lm.device,
|
||||||
|
)
|
||||||
frame = self.lm_gen.step(input_tokens)
|
frame = self.lm_gen.step(input_tokens)
|
||||||
self.offset += 1
|
self.offset += 1
|
||||||
if frame is not None:
|
if frame is not None:
|
||||||
if self.on_frame is not None:
|
if self.on_frame is not None:
|
||||||
self.on_frame(frame)
|
self.on_frame(frame)
|
||||||
|
|
||||||
|
|
||||||
def append_entry(self, entry):
|
def append_entry(self, entry):
|
||||||
self.state.entries.append(entry)
|
self.state.entries.append(entry)
|
||||||
|
|
||||||
|
|
@ -179,7 +212,7 @@ def main():
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
gen.append_entry(entry)
|
gen.append_entry(entry)
|
||||||
gen.process()
|
gen.process()
|
||||||
time.sleep(3)
|
gen.process_last()
|
||||||
while True:
|
while True:
|
||||||
if pcms.qsize() == 0:
|
if pcms.qsize() == 0:
|
||||||
break
|
break
|
||||||
|
|
@ -200,10 +233,10 @@ def main():
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
gen.append_entry(entry)
|
gen.append_entry(entry)
|
||||||
gen.process()
|
gen.process()
|
||||||
|
gen.process_last()
|
||||||
pcm = np.concatenate(pcms, axis=-1)
|
pcm = np.concatenate(pcms, axis=-1)
|
||||||
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
|
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user