Implementation and reproduction of "Neural networks leverage nominally quantum and post-quantum representations" by Riechers et al.
Core Idea: Neural networks spontaneously learn low-dimensional quantum-like representations of belief states when trained on next-token prediction for certain stochastic processes.
- Mess3: 3-state classical HMM requiring quantum (qutrit) representation
- Bloch Walk: Quantum process on Bloch sphere requiring qubit representation
- Joint Process: Cartesian product (Mess3 × Bloch Walk) with 12-token vocabulary
The transformer learns a 3D probability simplex representation with fractal structure from input-dependent transitions.
The transformer learns a clustered representation of the Bloch disk (x-z plane of the Bloch sphere).
The transformer learns factored representations: 3D simplex (Mess3) ⊗ 3D Bloch disk (quantum). 100K sequences, 3 epochs.
# Install uv package manager
curl -LsSf https://astral.sh/uv/install.sh | sh
# Sync dependencies and activate environment
uv sync
source .venv/bin/activate # On Windows: .venv\Scripts\activate# Make scripts executable (first time only)
chmod +x run_full_experiment.sh
# Run full experiment (training + analysis + visualization)
./run_full_experiment.sh mess3_run1 3 200000A 3-state Hidden Markov Model with input-dependent transitions:
- Parameters: x=0.05, α=0.85 (from paper Appendix D)
- States: A, B, C
- Symbols: 0, 1, 2 (vocabulary size = 3)
- Key Feature: Three different transition matrices T[0], T[1], T[2], selected based on the observed symbol
- Optimal Memory: Single qutrit (3D quantum state) vs. infinite classical states
- Belief Space: 3D probability simplex (triangle)
# Basic training (200K sequences, 3 epochs)
python train.py --output_dir runs/mess3_test \
--vocab_size 3 \
--num_sequences 200000 \
--epochs 3 \
--max_ctx 8
# Large-scale training (500K sequences, 25 epochs)
python train.py --output_dir runs/mess3_large \
--vocab_size 3 \
--num_sequences 500000 \
--epochs 25 \
--max_ctx 10 \
--d_model 512 \
--n_heads 16 \
--n_layers 6 \
--save_every 5Training Parameters:
--vocab_size: 3 for Mess3--max_ctx: Context length (8-10 recommended)--d_model: Hidden dimension (128-512)--n_heads: Attention heads (4-16)--n_layers: Transformer layers (2-6)--save_every: Save checkpoint every N epochs (0=disabled)
# Run regression analysis
python -m belief_geometry.regression_analysis \
--analysis_batch runs/mess3_test/analysis_batch.pt \
--output_dir runs/mess3_test/analysis \
--alpha 1.0 \
--markov_orders 2 3 5What it does:
- Linear regression: Activations → Belief states
- Compares quantum (3D) vs. classical Markov-N baselines
- Computes RMSE and R² metrics
Expected Results:
- Quantum 3D RMSE: 0.01-0.05 (low = learned quantum geometry)
- Markov-N RMSE: 0.1-0.3 (higher = worse approximation)
- Random RMSE: 0.4+ (baseline)
# Standard plots
python visualize_results.py \
--analysis_dir runs/mess3_test/analysis \
--analysis_batch runs/mess3_test/analysis_batch.pt \
--output_dir runs/mess3_test/analysis/plots
# Paper-style simplex plots (ternary diagrams)
python plot_belief_simplex.py \
--analysis_batch runs/mess3_test/analysis_batch.pt \
--analysis_dir runs/mess3_test/analysis \
--style both
# Enhanced density plots (shows fractal structure)
python plot_belief_simplex_enhanced.py \
--analysis_batch runs/mess3_test/analysis_batch.pt \
--analysis_dir runs/mess3_test/analysis \
--output_dir runs/mess3_test/analysis/simplex_plotsGenerated Plots:
rmse_comparison.png: Bar chart of all methodsbelief_trajectories.png: Time evolution of beliefspaper_style_comparison.png: 3-panel simplex figuresimplex_comparison.png: Ground truth vs. transformer simplexsimplex_density.png: Hexbin density plot showing fractal patterns
A genuine quantum process on the Bloch sphere:
- Representation: Single qubit (2D slice of Bloch sphere)
- Symbols: 0, 1, 2, 3 (vocabulary size = 4)
- Parameters: α=1, β=√51, γ=1/√208 (normalization)
- Implementation: Kraus operators K₀, K₁, K₂, K₃ (Equations D4-D7 from paper)
- Belief Space: 3D Bloch vectors [b_x, b_y, b_z], constrained to x-z plane (b_y ≈ 0)
- Key Property: No finite classical HMM exists!
# Small test (50K sequences, 2 epochs)
python train_and_plot_blochwalk.py
# Larger experiment (200K sequences, 5 epochs)
python train_blochwalk_200k.py# Generate 3-panel comparison (Ground Truth | Transformer | Performance)
python plot_blochwalk_complete.py # For 50K run
python plot_blochwalk_200k.py # For 200K runGenerated Plots:
bloch_walk_comparison.png: 3-panel figure showing b_x vs b_z projections- Colored by b_z coordinate (north/south pole position)
- Transformer learns clustered representation of Bloch disk
Cartesian product of Mess3 and Bloch Walk running independently in parallel:
- Vocabulary: 3 (Mess3) × 4 (Bloch) = 12 tokens
- Belief Space: 6D product space (3D simplex ⊗ 3D Bloch)
- Encoding:
joint_token = mess3_token * 4 + bloch_token - Independence: Two processes evolve separately; network observes both
# Standard experiment (50K sequences, 5 epochs)
python train_joint.py
# Large-scale experiment (500K sequences, 25 epochs)
python train_joint_large.py# Generate decomposed visualization
python plot_joint.py
# For large-scale experiment
python plot_joint_large.pyGenerated Plot (joint_comparison.png):
- 4-panel decomposition:
- Top left: Mess3 ground truth (simplex)
- Top right: Mess3 transformer (simplex)
- Bottom left: Bloch Walk ground truth (disk)
- Bottom right: Bloch Walk transformer (disk)
- Performance panel: RMSE for joint (6D), Mess3 (3D), Bloch (3D), and random baseline
- Result: Network learns factored representation of product space!
For thorough exploration with progressively larger models:
# Make executable (first time only)
chmod +x run_overnight_experiments.sh
# Run suite (4-5 hours, 3 experiments)
./run_overnight_experiments.sh
# Monitor progress
./monitor_experiments.shExperiments:
- Large Model: 300K sequences, 15 epochs, d_model=384, 4 layers
- Very Large: 400K sequences, 20 epochs, d_model=512, 6 layers
- Massive: 500K sequences, 25 epochs, d_model=512, 6 layers
Each experiment includes:
- Periodic checkpointing (
--save_every) - Regression analysis
- Standard and density simplex plots
runs/
├── mess3_test/
│ ├── config.json # Training configuration
│ ├── model.pt # Trained model weights
│ ├── metrics_epoch_*.json # Per-epoch metrics
│ ├── analysis_batch.pt # Activations + ground truth beliefs
│ └── analysis/
│ ├── regression_results.json
│ ├── quantum_3d_coef.npy
│ ├── quantum_3d_intercept.npy
│ ├── plots/ # Standard plots
│ └── simplex_plots/ # Simplex visualizations
├── blochwalk_test/
│ ├── model.pt
│ ├── analysis_batch.pt
│ └── bloch_walk_comparison.png
└── joint_mess3_bloch/
├── model.pt
├── analysis_batch.pt
└── joint_comparison.png
belief_geometry/
├── processes.py # HMM definitions, sequence generation
├── dataset.py # PyTorch datasets for each process
├── model.py # ResidualHookedTransformer
├── config.py # TrainConfig dataclass
└── regression_analysis.py # Analysis pipeline
train.py # Main training script (Mess3)
train_and_plot_blochwalk.py # Train Bloch Walk (50K)
train_blochwalk_200k.py # Train Bloch Walk (200K)
train_joint.py # Train joint process
train_joint_large.py # Train joint process (500K)
plot_belief_simplex.py # Mess3 simplex plots
plot_belief_simplex_enhanced.py # Mess3 density plots
plot_blochwalk_complete.py # Bloch Walk visualization
plot_joint.py # Joint process visualization
plot_joint_large.py # Joint process (large-scale)
test_joint_process.py # Test joint data generation
run_full_experiment.sh # Complete pipeline
run_overnight_experiments.sh # Large-scale suite
monitor_experiments.sh # Status checker
- Input-dependent transitions: Three T matrices, selected by previous observation
- Numerical stability: Belief filtering includes clipping and renormalization
- Emission matrix: Ambiguous emissions to generate interior simplex points
- Kraus operators: More fundamental than GHMM formulation
- Normalization: γ = 1/√208 (corrected from paper's potential typo)
- State representation: Density matrix ρ → Bloch vector via Pauli matrices
- Sampling: Born rule probabilities p_n = tr(K_n ρ K_n†)
- Independent generation: Two processes run in parallel
- Token encoding: Pairing function maps (m, b) → joint index
- Belief concatenation: [Mess3 beliefs (3D), Bloch beliefs (3D)] = 6D
The paper's key finding: When trained on next-token prediction, transformers spontaneously discover minimal sufficient representations of belief states:
- Mess3: 3D simplex (quantum-like) instead of infinite-dimensional classical history
- Bloch Walk: Qubit representation (genuinely quantum)
- Joint: Factored product space (6D = 3D ⊗ 3D)
- RMSE: Root mean squared error of belief state prediction
- Lower = better approximation
- Quantum < Classical Markov-N < Random
- R² score: Fraction of variance explained
- Closer to 1 = better fit
- Component-wise RMSE: For joint process, measure each component separately
- Simplex (Mess3): Fractal structure emerges from input-dependent transitions
- Bloch disk: Clustering shows discrete quantum measurement outcomes
- Joint: Network learns to separate and represent both geometries independently
Paper's full repo: adamimos/epsilon-transformers
Their implementation:
- 4 processes (Mess3, Bloch Walk, FRDN, Moon)
- 4 architectures (Transformer, LSTM, GRU, RNN)
- 201 checkpoints per training run
- wandb integration for tracking
- HuggingFace datasets with pre-trained models
Our implementation:
- Focused on Mess3 + Bloch Walk + Joint
- Transformer only
- Simplified training/analysis pipeline
- Emphasis on reproducing key figures
- Additional: Kraus operator formulation for Bloch Walk
- Solution: Ensure using input-dependent Mess3 (3 T matrices)
- Increase training data (500K+ sequences)
- Train for more epochs (15-25)
- Use larger models (d_model=512, n_layers=6)
- Solution: Already fixed via clipping/renormalization in
filter_beliefs()
- Solution: Reduce learning rate, add gradient clipping, check data generation
- Solution: Quantum processes are harder; use 200K+ sequences, 5+ epochs
If you use this code or reproduce the experiments, please cite:
For the original paper:
@article{riechers2025quantum,
title={Neural networks leverage nominally quantum and post-quantum representations},
author={Riechers, Paul and Crutchfield, James and others},
journal={arXiv preprint arXiv:2507.07432},
year={2025}
}For this implementation:
@software{kocher2025belief_geometry,
author={Kocher, Greg},
title={Belief State Geometry: Implementation of Mess3, Bloch Walk, and Joint Processes},
year={2025},
url={https://github.com/gregkocher/belief-state-geometry}
}QUICKSTART.md: Command reference cheat sheetBLOCH_WALK_IMPLEMENTATION.md: Technical details of Kraus operator approachDOCUMENTATION_UPDATE.md: Summary of documentation changesSUMMARY.md: Complete repository summary