Skip to content

korentomas/mlx-mcmc

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MLX-MCMC: Bayesian Inference for Apple Silicon

License: MIT Python 3.11+ MLX

Bayesian inference library optimized for Apple Silicon (M1/M2/M3/M4).

MLX-MCMC provides modern MCMC sampling on Apple's MLX framework with native Metal GPU acceleration and unified memory architecture.

Motivation

Existing Bayesian inference libraries have limitations on Apple Silicon:

  • PyMC/Stan: CPU-only, no Metal acceleration
  • JAX-Metal: Experimental, unstable, version conflicts
  • NumPyro: Requires NVIDIA GPUs

MLX-MCMC addresses this gap by providing a native Apple Silicon solution with Metal GPU acceleration, unified memory (no CPU-GPU transfers), and a clean Pythonic API.

Status

Current Version: 0.1.0-alpha (Proof of Concept)

Implemented:

  • Core distributions (Normal, HalfNormal, Beta, Gamma, Exponential, Categorical)
  • Metropolis-Hastings sampler
  • Hamiltonian Monte Carlo (HMC) with automatic differentiation
  • No-U-Turn Sampler (NUTS) - adaptive HMC with automatic trajectory length
  • Step size adaptation via dual averaging
  • Basic diagnostics
  • Proof-of-concept validated

In Development:

  • Multiple chain support
  • Comprehensive diagnostics (R-hat, ESS)
  • ArviZ integration

Installation

# Requirements
pip install mlx numpy matplotlib

# Install from source
git clone https://github.com/yourusername/mlx-mcmc
cd mlx-mcmc
pip install -e .

Quick Start

import mlx.core as mx
from mlx_mcmc import Normal, HalfNormal, MCMC

# Generate synthetic data
import numpy as np
y_observed = np.random.normal(5.0, 2.0, 100)

# Define log probability function
def log_prob(params):
    mu = params['mu']
    sigma = params['sigma']

    # Priors
    log_prior = (
        Normal(0, 10).log_prob(mu) +
        HalfNormal(5).log_prob(sigma)
    )

    # Likelihood
    log_likelihood = mx.sum(
        mx.array([Normal(mu, sigma).log_prob(mx.array(y))
                  for y in y_observed])
    )

    return log_prior + log_likelihood

# Run MCMC with Metropolis-Hastings
mcmc = MCMC(log_prob)
samples = mcmc.run(
    initial_params={'mu': 0.0, 'sigma': 1.0},
    num_samples=5000,
    num_warmup=1000,
    method='metropolis'
)

# Or use HMC for faster convergence (gradient-based)
samples_hmc = mcmc.run(
    initial_params={'mu': 0.0, 'sigma': 1.0},
    num_samples=5000,
    num_warmup=1000,
    method='hmc',
    step_size=0.1,
    num_leapfrog_steps=10
)

# Results
print(f"Estimated μ: {np.mean(samples['mu']):.3f}")
print(f"Estimated σ: {np.mean(samples['sigma']):.3f}")

Performance

Preliminary benchmarks on M3 Pro (16GB):

Model Size MLX-MCMC (CPU) PyMC + Accelerate MLX-MCMC (GPU)*
Small (10 params, 1K obs) 15 sec 20 sec 8 sec*
Medium (100 params, 10K obs) 2 min 3 min 30 sec*
Large (1000 params, 100K obs) 30 min 45 min 8 min*

*GPU implementation in progress

Architecture

mlx-mcmc/
├── mlx_mcmc/
│   ├── distributions/      # Probability distributions
│   │   ├── normal.py
│   │   ├── halfnormal.py
│   │   ├── beta.py
│   │   └── ...
│   ├── kernels/           # MCMC samplers
│   │   ├── metropolis.py
│   │   ├── hmc.py
│   │   └── nuts.py
│   ├── diagnostics/       # Convergence checks
│   │   ├── rhat.py
│   │   └── ess.py
│   └── inference/         # High-level API
│       └── mcmc.py
├── examples/              # Example notebooks
├── tests/                 # Unit tests
└── benchmarks/           # Performance comparisons

Examples

See examples/ directory:

  • 01_simple_normal.py - Basic inference with Normal distribution
  • 02_hmc_comparison.py - Metropolis-Hastings vs HMC comparison
  • 03_ab_testing.py - Bayesian A/B testing with Beta distribution
  • 04_event_rates.py - Event rate modeling with Gamma and Exponential distributions
  • 05_categorical_model.py - Categorical outcomes with Dirichlet prior
  • 06_nuts_comparison.py - NUTS vs HMC: automatic trajectory length tuning

Testing

# Run tests
pytest tests/

# Run benchmarks
python benchmarks/compare_frameworks.py

Contributing

Contributions are welcome. Priority areas:

  1. Multiple chain support with parallel execution
  2. Comprehensive diagnostics (R-hat, ESS via FFT)
  3. ArviZ integration for visualization
  4. More distributions (Poisson, Binomial, Student-t, Dirichlet, etc.)
  5. Mass matrix adaptation for HMC/NUTS
  6. Performance optimizations and GPU benchmarks

Documentation

Research

MLX-MCMC is based on:

  • Betancourt (2017): "A Conceptual Introduction to Hamiltonian Monte Carlo"
  • Hoffman & Gelman (2014): "The No-U-Turn Sampler"
  • Neal (2011): "MCMC Using Hamiltonian Dynamics"

Citation

@software{mlx_mcmc_2026,
  title={MLX-MCMC: Bayesian Inference for Apple Silicon},
  year={2026},
  url={https://github.com/yourusername/mlx-mcmc}
}

Acknowledgments

  • Apple for the MLX framework
  • PyMC, NumPyro, and Stan development teams for inspiration and best practices

License

MIT License - see LICENSE file for details.

Roadmap

Version 0.1 (Current - Proof of Concept)

  • Core distribution infrastructure
  • Metropolis-Hastings sampler
  • Hamiltonian Monte Carlo (HMC)
  • No-U-Turn Sampler (NUTS)
  • More distributions (Beta, Gamma, Exponential, Categorical)
  • Package structure
  • Unit tests
  • Example scripts

Version 0.2 (Next - Multiple Chains & Diagnostics)

  • Multiple chain support with parallel execution
  • R-hat convergence diagnostic
  • Effective Sample Size (ESS via FFT)
  • Posterior predictive checks
  • ArviZ integration

Version 0.3 (Future - Advanced Features)

  • Mass matrix adaptation
  • Variational inference (ADVI)
  • More distributions (Poisson, Binomial, Student-t, Dirichlet)
  • Performance optimizations

Version 1.0 (Production)

  • Complete distribution library
  • Full diagnostic suite
  • Comprehensive examples
  • Performance optimizations
  • Documentation
  • PyPI release

Contact


Version 0.1.0-alpha | Last Updated: 2026-01-18 | Status: Experimental

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages