Workaround for the mlx kv-cache bug.
This commit is contained in:
parent
09468c239a
commit
9de66be9ff
|
|
@ -76,6 +76,9 @@ def main():
|
|||
moshi_weights = hf_get(moshi_name, args.hf_repo)
|
||||
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
|
||||
lm_config = models.LmConfig.from_config_dict(raw_config)
|
||||
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
|
||||
# The following line gets around it for now.
|
||||
lm_config.transformer.max_seq_len = lm_config.transformer.context
|
||||
model = models.Lm(lm_config)
|
||||
model.set_dtype(mx.bfloat16)
|
||||
|
||||
|
|
|
|||
|
|
@ -205,6 +205,9 @@ def main():
|
|||
moshi_weights = hf_get(moshi_name, args.hf_repo)
|
||||
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
|
||||
lm_config = models.LmConfig.from_config_dict(raw_config)
|
||||
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
|
||||
# The following line gets around it for now.
|
||||
lm_config.transformer.max_seq_len = lm_config.transformer.context
|
||||
model = models.Lm(lm_config)
|
||||
model.set_dtype(mx.bfloat16)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user