-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_ground_truth_only.py
More file actions
58 lines (45 loc) · 1.59 KB
/
Copy pathplot_ground_truth_only.py
File metadata and controls
58 lines (45 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/env python3
"""
Simple script to just plot ground truth Bloch disk (no regression).
"""
import numpy as np
import torch
import matplotlib.pyplot as plt
from pathlib import Path
print("Loading data...")
data = torch.load("runs/blochwalk_test/analysis_batch.pt")
beliefs = data["beliefs"].numpy() # (N, T, 3)
print(f"Beliefs shape: {beliefs.shape}")
# Flatten
beliefs_flat = beliefs.reshape(-1, 3)
b_x = beliefs_flat[:, 0]
b_y = beliefs_flat[:, 1]
b_z = beliefs_flat[:, 2]
print(f"Total points: {len(b_x)}")
print(f"b_x range: [{b_x.min():.4f}, {b_x.max():.4f}]")
print(f"b_y range: [{b_y.min():.4f}, {b_y.max():.4f}]")
print(f"b_z range: [{b_z.min():.4f}, {b_z.max():.4f}]")
# Subsample
n_plot = min(5000, len(b_x))
idx = np.random.RandomState(42).choice(len(b_x), n_plot, replace=False)
# Plot
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
scatter = ax.scatter(b_x[idx], b_z[idx], c=b_z[idx], cmap='viridis', s=2, alpha=0.6)
# Unit circle
theta = np.linspace(0, 2*np.pi, 100)
ax.plot(np.cos(theta), np.sin(theta), 'r-', linewidth=2, alpha=0.5, label='Bloch sphere boundary')
ax.set_xlabel('$b_x$', fontsize=14)
ax.set_ylabel('$b_z$', fontsize=14)
ax.set_title('Ground Truth Bloch Walk\n(Quantum Belief States)', fontsize=16, fontweight='bold')
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)
ax.set_xlim(-1.1, 1.1)
ax.set_ylim(-1.1, 1.1)
ax.legend()
plt.colorbar(scatter, ax=ax, label='$b_z$ value')
plt.tight_layout()
save_path = "runs/blochwalk_test/bloch_disk_ground_truth.png"
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"\n✓ Saved to {save_path}")
plt.close()
print("\n✅ Done!")