feat: Add LLaDA Diffusion Model Support #14
+605
−44
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Overview
This PR introduces support for diffusion-based language models (e.g., LLaDA) in mlx_lm, extending the framework beyond autoregressive generation to include diffusion paradigms for further research. The implementation adds a new generate_diffusion function in utils.py, respectfully updates the generation pipeline to handle both model types, and resolves issues that arose during integration, being very careful not to affect autoregressive path compatibility. The changes maintain alignment with the mlx-lm principles while ensuring robust initial functionality for diffusion models LLaDA and potentially others going forward.
Key Changes
New
generate_diffusion
Function inutils.py
steps
), generation length (gen_length
), and block sizes (block_length
).noise_temp
) for stochasticity and classifier-free guidance (cfg
) for controlled generation.topk
orrandom
).GenerationResponse
dataclass.mx.compile
for the sampling step to leverage MLX’s performance benefits.Updated
stream_generate
inutils.py
model.args.model_type == "llada"
and delegate togenerate_diffusion
.generate_step
andspeculative_generate_step
) for non-diffusion models.steps
,gen_length
) togenerate_diffusion
and autoregressive kwargs (e.g.,max_tokens
,sampler
) to their respective functions.GenerationResponse
interface.Extended CLI in
generate.py
--steps
(default: 32): Number of diffusion steps.--gen-length
(default: 64): Length of the generated sequence.--noise-temp
(default: 0.0): Temperature for Gumbel noise sampling.--cfg
(default: 1.0): Classifier-free guidance scale.--block-length
(default: None): Size of semi-autoregressive blocks.--unmasking
(default: "topk"): Strategy for unmasking tokens (topk
orrandom
).generate
call, ensuring seamless invocation of diffusion generation when using an LLaDA model.Implementation Details
model.args.model_type == "llada"
to switch to diffusion mode, ensuring compatibility with custom LLaDA implementations (e.g., yourllada.py
).Testing
Impact
mlx_lm
’s applicability.Known Limitations
\033[2J\033[H
), which may not render perfectly in all terminals.Future Work