Skip to content

feat: Add LLaDA Diffusion Model Support #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

cavit99
Copy link
Contributor

@cavit99 cavit99 commented Mar 15, 2025

Overview

This PR introduces support for diffusion-based language models (e.g., LLaDA) in mlx_lm, extending the framework beyond autoregressive generation to include diffusion paradigms for further research. The implementation adds a new generate_diffusion function in utils.py, respectfully updates the generation pipeline to handle both model types, and resolves issues that arose during integration, being very careful not to affect autoregressive path compatibility. The changes maintain alignment with the mlx-lm principles while ensuring robust initial functionality for diffusion models LLaDA and potentially others going forward.

Key Changes

  1. New generate_diffusion Function in utils.py

    • Added a new function to support diffusion-based text generation, tailored for models like LLaDA.
    • Key features:
      • Progressive token unmasking with configurable steps (steps), generation length (gen_length), and block sizes (block_length).
      • Supports Gumbel noise sampling (noise_temp) for stochasticity and classifier-free guidance (cfg) for controlled generation.
      • Implements semi-autoregressive block-wise diffusion with configurable unmasking strategies (topk or random).
      • Yields intermediate progress updates in verbose mode or final text otherwise, integrated with the GenerationResponse dataclass.
    • Optimizations:
      • Uses mx.compile for the sampling step to leverage MLX’s performance benefits.
      • Efficiently handles batch size of 1 (current limitation) with plans for future expansion.
  2. Updated stream_generate in utils.py

    • Modified to detect diffusion models via model.args.model_type == "llada" and delegate to generate_diffusion.
    • Preserves existing autoregressive paths (generate_step and speculative_generate_step) for non-diffusion models.
    • Added argument filtering to route diffusion-specific kwargs (e.g., steps, gen_length) to generate_diffusion and autoregressive kwargs (e.g., max_tokens, sampler) to their respective functions.
    • Ensures consistent streaming behavior across model types using the GenerationResponse interface.
  3. Extended CLI in generate.py

    • Added diffusion-specific arguments with sensible defaults:
      • --steps (default: 32): Number of diffusion steps.
      • --gen-length (default: 64): Length of the generated sequence.
      • --noise-temp (default: 0.0): Temperature for Gumbel noise sampling.
      • --cfg (default: 1.0): Classifier-free guidance scale.
      • --block-length (default: None): Size of semi-autoregressive blocks.
      • --unmasking (default: "topk"): Strategy for unmasking tokens (topk or random).
    • Integrated these into the generate call, ensuring seamless invocation of diffusion generation when using an LLaDA model.

Implementation Details

  • Model Detection: Relies on model.args.model_type == "llada" to switch to diffusion mode, ensuring compatibility with custom LLaDA implementations (e.g., your llada.py).
  • Performance: Leverages MLX’s fast operations as much as possible.
  • Backward Compatibility: Autoregressive generation paths (including speculative decoding) remain unchanged, with diffusion-specific logic isolated to new code paths.

Testing

  • Weights Conversion
    • Works as normal, including quantization.
    • Example:
      mlx_lm.convert --hf-path <llada-hf-repo> --mlx-path ./llada-mlx --quantize --q-bits 4
  • Diffusion:
    • Command:
    • mlx_lm.generate \
        --model mlx-community/LLaDA-8B-Instruct-mlx-fp16 \
        --prompt "Tell me about Leonardo da Vinci." \
        --gen-length 32 \
        --steps 32 \
        --noise-temp 0.3 \
        --cfg 1.0 \
        --verbose true
    • Recommend using pre converted weights mlx-community/LLaDA-8B-Instruct-mlx-fp16 or quantized also on mlx-community.
    • Output:
      • Verbose mode shows block-wise progress (e.g., "Block 1/1 | Step 32/32 | Unmasked 32/32") with intermediate text.
      • Non-verbose mode prints only the final generated text.

Impact

  • New Feature: Users can now generate text with diffusion model LLaDA, broadening mlx_lm’s applicability.
  • Preserved Behaviour: Autoregressive models unaffected, with minimised intrusion into the repo.

Known Limitations

  • Verbose output in diffusion mode uses ANSI escape codes (\033[2J\033[H), which may not render perfectly in all terminals.

Future Work

  • Optimise performance
  • New research regarding diffusion LLMs drops weekly, there is a lot of scope for improvements

@Blaizzy
Copy link
Contributor

Blaizzy commented Mar 15, 2025

Well done @cavit99, this looks awesome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants