Get the streaming example to work.

This commit is contained in:
laurent 2025-07-16 20:52:32 +02:00
parent 49ce18be2e
commit edce117900

View File

@ -20,10 +20,17 @@ import typing as tp
from moshi.models.loaders import CheckpointInfo from moshi.models.loaders import CheckpointInfo
from moshi.conditioners import dropout_all_conditions from moshi.conditioners import dropout_all_conditions
from moshi.models.lm import LMGen from moshi.models.lm import LMGen
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel, ConditionAttributes from moshi.models.tts import (
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
ConditionAttributes,
)
def _make_null(all_attributes: tp.Sequence[ConditionAttributes]) -> list[ConditionAttributes]: def _make_null(
all_attributes: tp.Sequence[ConditionAttributes],
) -> list[ConditionAttributes]:
# When using CFG, returns the null conditions. # When using CFG, returns the null conditions.
return dropout_all_conditions(all_attributes) return dropout_all_conditions(all_attributes)
@ -43,7 +50,8 @@ class TTSGen:
if tts_model.valid_cfg_conditionings: if tts_model.valid_cfg_conditionings:
raise ValueError( raise ValueError(
"This model does not support direct CFG, but was trained with " "This model does not support direct CFG, but was trained with "
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`.") "CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
)
nulled = _make_null(attributes) nulled = _make_null(attributes)
attributes = list(attributes) + nulled attributes = list(attributes) + nulled
@ -53,7 +61,9 @@ class TTSGen:
def _on_text_logits_hook(text_logits): def _on_text_logits_hook(text_logits):
if tts_model.padding_bonus: if tts_model.padding_bonus:
text_logits[..., tts_model.machine.token_ids.pad] += tts_model.padding_bonus text_logits[..., tts_model.machine.token_ids.pad] += (
tts_model.padding_bonus
)
return text_logits return text_logits
def _on_audio_hook(audio_tokens): def _on_audio_hook(audio_tokens):
@ -70,29 +80,52 @@ class TTSGen:
for token in tokens: for token in tokens:
out_token, _ = tts_model.machine.process(self.offset, self.state, token) out_token, _ = tts_model.machine.process(self.offset, self.state, token)
out_tokens.append(out_token) out_tokens.append(out_token)
text_tokens[:] = torch.tensor(out_tokens, dtype=torch.long, device=text_tokens.device) text_tokens[:] = torch.tensor(
out_tokens, dtype=torch.long, device=text_tokens.device
)
tts_model.lm.dep_q = tts_model.n_q tts_model.lm.dep_q = tts_model.n_q
self.lm_gen = LMGen( self.lm_gen = LMGen(
tts_model.lm, temp=tts_model.temp, temp_text=tts_model.temp, tts_model.lm,
cfg_coef=tts_model.cfg_coef, condition_tensors=condition_tensors, temp=tts_model.temp,
on_text_logits_hook=_on_text_logits_hook, on_text_hook=_on_text_hook, on_audio_hook=_on_audio_hook, temp_text=tts_model.temp,
cfg_is_masked_until=None, cfg_is_no_text=True) cfg_coef=tts_model.cfg_coef,
condition_tensors=condition_tensors,
on_text_logits_hook=_on_text_logits_hook,
on_text_hook=_on_text_hook,
on_audio_hook=_on_audio_hook,
cfg_is_masked_until=None,
cfg_is_no_text=True,
)
self.lm_gen.streaming_forever(1) self.lm_gen.streaming_forever(1)
def process_last(self):
while len(self.state.entries) > 0 or self.state.end_step is not None:
self._step()
additional_steps = (
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
)
for _ in range(additional_steps):
self._step()
def process(self): def process(self):
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead: while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
self._step()
def _step(self):
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
input_tokens = torch.full((1, missing, 1), self.tts_model.machine.token_ids.zero, input_tokens = torch.full(
dtype=torch.long, device=self.tts_model.lm.device) (1, missing, 1),
self.tts_model.machine.token_ids.zero,
dtype=torch.long,
device=self.tts_model.lm.device,
)
frame = self.lm_gen.step(input_tokens) frame = self.lm_gen.step(input_tokens)
self.offset += 1 self.offset += 1
if frame is not None: if frame is not None:
if self.on_frame is not None: if self.on_frame is not None:
self.on_frame(frame) self.on_frame(frame)
def append_entry(self, entry): def append_entry(self, entry):
self.state.entries.append(entry) self.state.entries.append(entry)
@ -179,7 +212,7 @@ def main():
for entry in entries: for entry in entries:
gen.append_entry(entry) gen.append_entry(entry)
gen.process() gen.process()
time.sleep(3) gen.process_last()
while True: while True:
if pcms.qsize() == 0: if pcms.qsize() == 0:
break break
@ -200,10 +233,10 @@ def main():
for entry in entries: for entry in entries:
gen.append_entry(entry) gen.append_entry(entry)
gen.process() gen.process()
gen.process_last()
pcm = np.concatenate(pcms, axis=-1) pcm = np.concatenate(pcms, axis=-1)
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate) sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
if __name__ == "__main__": if __name__ == "__main__":
main() main()