Skip to content

gregkocher/belief-state-geometry

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Belief State Geometry: Mess3, Bloch Walk, and Joint Processes

Overview

Implementation and reproduction of "Neural networks leverage nominally quantum and post-quantum representations" by Riechers et al.

Core Idea: Neural networks spontaneously learn low-dimensional quantum-like representations of belief states when trained on next-token prediction for certain stochastic processes.

Implemented Processes

  1. Mess3: 3-state classical HMM requiring quantum (qutrit) representation
  2. Bloch Walk: Quantum process on Bloch sphere requiring qubit representation
  3. Joint Process: Cartesian product (Mess3 × Bloch Walk) with 12-token vocabulary

Example Results

Mess3: Classical HMM with Quantum Representation

Mess3 Results The transformer learns a 3D probability simplex representation with fractal structure from input-dependent transitions.

Bloch Walk: Quantum Process on Bloch Sphere

Bloch Walk Results The transformer learns a clustered representation of the Bloch disk (x-z plane of the Bloch sphere).

Joint Process: Product Space (Mess3 × Bloch Walk)

Joint Process Results The transformer learns factored representations: 3D simplex (Mess3) ⊗ 3D Bloch disk (quantum). 100K sequences, 3 epochs.


Quick Start

Setup

# Install uv package manager
curl -LsSf https://astral.sh/uv/install.sh | sh

# Sync dependencies and activate environment
uv sync
source .venv/bin/activate  # On Windows: .venv\Scripts\activate

Run Complete Pipeline

# Make scripts executable (first time only)
chmod +x run_full_experiment.sh

# Run full experiment (training + analysis + visualization)
./run_full_experiment.sh mess3_run1 3 200000

Process 1: Mess3 (Classical → Quantum)

What is Mess3?

A 3-state Hidden Markov Model with input-dependent transitions:

  • Parameters: x=0.05, α=0.85 (from paper Appendix D)
  • States: A, B, C
  • Symbols: 0, 1, 2 (vocabulary size = 3)
  • Key Feature: Three different transition matrices T[0], T[1], T[2], selected based on the observed symbol
  • Optimal Memory: Single qutrit (3D quantum state) vs. infinite classical states
  • Belief Space: 3D probability simplex (triangle)

Training Mess3

# Basic training (200K sequences, 3 epochs)
python train.py --output_dir runs/mess3_test \
  --vocab_size 3 \
  --num_sequences 200000 \
  --epochs 3 \
  --max_ctx 8

# Large-scale training (500K sequences, 25 epochs)
python train.py --output_dir runs/mess3_large \
  --vocab_size 3 \
  --num_sequences 500000 \
  --epochs 25 \
  --max_ctx 10 \
  --d_model 512 \
  --n_heads 16 \
  --n_layers 6 \
  --save_every 5

Training Parameters:

  • --vocab_size: 3 for Mess3
  • --max_ctx: Context length (8-10 recommended)
  • --d_model: Hidden dimension (128-512)
  • --n_heads: Attention heads (4-16)
  • --n_layers: Transformer layers (2-6)
  • --save_every: Save checkpoint every N epochs (0=disabled)

Analyzing Mess3

# Run regression analysis
python -m belief_geometry.regression_analysis \
  --analysis_batch runs/mess3_test/analysis_batch.pt \
  --output_dir runs/mess3_test/analysis \
  --alpha 1.0 \
  --markov_orders 2 3 5

What it does:

  • Linear regression: Activations → Belief states
  • Compares quantum (3D) vs. classical Markov-N baselines
  • Computes RMSE and R² metrics

Expected Results:

  • Quantum 3D RMSE: 0.01-0.05 (low = learned quantum geometry)
  • Markov-N RMSE: 0.1-0.3 (higher = worse approximation)
  • Random RMSE: 0.4+ (baseline)

Visualizing Mess3

# Standard plots
python visualize_results.py \
  --analysis_dir runs/mess3_test/analysis \
  --analysis_batch runs/mess3_test/analysis_batch.pt \
  --output_dir runs/mess3_test/analysis/plots

# Paper-style simplex plots (ternary diagrams)
python plot_belief_simplex.py \
  --analysis_batch runs/mess3_test/analysis_batch.pt \
  --analysis_dir runs/mess3_test/analysis \
  --style both

# Enhanced density plots (shows fractal structure)
python plot_belief_simplex_enhanced.py \
  --analysis_batch runs/mess3_test/analysis_batch.pt \
  --analysis_dir runs/mess3_test/analysis \
  --output_dir runs/mess3_test/analysis/simplex_plots

Generated Plots:

  • rmse_comparison.png: Bar chart of all methods
  • belief_trajectories.png: Time evolution of beliefs
  • paper_style_comparison.png: 3-panel simplex figure
  • simplex_comparison.png: Ground truth vs. transformer simplex
  • simplex_density.png: Hexbin density plot showing fractal patterns

Process 2: Bloch Walk (Quantum Process)

What is Bloch Walk?

A genuine quantum process on the Bloch sphere:

  • Representation: Single qubit (2D slice of Bloch sphere)
  • Symbols: 0, 1, 2, 3 (vocabulary size = 4)
  • Parameters: α=1, β=√51, γ=1/√208 (normalization)
  • Implementation: Kraus operators K₀, K₁, K₂, K₃ (Equations D4-D7 from paper)
  • Belief Space: 3D Bloch vectors [b_x, b_y, b_z], constrained to x-z plane (b_y ≈ 0)
  • Key Property: No finite classical HMM exists!

Training Bloch Walk

# Small test (50K sequences, 2 epochs)
python train_and_plot_blochwalk.py

# Larger experiment (200K sequences, 5 epochs)
python train_blochwalk_200k.py

Visualizing Bloch Walk

# Generate 3-panel comparison (Ground Truth | Transformer | Performance)
python plot_blochwalk_complete.py  # For 50K run
python plot_blochwalk_200k.py      # For 200K run

Generated Plots:

  • bloch_walk_comparison.png: 3-panel figure showing b_x vs b_z projections
  • Colored by b_z coordinate (north/south pole position)
  • Transformer learns clustered representation of Bloch disk

Process 3: Joint Mess3 × Bloch Walk (Product Space)

What is the Joint Process?

Cartesian product of Mess3 and Bloch Walk running independently in parallel:

  • Vocabulary: 3 (Mess3) × 4 (Bloch) = 12 tokens
  • Belief Space: 6D product space (3D simplex ⊗ 3D Bloch)
  • Encoding: joint_token = mess3_token * 4 + bloch_token
  • Independence: Two processes evolve separately; network observes both

Training Joint Process

# Standard experiment (50K sequences, 5 epochs)
python train_joint.py

# Large-scale experiment (500K sequences, 25 epochs)
python train_joint_large.py

Visualizing Joint Process

# Generate decomposed visualization
python plot_joint.py

# For large-scale experiment
python plot_joint_large.py

Generated Plot (joint_comparison.png):

  • 4-panel decomposition:
    • Top left: Mess3 ground truth (simplex)
    • Top right: Mess3 transformer (simplex)
    • Bottom left: Bloch Walk ground truth (disk)
    • Bottom right: Bloch Walk transformer (disk)
  • Performance panel: RMSE for joint (6D), Mess3 (3D), Bloch (3D), and random baseline
  • Result: Network learns factored representation of product space!

Large-Scale Experiments

Overnight Experiment Suite

For thorough exploration with progressively larger models:

# Make executable (first time only)
chmod +x run_overnight_experiments.sh

# Run suite (4-5 hours, 3 experiments)
./run_overnight_experiments.sh

# Monitor progress
./monitor_experiments.sh

Experiments:

  1. Large Model: 300K sequences, 15 epochs, d_model=384, 4 layers
  2. Very Large: 400K sequences, 20 epochs, d_model=512, 6 layers
  3. Massive: 500K sequences, 25 epochs, d_model=512, 6 layers

Each experiment includes:

  • Periodic checkpointing (--save_every)
  • Regression analysis
  • Standard and density simplex plots

File Structure

Training Outputs

runs/
├── mess3_test/
│   ├── config.json              # Training configuration
│   ├── model.pt                 # Trained model weights
│   ├── metrics_epoch_*.json     # Per-epoch metrics
│   ├── analysis_batch.pt        # Activations + ground truth beliefs
│   └── analysis/
│       ├── regression_results.json
│       ├── quantum_3d_coef.npy
│       ├── quantum_3d_intercept.npy
│       ├── plots/               # Standard plots
│       └── simplex_plots/       # Simplex visualizations
├── blochwalk_test/
│   ├── model.pt
│   ├── analysis_batch.pt
│   └── bloch_walk_comparison.png
└── joint_mess3_bloch/
    ├── model.pt
    ├── analysis_batch.pt
    └── joint_comparison.png

Source Code

belief_geometry/
├── processes.py       # HMM definitions, sequence generation
├── dataset.py         # PyTorch datasets for each process
├── model.py           # ResidualHookedTransformer
├── config.py          # TrainConfig dataclass
└── regression_analysis.py  # Analysis pipeline

Scripts

train.py                         # Main training script (Mess3)
train_and_plot_blochwalk.py      # Train Bloch Walk (50K)
train_blochwalk_200k.py          # Train Bloch Walk (200K)
train_joint.py                   # Train joint process
train_joint_large.py             # Train joint process (500K)
plot_belief_simplex.py           # Mess3 simplex plots
plot_belief_simplex_enhanced.py  # Mess3 density plots
plot_blochwalk_complete.py       # Bloch Walk visualization
plot_joint.py                    # Joint process visualization
plot_joint_large.py              # Joint process (large-scale)
test_joint_process.py            # Test joint data generation
run_full_experiment.sh           # Complete pipeline
run_overnight_experiments.sh     # Large-scale suite
monitor_experiments.sh           # Status checker

Key Implementation Details

Mess3 Process

  • Input-dependent transitions: Three T matrices, selected by previous observation
  • Numerical stability: Belief filtering includes clipping and renormalization
  • Emission matrix: Ambiguous emissions to generate interior simplex points

Bloch Walk Process

  • Kraus operators: More fundamental than GHMM formulation
  • Normalization: γ = 1/√208 (corrected from paper's potential typo)
  • State representation: Density matrix ρ → Bloch vector via Pauli matrices
  • Sampling: Born rule probabilities p_n = tr(K_n ρ K_n†)

Joint Process

  • Independent generation: Two processes run in parallel
  • Token encoding: Pairing function maps (m, b) → joint index
  • Belief concatenation: [Mess3 beliefs (3D), Bloch beliefs (3D)] = 6D

Understanding the Results

What the Transformer Learns

The paper's key finding: When trained on next-token prediction, transformers spontaneously discover minimal sufficient representations of belief states:

  1. Mess3: 3D simplex (quantum-like) instead of infinite-dimensional classical history
  2. Bloch Walk: Qubit representation (genuinely quantum)
  3. Joint: Factored product space (6D = 3D ⊗ 3D)

Metrics

  • RMSE: Root mean squared error of belief state prediction
    • Lower = better approximation
    • Quantum < Classical Markov-N < Random
  • R² score: Fraction of variance explained
    • Closer to 1 = better fit
  • Component-wise RMSE: For joint process, measure each component separately

Visualization Insights

  • Simplex (Mess3): Fractal structure emerges from input-dependent transitions
  • Bloch disk: Clustering shows discrete quantum measurement outcomes
  • Joint: Network learns to separate and represent both geometries independently

Comparison to Paper's Repository

Paper's full repo: adamimos/epsilon-transformers

Their implementation:

  • 4 processes (Mess3, Bloch Walk, FRDN, Moon)
  • 4 architectures (Transformer, LSTM, GRU, RNN)
  • 201 checkpoints per training run
  • wandb integration for tracking
  • HuggingFace datasets with pre-trained models

Our implementation:

  • Focused on Mess3 + Bloch Walk + Joint
  • Transformer only
  • Simplified training/analysis pipeline
  • Emphasis on reproducing key figures
  • Additional: Kraus operator formulation for Bloch Walk

Troubleshooting

Plots lack fractal structure

  • Solution: Ensure using input-dependent Mess3 (3 T matrices)
  • Increase training data (500K+ sequences)
  • Train for more epochs (15-25)
  • Use larger models (d_model=512, n_layers=6)

Points outside simplex boundary

  • Solution: Already fixed via clipping/renormalization in filter_beliefs()

NaN or Inf during training

  • Solution: Reduce learning rate, add gradient clipping, check data generation

Poor Bloch Walk performance

  • Solution: Quantum processes are harder; use 200K+ sequences, 5+ epochs

Citation

If you use this code or reproduce the experiments, please cite:

For the original paper:

@article{riechers2025quantum,
  title={Neural networks leverage nominally quantum and post-quantum representations},
  author={Riechers, Paul and Crutchfield, James and others},
  journal={arXiv preprint arXiv:2507.07432},
  year={2025}
}

For this implementation:

@software{kocher2025belief_geometry,
  author={Kocher, Greg},
  title={Belief State Geometry: Implementation of Mess3, Bloch Walk, and Joint Processes},
  year={2025},
  url={https://github.com/gregkocher/belief-state-geometry}
}

Additional Documentation

  • QUICKSTART.md: Command reference cheat sheet
  • BLOCH_WALK_IMPLEMENTATION.md: Technical details of Kraus operator approach
  • DOCUMENTATION_UPDATE.md: Summary of documentation changes
  • SUMMARY.md: Complete repository summary

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors