Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
a731fb2
TEMP - Make photocurrent_mapping.py run on MPS
trung-vt Jan 11, 2026
2b46674
TEMP - Seeing if can make MPS faster, not working yet
trung-vt Jan 11, 2026
1f7eb0e
TEMP - Make testcases and plots
trung-vt Jan 11, 2026
313435d
TEMP - Controls which algorithms to run
trung-vt Jan 11, 2026
af2fac6
TEMP - Custom in-order sampling ratios
trung-vt Jan 11, 2026
7108b46
TEMP - Change argument to sampling_ratio and remove subtract_from_J
trung-vt Jan 11, 2026
1b7958e
TEMP - Keep all decimal places in CSV
trung-vt Jan 11, 2026
8863d05
TEMP - Use dynamic progress bar and prep big tests
trung-vt Jan 11, 2026
d4ff1d6
TEMP - Big test, 119 testcases for just 1 image
trung-vt Jan 11, 2026
0217f9f
TEMP - Big test for other images
trung-vt Jan 11, 2026
04cc29c
TEMP - Big test for other images
trung-vt Jan 11, 2026
acfa4e5
TEMP - Plot PCM Testcases
trung-vt Jan 11, 2026
5770eb1
TEMP - Add matlab live script demo
trung-vt Jan 12, 2026
ff12855
TEMP - Add matlab text file version of the live script
trung-vt Jan 12, 2026
950e683
TEMP - PCM experiments
trung-vt Jan 15, 2026
d898a90
TEMP - PCM experiment with Si_256_512x512 data
trung-vt Jan 15, 2026
79639c3
TEMP - Fix bug
trung-vt Jan 15, 2026
aa53555
TEMP - Plot tests
trung-vt Jan 15, 2026
d878582
TEMP - Run test Si_256 real measurements
trung-vt Jan 15, 2026
bd107d9
TEMP - Plot results
trung-vt Jan 16, 2026
11e520f
TEMP - Run with different seeds
trung-vt Jan 16, 2026
bed2327
TEMP - Add trials tqdm
trung-vt Jan 16, 2026
c146c5b
TEMP - Compute means and stds of metrics across trials for a given me…
trung-vt Jan 16, 2026
0705ca1
TEMP - Run more trials
trung-vt Jan 16, 2026
f1e236e
TEMP - Skip some trials
trung-vt Jan 16, 2026
b71bd58
TEMP - Run more trials
trung-vt Jan 16, 2026
5f6c4f3
TEMP - Run more trials
trung-vt Jan 17, 2026
08cdfee
TEMP - Run more trials with uniform sampling
trung-vt Jan 17, 2026
38b5a04
TEMP - Run more trials with uniform sampling
trung-vt Jan 18, 2026
4160d99
TEMP - Run more trials with uniform sampling
trung-vt Jan 18, 2026
2fcea47
TEMP - Run more trials with uniform sampling
trung-vt Jan 18, 2026
505e9bc
TEMP - Run more trials with uniform sampling
trung-vt Jan 19, 2026
2fc5bdc
Visualise drift
trung-vt Jan 20, 2026
7e856f6
TEMP - Run more trials
trung-vt Jan 22, 2026
94a22a0
Run PCM PnP-ADMM with gradient-step denoiser
trung-vt Jan 22, 2026
bcf7741
Temp - Box plots update
trung-vt Mar 11, 2026
f597586
Temp - Rearrange scripts
trung-vt Mar 11, 2026
8723e6a
Temp - Rearrange scripts
trung-vt Mar 11, 2026
059f747
Temp - Rearrange scripts
trung-vt Mar 11, 2026
0fed3d1
Temp - Rearrange scripts
trung-vt Mar 11, 2026
3c4b3cc
Temp - Refactor
trung-vt Mar 11, 2026
6b42ddc
Temp - Refactor
trung-vt Mar 11, 2026
e2d4a99
Temp - Refactor
trung-vt Mar 11, 2026
8d414fe
Temp - Refactor
trung-vt Mar 11, 2026
9cd4ee9
Temp - Refactor
trung-vt Mar 15, 2026
54eb56f
Temp - Refactor
trung-vt Mar 15, 2026
ecef3f9
Temp - Refactor
trung-vt Mar 15, 2026
548d9ce
Temp - Refactor
trung-vt Mar 15, 2026
126bae1
New PCM experiment implementation
trung-vt Mar 16, 2026
a69dd02
New PCM experiment implementation
trung-vt Mar 16, 2026
904523a
New PCM experiment implementation
trung-vt Mar 17, 2026
e21fb27
New PCM experiment implementation
trung-vt Mar 17, 2026
c8d9294
New PCM experiment implementation
trung-vt Mar 17, 2026
7bbf36e
Update README
trung-vt Mar 17, 2026
4114ee3
New PCM experiment implementation
trung-vt Mar 17, 2026
5d82398
Update README
trung-vt Mar 17, 2026
611081b
Update
trung-vt Mar 17, 2026
a92e74d
New PCM experiment implementation
trung-vt Mar 18, 2026
b6e31aa
New PCM experiment implementation
trung-vt Mar 18, 2026
b359970
New PCM experiment implementation
trung-vt Mar 18, 2026
d2ba825
New PCM experiment implementation
trung-vt Mar 18, 2026
099fb84
Remove unneeded file
trung-vt Mar 18, 2026
9abb2b4
Remove unneeded file
trung-vt Mar 18, 2026
db5bd59
Remove unneeded file
trung-vt Mar 18, 2026
f4757a0
Remove unneeded files
trung-vt Mar 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ dist
example.py
/slurm

wandb
wandb
.DS_Store
*.zip
16 changes: 8 additions & 8 deletions LION/classical_algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""LION classical algorithms."""
# """LION classical algorithms."""

from LION.classical_algorithms.conjugate_gradient import conjugate_gradient
from LION.classical_algorithms.fdk import fdk
from LION.classical_algorithms.fista import fista_l1
from LION.classical_algorithms.sirt import sirt
from LION.classical_algorithms.spgl1_torch import spgl1_torch
from LION.classical_algorithms.tv_min import tv_min
# from LION.classical_algorithms.conjugate_gradient import conjugate_gradient
# from LION.classical_algorithms.fdk import fdk
# from LION.classical_algorithms.fista import fista_l1
# from LION.classical_algorithms.sirt import sirt
# from LION.classical_algorithms.spgl1_torch import spgl1_torch
# from LION.classical_algorithms.tv_min import tv_min

__all__ = ["conjugate_gradient", "fdk", "fista_l1", "sirt", "spgl1_torch", "tv_min"]
# __all__ = ["conjugate_gradient", "fdk", "fista_l1", "sirt", "spgl1_torch", "tv_min"]
13 changes: 9 additions & 4 deletions LION/classical_algorithms/conjugate_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def conjugate_gradient(
d: torch.Tensor,
x0: torch.Tensor,
max_iter: int,
tol: float,
eps: float = 1e-14,
rel_tol: float = 0.0,
prog_bar: Callable | None = None,
) -> torch.Tensor:
"""
Conjugate gradient solver.
Expand Down Expand Up @@ -39,18 +41,21 @@ def conjugate_gradient(
d = r.clone()
rr = torch.sum(r**2)

for _ in range(max_iter):
iterator = (
prog_bar(range(max_iter), desc="CG iterations") if prog_bar else range(max_iter)
)
for _ in iterator:
z = matmul_closure(d)

dz = torch.sum(d * z)
# Check for breakdown
if abs(dz) < 1e-14:
if abs(dz) < eps:
break
alpha = rr / dz
x += alpha * d
r -= alpha * z

if torch.norm(r) / torch.norm(d) < tol:
if torch.norm(r) / torch.norm(d) < rel_tol:
break

rr_next = torch.sum(r**2)
Expand Down
11 changes: 7 additions & 4 deletions LION/classical_algorithms/fista.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""FISTA algorithm for l1-regularized problems."""

from __future__ import annotations

from typing import Callable
import math

import torch
Expand Down Expand Up @@ -41,7 +44,7 @@ def fista_l1(
tol: float = 1e-4,
L: float | None = None,
verbose: bool = False,
progress_bar: bool = False,
prog_bar: Callable | None = None,
) -> torch.Tensor:
r"""Solve :math:`\min_w \tfrac12\lVert A w - y\rVert_2^2 + \lambda \lVert w\rVert_1`
by FISTA.
Expand Down Expand Up @@ -111,9 +114,9 @@ def fista_l1(
z = w.clone()
t = 1.0

iterator = range(max_iter)
if progress_bar:
iterator = tqdm(iterator, desc="FISTA l1")
iterator = (
prog_bar(range(max_iter), desc="FISTA l1") if prog_bar else range(max_iter)
)
for k in iterator:
Az: torch.Tensor = op(z)
grad = op.adjoint(Az - y) # gradient of data term, shape (n,)
Expand Down
23 changes: 13 additions & 10 deletions LION/classical_algorithms/spgl1_torch.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
"""SPGL1 sparse reconstruction with torch operators."""

from typing import Any
import numpy as np
import torch
from scipy.sparse.linalg import LinearOperator
from spgl1 import spgl1
from spgl1 import spgl1, spg_bp

from LION.operators.Operator import Operator


def spgl1_torch(op: Operator, y: torch.Tensor, **spgl1_kwargs) -> torch.Tensor:
r"""Solve an l1 sparse reconstruction using SPGL1, wrapping torch operators.
def spgl1_torch(
op: Operator, y: torch.Tensor, **spgl1_kwargs
) -> tuple[torch.Tensor, Any]:
r"""Solve an l1 sparse reconstruction using SPGL1 with Basis Pursuit (BP), wrapping torch operators.

This is a thin wrapper around the Python SPGL1 solver ``spgl1.spgl1`` that
This is a thin wrapper around the Python SPGL1 solver ``spgl1.spg_bp`` that
uses torch operators for matrix-vector products. SPGL1 is a spectral
projected-gradient method for constrained l1 problems; see
[BergFriedlander2008]_ and [BergFriedlander2010]_.

This wrapper is built on top of the Python implementation ``spgl1.spgl1`` and
This wrapper is built on top of the Python implementation ``spgl1.spg_bp`` and
uses the same calling convention (argument names and behaviour); see
[SPGL1Python]_ for details.

Expand All @@ -28,7 +31,7 @@ def spgl1_torch(op: Operator, y: torch.Tensor, **spgl1_kwargs) -> torch.Tensor:
y : torch.Tensor
Measurements, shape ``(M,)``.
spgl1_kwargs : dict
Extra keyword args forwarded to ``spgl1.spgl1`` (for example
Extra keyword args forwarded to ``spgl1.spg_bp`` (for example
tolerances or iteration limits; see [SPGL1Python]_).

Returns
Expand Down Expand Up @@ -75,9 +78,9 @@ def rmatvec(r_np: np.ndarray) -> np.ndarray:
)

y_np = y.detach().cpu().numpy().ravel()
x0_np = np.zeros(n_w, dtype=np.float32)

x_np, _, _, _ = spgl1(A_linop, y_np, x0=x0_np, **spgl1_kwargs)
# x0_np = np.zeros(n_w, dtype=np.float32)
# x_np, _, _, _ = spgl1(A_linop, y_np, x0=x0_np, **spgl1_kwargs)
x_np, _, _, info = spg_bp(A_linop, y_np, **spgl1_kwargs)

w_hat = torch.from_numpy(x_np.astype(np.float32)).to(device).view_as(w0)
return w_hat
return w_hat, info
9 changes: 5 additions & 4 deletions LION/operators/DebiasOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from tabnanny import verbose
from typing import Callable

import torch
from tqdm import tqdm
Expand Down Expand Up @@ -112,7 +113,7 @@ def debias_ls(
support_tol: float = 1e-3,
max_iter: int = 200,
tol: float = 1e-5,
progress_bar: bool = False,
prog_bar: Callable | None = None,
) -> torch.Tensor:
"""Debiasing least squares on the support of w.

Expand Down Expand Up @@ -152,9 +153,9 @@ def debias_ls(

v = w[support].clone()

iterator = range(max_iter)
if progress_bar:
iterator = tqdm(iterator, desc="Debiasing LS")
iterator = (
prog_bar(range(max_iter), desc="Debiasing LS") if prog_bar else range(max_iter)
)
for _ in iterator:
r = op_s(v) - y
grad = op_s.adjoint(r)
Expand Down
41 changes: 12 additions & 29 deletions LION/operators/PhotocurrentMapOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from collections.abc import Sequence

import numpy as np
import torch
from spyrit.core.torch import fwht, ifwht
Expand All @@ -10,6 +12,7 @@


def normal_to_dyadic_permutation(J: int) -> np.ndarray:
"""Compute the permutation mapping normal order to dyadic order for WHT."""
nbits = 2 * J
n = 1 << nbits # total length = 2^(2J)
# ---- dyadic-by-scales permutation (dtype-safe bit ops) ----
Expand All @@ -28,48 +31,25 @@ def normal_to_dyadic_permutation(J: int) -> np.ndarray:
return permutation


class Subsampler:
def __init__(
self, n: int, delta: float, coarseJ: int, rng: np.random.Generator | None = None
) -> None:
if rng is None:
rng = np.random.default_rng()
# ---- random undersampling with coarseJ fully kept ----
m_total = int(np.ceil(delta * n))
m1 = min(1 << (2 * coarseJ), m_total)
m2 = m_total - m1
if m2 > 0:
idx_tail = rng.choice(n - m1, size=m2, replace=False) + m1
self._subsampled_indices = np.concatenate(
[np.arange(m1, dtype=np.int64), idx_tail.astype(np.int64)]
)
else:
self._subsampled_indices = np.arange(m1, dtype=np.int64)

@property
def subsampled_indices(self) -> np.ndarray:
return self._subsampled_indices


class PhotocurrentMapOp(Operator):
"""Photocurrent mapping operator using subsampled WHT and dyadic permutation.

Parameters
----------
J : int
The exponent such that the image size is (2^J, 2^J).
subsampler : Subsampler
The subsampler defining the measurement indices.
sampled_indices : Sequence[int] | np.ndarray | None, optional
The indices of the measurements to be taken, in coarse-to-fine (dyadic) order.
wht_dim : int, optional
The dimension along which to apply the WHT. Default is -1 (last dimension).
device : str or torch.device
device : str or torch.device | None, optional
Device where tensors are placed.
"""

def __init__(
self,
J: int,
subsampler: Subsampler,
sampled_indices: Sequence[int] | np.ndarray | None = None,
wht_dim: int = -1,
device: torch.device | str | None = None,
):
Expand All @@ -79,13 +59,16 @@ def __init__(
self.num_pixels = self.N * self.N
self.wht_dim = wht_dim

if sampled_indices is None:
sampled_indices = np.arange(self.num_pixels, dtype=np.int64)

# TODO: Add batch size
self._image_shape = (self.N, self.N)
self._data_shape = (subsampler.subsampled_indices.shape[0],)
self._data_shape = (len(sampled_indices),)

self.normal_to_dyadic_perm = normal_to_dyadic_permutation(J=J)
self.meas_indices_standard = torch.tensor(
self.normal_to_dyadic_perm[subsampler.subsampled_indices],
self.normal_to_dyadic_perm[sampled_indices],
dtype=torch.long,
device=self.device,
)
Expand Down
36 changes: 18 additions & 18 deletions LION/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
"""LION operators."""
# """LION operators."""

from LION.operators.CompositeOp import CompositeOp
from LION.operators.CTProjectionOp import CTProjectionOp
from LION.operators.DebiasOp import DebiasOp
from LION.operators.Operator import Operator
from LION.operators.PhotocurrentMapOp import PhotocurrentMapOp, Subsampler
from LION.operators.WalshHadamard2D import WalshHadamard2D
from LION.operators.Wavelet2D import Wavelet2D
# from LION.operators.CompositeOp import CompositeOp
# from LION.operators.CTProjectionOp import CTProjectionOp
# from LION.operators.DebiasOp import DebiasOp
# from LION.operators.Operator import Operator
# from LION.operators.PhotocurrentMapOp import PhotocurrentMapOp, Subsampler
# from LION.operators.WalshHadamard2D import WalshHadamard2D
# from LION.operators.Wavelet2D import Wavelet2D

__all__ = [
"CompositeOp",
"CTProjectionOp",
"DebiasOp",
"Operator",
"PhotocurrentMapOp",
"Subsampler",
"WalshHadamard2D",
"Wavelet2D",
]
# __all__ = [
# "CompositeOp",
# "CTProjectionOp",
# "DebiasOp",
# "Operator",
# "PhotocurrentMapOp",
# "Subsampler",
# "WalshHadamard2D",
# "Wavelet2D",
# ]
Loading
Loading