-
Notifications
You must be signed in to change notification settings - Fork 2.7k
[MISC] Add register-only tiled cholesky, and incremental H patching, for performance. #2659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 45 commits
f9244a4
2e4778b
70f9374
01f3cbe
d0aca95
ee9335b
b6ba7fe
8348d06
fcf3f8e
348a822
52927a0
2030a00
016390c
37cd1c5
9ebf712
80f6d45
bac01c7
9c8c986
0b9a67c
f00edd7
1de2b6c
36be574
459c235
aeda794
7b45735
0356dce
eabefd9
f2225dc
61a70e7
1230f65
735cff4
aa32f7d
5fcf106
d5cb5f1
716ee91
e591827
c840edb
9c1690d
2234bd0
fcafdb5
6f90d8a
60f2af9
f6df0a5
3583394
e1be020
9561149
5654141
ef1c3e2
7bfed26
4c01929
02ff753
a4943cf
ae548f4
5fe6d2e
cc4d140
56a8aa6
cc978ea
a1ec2fb
4b49f2a
7cb0e3b
17ec942
d803919
6b90dc5
7714b96
e661aa3
86c3233
adbb55c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import numpy as np | ||
| import quadrants as qd | ||
| from quadrants.lang.simt.tile16 import make_tile16x16 | ||
| import torch | ||
| from frozendict import frozendict | ||
|
|
||
|
|
@@ -1456,6 +1457,7 @@ def func_hessian_direct_batch( | |
| def func_hessian_direct_tiled( | ||
| constraint_state: array_class.ConstraintState, | ||
| rigid_global_info: array_class.RigidGlobalInfo, | ||
| check_full_hessian: qd.template() = False, | ||
| ): | ||
| """Compute the Hessian matrix `H = M + J.T @ D @ J of the optimization problem for all environment at once. | ||
|
|
||
|
|
@@ -1467,6 +1469,9 @@ def func_hessian_direct_tiled( | |
| optimization problem fits in a single block, i.e. n_constraints <= 32 and n_dofs <= 64. | ||
|
|
||
| Note that only the lower triangular part will be updated for efficiency, because the Hessian matrix is symmetric. | ||
|
|
||
| When check_full_hessian is True (used with H patching), skips envs where | ||
| use_full_hessian == 0 (those get patched instead of rebuilt). | ||
|
hughperkins marked this conversation as resolved.
Outdated
|
||
| """ | ||
| _B = constraint_state.grad.shape[1] | ||
| n_dofs = constraint_state.nt_H.shape[1] | ||
|
|
@@ -1492,6 +1497,9 @@ def func_hessian_direct_tiled( | |
| continue | ||
| if constraint_state.n_constraints[i_b] == 0 or not constraint_state.improved[i_b]: | ||
| continue | ||
| if qd.static(check_full_hessian): | ||
| if constraint_state.use_full_hessian[i_b] == 0: | ||
| continue | ||
|
|
||
| jac_row = qd.simt.block.SharedArray((MAX_CONSTRAINTS_PER_BLOCK, MAX_DOFS_PER_BLOCK), gs.qd_float) | ||
| jac_col = qd.simt.block.SharedArray((MAX_CONSTRAINTS_PER_BLOCK, MAX_DOFS_PER_BLOCK), gs.qd_float) | ||
|
|
@@ -1607,6 +1615,7 @@ def func_cholesky_factor_direct_batch( | |
|
|
||
| n_dofs = constraint_state.nt_H.shape[1] | ||
|
|
||
| # In-place factorization on nt_H (batch path never uses H patching) | ||
| for i_d in range(n_dofs): | ||
| tmp = constraint_state.nt_H[i_b, i_d, i_d] | ||
| for j_d in range(i_d): | ||
|
|
@@ -1621,6 +1630,17 @@ def func_cholesky_factor_direct_batch( | |
| constraint_state.nt_H[i_b, j_d, i_d] = (constraint_state.nt_H[i_b, j_d, i_d] - dot) * tmp | ||
|
|
||
|
|
||
| Tile16x16 = make_tile16x16(gs.qd_float) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this not done by Quadrants? In some kind of singleton factory.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How are you imaginging this usage woud look?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to "annotations"? qd.types.Tile16x16(dtype=gs.qd_float) |
||
|
|
||
|
|
||
| @qd.func | ||
| def _butterfly_reduce_16(val, tid): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if this should be moved into quadrants somewhere?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probalby right?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (also, I kind of lean towards the name descrinbg the effect and behavior, rather than the detailed implementation?)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes that would be much better in Genesis where lower expertise from maintainers can be expected. Nice to mention the implementation in doc string though.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed you should move it in Quadrants i think. My previous comment still hold since it would be public API. |
||
| """Sum val across 16 threads using butterfly reduction via subgroup shuffles (4 rounds).""" | ||
| for i in qd.static(range(4)): | ||
| val = val + qd.simt.subgroup.shuffle(val, qd.u32(tid ^ (8 >> i))) | ||
| return val | ||
|
|
||
|
|
||
| @qd.func | ||
| def func_cholesky_factor_direct_tiled( | ||
| constraint_state: array_class.ConstraintState, | ||
|
|
@@ -1629,77 +1649,194 @@ def func_cholesky_factor_direct_tiled( | |
| ): | ||
| """Compute the Cholesky factorization L of the Hessian matrix H = L @ L.T for a given environment `i_b`. | ||
|
|
||
| This implementation is specialized for GPU backend and highly optimized for it using shared memory and cooperative | ||
| threading. The current implementation only supports n_dofs <= 64 for 64bits precision and n_dofs <= 92 for 32bits | ||
| precision due to shared memory storage being limited to 48kB. Note that the amount of shared memory available is | ||
| hardware-specific, but the 48kB default limit without enabling dedicated GPU context flag is hardware-agnostic on | ||
| modern GPUs. | ||
| This implementation is specialized for GPU backend and highly optimized for it using a left-looking blocked algorithm | ||
| with Tile16x16 primitives (potrf, trsm, syr_sub, ger_sub), all operating entirely in registers via subgroup shuffles. | ||
| No shared memory or block synchronization needed. This function has no inherent DOF limit, but the fused variant | ||
| (func_cholesky_and_solve_fused_tiled) requires shared memory for L, so the caller gates both behind the same | ||
| shared-memory-based DOF threshold: n_dofs <= 64 (f64) or 96 (f32) with 48kB default shared memory, higher with | ||
| opt-in shared memory (e.g. 160/224 on RTX PRO 6000). | ||
|
|
||
| Beware the Hessian matrix is re-purposed to store its Cholesky factorization to sparse memory resources. | ||
| Beware the Hessian matrix is re-purposed to store its Cholesky factorization to spare memory resources. | ||
|
|
||
| Note that only the lower triangular part will be updated for efficiency, because the Hessian matrix is symmetric. | ||
| When n_dofs is not a multiple of 16, partial tiles are padded with identity (diagonal=1, off-diagonal=0) so the | ||
| factorization is correct for the original n_dofs x n_dofs submatrix. | ||
| """ | ||
| EPS = rigid_global_info.EPS[None] | ||
|
|
||
| _B = constraint_state.grad.shape[1] | ||
| n_dofs = constraint_state.nt_H.shape[1] | ||
| N_BLOCKS = (n_dofs + Tile16x16.SIZE - 1) // Tile16x16.SIZE | ||
|
|
||
| # Performance is optimal for BLOCK_DIM = 64 | ||
| BLOCK_DIM = qd.static(64) | ||
| qd.loop_config(name="cholesky_factor_direct_tiled", block_dim=Tile16x16.SIZE) | ||
| for i in range(_B * Tile16x16.SIZE): | ||
| tid = i % Tile16x16.SIZE | ||
| i_b = i // Tile16x16.SIZE | ||
| if i_b >= _B: | ||
| continue | ||
| if constraint_state.n_constraints[i_b] == 0 or not constraint_state.improved[i_b]: | ||
| continue | ||
|
|
||
| for kb in range(N_BLOCKS): | ||
| k0 = kb * Tile16x16.SIZE | ||
|
|
||
| L_kk = Tile16x16() | ||
| if k0 + tid < n_dofs: | ||
| L_kk[:] = constraint_state.nt_H[i_b, k0 : k0 + Tile16x16.SIZE, k0:n_dofs] | ||
| else: | ||
| L_kk.eye_() | ||
|
|
||
| for jb in range(kb): | ||
| j0 = jb * Tile16x16.SIZE | ||
| for t in range(Tile16x16.SIZE): | ||
| v = gs.qd_float(0.0) | ||
| if k0 + tid < n_dofs: | ||
| v = constraint_state.nt_H[i_b, k0 + tid, j0 + t] | ||
| L_kk -= qd.outer(v, v) | ||
|
|
||
| L_kk.cholesky_(EPS) | ||
|
|
||
| for ib in range(kb + 1, N_BLOCKS): | ||
| i0 = ib * Tile16x16.SIZE | ||
|
|
||
| L_ik = Tile16x16() | ||
| if i0 + tid < n_dofs: | ||
| L_ik[:] = constraint_state.nt_H[i_b, i0 : i0 + Tile16x16.SIZE, k0:n_dofs] | ||
|
|
||
| for jb in range(kb): | ||
| j0 = jb * Tile16x16.SIZE | ||
| for t in range(Tile16x16.SIZE): | ||
| v_own = gs.qd_float(0.0) | ||
| v_diag = gs.qd_float(0.0) | ||
| if i0 + tid < n_dofs: | ||
| v_own = constraint_state.nt_H[i_b, i0 + tid, j0 + t] | ||
| if k0 + tid < n_dofs: | ||
| v_diag = constraint_state.nt_H[i_b, k0 + tid, j0 + t] | ||
| L_ik -= qd.outer(v_own, v_diag) | ||
|
|
||
| L_kk.solve_triangular_(L_ik) | ||
|
|
||
| if i0 + tid < n_dofs: | ||
| constraint_state.nt_H[i_b, i0 : i0 + Tile16x16.SIZE, k0:n_dofs] = L_ik | ||
|
|
||
| if k0 + tid < n_dofs: | ||
| constraint_state.nt_H[i_b, k0 : k0 + Tile16x16.SIZE, k0:n_dofs] = L_kk | ||
|
hughperkins marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| @qd.func | ||
| def func_cholesky_and_solve_fused_tiled( | ||
| constraint_state: array_class.ConstraintState, | ||
| rigid_global_info: array_class.RigidGlobalInfo, | ||
| static_rigid_sim_config: qd.template(), | ||
| ): | ||
| """Fused Cholesky factorization and triangular solve, keeping L in shared memory. | ||
|
|
||
| Factorizes H = L L^T using register-resident 16x16 tiles, storing completed L tiles | ||
| in shared memory. Then solves L L^T x = g (forward + backward substitution) in-place | ||
| and writes the result to Mgrad, without ever writing L to global memory. | ||
|
hughperkins marked this conversation as resolved.
Outdated
|
||
| """ | ||
| EPS = rigid_global_info.EPS[None] | ||
| MAX_DOFS = qd.static(static_rigid_sim_config.tiled_n_dofs) | ||
|
|
||
| n_lower_tri = n_dofs * (n_dofs + 1) // 2 | ||
| _B = constraint_state.grad.shape[1] | ||
| n_dofs = constraint_state.nt_H.shape[1] | ||
| N_BLOCKS = (n_dofs + Tile16x16.SIZE - 1) // Tile16x16.SIZE | ||
|
|
||
| qd.loop_config(name="cholesky_factor_direct_tiled", block_dim=BLOCK_DIM) | ||
| for i in range(_B * BLOCK_DIM): | ||
| tid = i % BLOCK_DIM | ||
| i_b = i // BLOCK_DIM | ||
| qd.loop_config(name="cholesky_and_solve_fused_tiled", block_dim=Tile16x16.SIZE) | ||
| for i in range(_B * Tile16x16.SIZE): | ||
| tid = i % Tile16x16.SIZE | ||
| i_b = i // Tile16x16.SIZE | ||
| if i_b >= _B: | ||
| continue | ||
| if constraint_state.n_constraints[i_b] == 0 or not constraint_state.improved[i_b]: | ||
| continue | ||
|
|
||
| # Padding +1 to avoid memory bank conflicts that would cause access serialization | ||
| H = qd.simt.block.SharedArray((MAX_DOFS, MAX_DOFS + 1), gs.qd_float) | ||
| L_sh = qd.simt.block.SharedArray((MAX_DOFS, MAX_DOFS), gs.qd_float) | ||
|
hughperkins marked this conversation as resolved.
Outdated
|
||
| v_sh = qd.simt.block.SharedArray((MAX_DOFS,), gs.qd_float) | ||
|
|
||
| for kb in range(N_BLOCKS): | ||
| k0 = kb * Tile16x16.SIZE | ||
|
|
||
| # Copy the lower triangular part of the entire Hessian matrix to shared memory for efficiency | ||
| i_pair = tid | ||
| while i_pair < n_lower_tri: | ||
| i_d1, i_d2 = linear_to_lower_tri(i_pair) | ||
| H[i_d1, i_d2] = constraint_state.nt_H[i_b, i_d1, i_d2] | ||
| i_pair = i_pair + BLOCK_DIM | ||
| L_kk = Tile16x16() | ||
| if k0 + tid < n_dofs: | ||
| L_kk[:] = constraint_state.nt_H[i_b, k0 : k0 + Tile16x16.SIZE, k0:n_dofs] | ||
| else: | ||
| L_kk.eye_() | ||
|
|
||
| for jb in range(kb): | ||
| j0 = jb * Tile16x16.SIZE | ||
| for t in range(Tile16x16.SIZE): | ||
| v = gs.qd_float(0.0) | ||
| if k0 + tid < n_dofs: | ||
| v = L_sh[k0 + tid, j0 + t] | ||
| L_kk -= qd.outer(v, v) | ||
|
|
||
| L_kk.cholesky_(EPS) | ||
|
|
||
| for ib in range(kb + 1, N_BLOCKS): | ||
| i0 = ib * Tile16x16.SIZE | ||
|
|
||
| L_ik = Tile16x16() | ||
| if i0 + tid < n_dofs: | ||
| L_ik[:] = constraint_state.nt_H[i_b, i0 : i0 + Tile16x16.SIZE, k0:n_dofs] | ||
|
|
||
| for jb in range(kb): | ||
| j0 = jb * Tile16x16.SIZE | ||
| for t in range(Tile16x16.SIZE): | ||
| v_own = gs.qd_float(0.0) | ||
| v_diag = gs.qd_float(0.0) | ||
| if i0 + tid < n_dofs: | ||
| v_own = L_sh[i0 + tid, j0 + t] | ||
| if k0 + tid < n_dofs: | ||
| v_diag = L_sh[k0 + tid, j0 + t] | ||
| L_ik -= qd.outer(v_own, v_diag) | ||
|
|
||
| L_kk.solve_triangular_(L_ik) | ||
|
|
||
| if i0 + tid < n_dofs: | ||
| L_sh[i0 : i0 + Tile16x16.SIZE, k0:n_dofs] = L_ik | ||
|
|
||
| if k0 + tid < n_dofs: | ||
| L_sh[k0 : k0 + Tile16x16.SIZE, k0:n_dofs] = L_kk | ||
|
hughperkins marked this conversation as resolved.
Outdated
|
||
|
|
||
| # --- Fused solve: Ly = grad (forward), L^T x = y (backward) --- | ||
| # L is fully computed in L_sh. Load gradient into v_sh. | ||
| k = tid | ||
| while k < n_dofs: | ||
| v_sh[k] = constraint_state.grad[k, i_b] | ||
| k = k + Tile16x16.SIZE | ||
| qd.simt.block.sync() | ||
|
|
||
| # Loop over all columns sequentially, which is an integral part of Cholesky-Crout algorithm and cannot be | ||
| # avoided. | ||
| # Forward substitution: solve L @ y = grad (parallel dot with 16 threads) | ||
|
hughperkins marked this conversation as resolved.
|
||
| for i_d in range(n_dofs): | ||
| # Compute the diagonal of the Cholesky factor L for the column i being considered, ie | ||
| # L_{i,i} = sqrt(A_{i,i} - sum_{j=1}^{i-1}(L_{i,j} ** 2 )) | ||
|
hughperkins marked this conversation as resolved.
|
||
| dot = gs.qd_float(0.0) | ||
| j = tid | ||
| while j < i_d: | ||
| dot = dot + L_sh[i_d, j] * v_sh[j] | ||
| j = j + Tile16x16.SIZE | ||
| dot = _butterfly_reduce_16(dot, tid) | ||
| if tid == 0: | ||
| tmp = H[i_d, i_d] | ||
| for j_d in range(i_d): | ||
| tmp = tmp - H[i_d, j_d] ** 2 | ||
| H[i_d, i_d] = qd.sqrt(qd.max(tmp, EPS)) | ||
| v_sh[i_d] = (v_sh[i_d] - dot) / L_sh[i_d, i_d] | ||
| qd.simt.block.sync() | ||
|
|
||
| # Compute all the off-diagonal terms of the Cholesky factor L for the column i being considered, ie | ||
| # L_{j,i} = 1 / L_{i,i} (A_{j,i} - sum_{k=1}^{i-1}(L_{j,k} L_{i,k}), for j > i | ||
| inv_diag = 1.0 / H[i_d, i_d] | ||
| j_d = i_d + 1 + tid | ||
| while j_d < n_dofs: | ||
| dot = gs.qd_float(0.0) | ||
| for k_d in range(i_d): | ||
| dot = dot + H[j_d, k_d] * H[i_d, k_d] | ||
| H[j_d, i_d] = (H[j_d, i_d] - dot) * inv_diag | ||
| j_d = j_d + BLOCK_DIM | ||
| # Backward substitution: solve L^T @ x = y (parallel dot with 16 threads) | ||
|
hughperkins marked this conversation as resolved.
|
||
| for i_d_ in range(n_dofs): | ||
| i_d = n_dofs - 1 - i_d_ | ||
| dot = gs.qd_float(0.0) | ||
| j = i_d + 1 + tid | ||
| while j < n_dofs: | ||
| dot = dot + L_sh[j, i_d] * v_sh[j] | ||
| j = j + Tile16x16.SIZE | ||
| dot = _butterfly_reduce_16(dot, tid) | ||
| if tid == 0: | ||
| v_sh[i_d] = (v_sh[i_d] - dot) / L_sh[i_d, i_d] | ||
| qd.simt.block.sync() | ||
|
|
||
| # Copy the final result back from shared memory, only considered the lower triangular part | ||
| i_pair = tid | ||
| while i_pair < n_lower_tri: | ||
| i_d1, i_d2 = linear_to_lower_tri(i_pair) | ||
| constraint_state.nt_H[i_b, i_d1, i_d2] = H[i_d1, i_d2] | ||
| i_pair = i_pair + BLOCK_DIM | ||
| # Write Mgrad to global memory | ||
| k = tid | ||
| while k < n_dofs: | ||
| constraint_state.Mgrad[k, i_b] = v_sh[k] | ||
| k = k + Tile16x16.SIZE | ||
|
|
||
|
|
||
| @qd.func | ||
|
|
@@ -1896,6 +2033,7 @@ def func_cholesky_solve_batch( | |
| ): | ||
| n_dofs = constraint_state.Mgrad.shape[0] | ||
|
|
||
| # Batch path: L is in nt_H (in-place factorization) | ||
| for i_d in range(n_dofs): | ||
| curr_out = constraint_state.grad[i_d, i_b] | ||
| for j_d in range(i_d): | ||
|
|
@@ -1954,7 +2092,7 @@ def func_cholesky_solve_tiled( | |
| (NUM_WARPS if qd.static(ENABLE_WARP_REDUCTION) else BLOCK_DIM,), gs.qd_float | ||
| ) | ||
|
|
||
| # Copy the lower triangular part of the entire Hessian matrix to shared memory for efficiency | ||
| # Copy the lower triangular part of L (Cholesky factor) to shared memory for efficiency | ||
| i_flat = tid | ||
| while i_flat < n_dofs_2: | ||
| i_d1 = i_flat // n_dofs | ||
|
|
@@ -3037,6 +3175,8 @@ def func_solve_init( | |
| qd.loop_config(name="init_improved", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) | ||
| for i_b in qd.ndrange(_B): | ||
| constraint_state.improved[i_b] = constraint_state.n_constraints[i_b] > 0 | ||
| constraint_state.use_full_hessian[i_b] = 1 | ||
| constraint_state.solver_iter_counter[()] = 0 | ||
|
|
||
| if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): | ||
| func_hessian_and_cholesky_factor_direct( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.