Noureddine RAMDI / dflash-mlx: Speculative decoding on Apple Silicon with Metal and MLX

Created Mon, 04 May 2026 10:23:02 +0000 Modified Sat, 23 May 2026 20:41:27 +0000

Aryagm/dflash-mlx

Speculative decoding can speed up autoregressive language model inference by proposing multiple tokens at once and verifying them in a single pass. dflash-mlx brings this concept natively to Apple Silicon using Metal and Apple’s MLX framework, which notably lacks built-in support for speculative decoding. The project reconstructs every piece of the decoding pipeline — from hidden-state extraction to per-layer KV cache rollback — to enable exact, bit-for-bit identical outputs with fewer forward passes.

What dflash-mlx does and its architecture

dflash-mlx is a native port of DFlash (Block Diffusion for Flash Speculative Decoding) tailored for Apple Silicon GPUs using Metal and MLX. The core idea is to speed up autoregressive decoding by having a small “draft” model propose multiple tokens simultaneously as a block-diffusion draft. Then, a larger “target” model verifies this block in a single forward pass and accepts the longest correct prefix, ensuring the output matches exactly what the target model would produce token-by-token.

Under the hood, this requires several challenging pieces to work together:

  • Hidden-state extraction: Efficiently capturing and managing intermediate states needed for verification.

  • Parallel block proposal: The draft model proposes token blocks concurrently.

  • Single-pass batched verification: The target model verifies all proposed tokens in one go.

  • Per-layer KV cache rollback: If verification rejects some tokens, the KV cache (key-value cache storing attention states) is rolled back layer by layer to the last accepted token, preserving correctness.

Since MLX provides no primitives for speculative decoding, dflash-mlx engineers these mechanisms from scratch directly on Metal. This involves low-level GPU programming and careful cache management.

The repo supports Qwen3-4B and Qwen3.5-4B models, with an adapter pattern that isolates model-specific details like layer IDs, cache types, and stop tokens. This makes adding new model families a matter of implementing a single adapter file, improving extensibility.

Beyond the core decoding engine, dflash-mlx includes:

  • CLI tools for running generation and interactive chat
  • A Python API for integration in applications
  • Streaming output support
  • An OpenAI-compatible local server interface

The stack is primarily Python for orchestration with heavy Metal kernel code under the hood for GPU operations via MLX.

Technical strengths and tradeoffs

dflash-mlx stands out for building exact speculative decoding from the ground up on a platform that doesn’t support it natively. This is an engineering challenge because speculative decoding involves complex state management and verification logic that must be tightly coupled with the model’s KV cache and GPU execution.

The adapter pattern separating model-specific logic from the decoding engine is a clean architectural decision that simplifies maintenance and extensibility.

Tradeoffs include:

  • Hardware specificity: The solution is Apple Silicon-specific, relying on Metal and MLX, which limits portability.

  • Model support: Currently limited to Qwen3 variants with adapters. Adding new models requires adapter implementation and validation.

  • Complexity: Implementing per-layer KV cache rollback and batched verification is non-trivial and increases code complexity.

  • Model size and resource use: Downloading model checkpoints requires significant disk space (~12 GB for default models), and running these large models on Apple Silicon demands substantial memory and compute.

The codebase is surprisingly clean given the low-level GPU programming involved. The CLI and Python API provide decent DX, and the local OpenAI-compatible server makes integration easier for existing workflows.

Quick start

To try dflash-mlx, the README provides a straightforward quick start:

git clone https://github.com/aryagm/dflash-mlx.git && cd dflash-mlx
uv sync

uv run dflash-mlx --max-new-tokens 128

This runs the default generation with Qwen3-4B in BF16 precision. The first run downloads the models (~12 GB) into the Hugging Face cache.

You can override models with --target-model and --draft-model flags.

For interactive chat mode, use dflash-mlx-chat. Adding --json outputs machine-readable results.

Benchmark history can be recorded with --history or --history-file.

The Python API usage example:

from dflash_mlx import DFlashGenerator
# Instantiate and use the generator as needed

Verdict

dflash-mlx is a compelling project if you want to explore or deploy speculative decoding on Apple Silicon with native Metal acceleration. It’s particularly relevant for researchers and engineers working with autoregressive language models looking to reduce inference latency and GPU workload.

The project’s strength lies in its ground-up implementation of speculative decoding primitives on a platform without native support, demonstrating the feasibility and tradeoffs involved.

However, it’s not a drop-in library for general use. The Apple Silicon requirement, limited model support, and engineering complexity mean it suits users comfortable with low-level ML engineering and willing to adapt it to their models.

If you’re targeting other hardware or need broader model compatibility out of the box, this might not be the right tool. But for Apple Silicon environments, dflash-mlx offers an impressive, low-level approach to speeding up generation without sacrificing output fidelity.


→ GitHub Repo: Aryagm/dflash-mlx ⭐ 361 · Python