Workaround for the mlx kv-cache bug. (#108)

This commit is contained in:
Laurent Mazare 2025-08-04 16:37:00 +02:00 committed by GitHub
parent 09468c239a
commit cf97f8d863
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 0 deletions

View File

@ -76,6 +76,9 @@ def main():
moshi_weights = hf_get(moshi_name, args.hf_repo) moshi_weights = hf_get(moshi_name, args.hf_repo)
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo) tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
lm_config = models.LmConfig.from_config_dict(raw_config) lm_config = models.LmConfig.from_config_dict(raw_config)
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
# The following line gets around it for now.
lm_config.transformer.max_seq_len = lm_config.transformer.context
model = models.Lm(lm_config) model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16) model.set_dtype(mx.bfloat16)

View File

@ -205,6 +205,9 @@ def main():
moshi_weights = hf_get(moshi_name, args.hf_repo) moshi_weights = hf_get(moshi_name, args.hf_repo)
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo) tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
lm_config = models.LmConfig.from_config_dict(raw_config) lm_config = models.LmConfig.from_config_dict(raw_config)
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
# The following line gets around it for now.
lm_config.transformer.max_seq_len = lm_config.transformer.context
model = models.Lm(lm_config) model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16) model.set_dtype(mx.bfloat16)