diff --git a/scripts/tts_mlx.py b/scripts/tts_mlx.py index af2b166..47faca6 100644 --- a/scripts/tts_mlx.py +++ b/scripts/tts_mlx.py @@ -76,6 +76,9 @@ def main(): moshi_weights = hf_get(moshi_name, args.hf_repo) tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo) lm_config = models.LmConfig.from_config_dict(raw_config) + # There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache. + # The following line gets around it for now. + lm_config.transformer.max_seq_len = lm_config.transformer.context model = models.Lm(lm_config) model.set_dtype(mx.bfloat16) diff --git a/scripts/tts_mlx_streaming.py b/scripts/tts_mlx_streaming.py index dc66156..bbdea7e 100644 --- a/scripts/tts_mlx_streaming.py +++ b/scripts/tts_mlx_streaming.py @@ -205,6 +205,9 @@ def main(): moshi_weights = hf_get(moshi_name, args.hf_repo) tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo) lm_config = models.LmConfig.from_config_dict(raw_config) + # There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache. + # The following line gets around it for now. + lm_config.transformer.max_seq_len = lm_config.transformer.context model = models.Lm(lm_config) model.set_dtype(mx.bfloat16)