Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ ignore = [
channels = ["conda-forge"]
platforms = ["linux-64", "linux-aarch64", "osx-arm64", "osx-64"]
requires-pixi = ">=0.39.0"
preview = ["pixi-build"]

[tool.pixi.environments]
# We resolve only two groups: cpu and gpu.
Expand Down
240 changes: 227 additions & 13 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,23 +916,33 @@ class LinkParametrizableShape:
Box: ClassVar[int] = 0
Cylinder: ClassVar[int] = 1
Sphere: ClassVar[int] = 2
Mesh: ClassVar[int] = 3


@jax_dataclasses.pytree_dataclass
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
class HwLinkMetadata(JaxsimDataclass):
"""
Class storing the hardware parameters of a link.

Attributes:
link_shape: The shape of the link.
0 = box, 1 = cylinder, 2 = sphere, -1 = unsupported.
geometry: The dimensions of the link.
box: [lx,ly,lz], cylinder: [r,l,0], sphere: [r,0,0].
0 = box, 1 = cylinder, 2 = sphere, 3 = mesh, -1 = unsupported.
geometry: Shape parameters used by HW parametrization.
box: [lx,ly,lz], cylinder: [r,l,0], sphere: [r,0,0],
mesh: cumulative anisotropic scale factors [sx,sy,sz] (initialized to [1,1,1]).
density: The density of the link.
Comment thread
flferretti marked this conversation as resolved.
L_H_G: The homogeneous transformation matrix from the link frame to the CoM frame G.
L_H_vis: The homogeneous transformation matrix from the link frame to the visual frame.
L_H_pre_mask: The mask indicating the link's child joint indices.
L_H_pre: The homogeneous transforms for child joints.
mesh_moments: Precomputed volumetric moments for mesh shapes (n_links x 13).
Each row stores [V_ref, com_x, com_y, com_z, Σ_00..Σ_22] where V_ref is the
reference volume, com is the volumetric center of mass, and Σ is the
volumetric covariance matrix at the origin. Zero for non-mesh links.
mesh_vertices: The original centered mesh vertices (Nx3) for mesh shapes, None otherwise.
mesh_faces: The mesh triangle faces (Mx3 integer indices) for mesh shapes, None otherwise.
mesh_offset: The original mesh centroid offset (3D vector) for mesh shapes, None otherwise.
mesh_uri: The path to the mesh file for reference, None otherwise.
"""

link_shape: jtp.Vector
Expand All @@ -942,6 +952,11 @@ class HwLinkMetadata(JaxsimDataclass):
L_H_vis: jtp.Matrix
L_H_pre_mask: jtp.Vector
L_H_pre: jtp.Matrix
mesh_moments: jtp.Matrix
mesh_vertices: Static[tuple[HashedNumpyArray | None, ...] | None]
mesh_faces: Static[tuple[HashedNumpyArray | None, ...] | None]
mesh_offset: Static[tuple[HashedNumpyArray | None, ...] | None]
mesh_uri: Static[tuple[str | None, ...] | None]

@classmethod
def empty(cls) -> HwLinkMetadata:
Expand All @@ -954,7 +969,174 @@ def empty(cls) -> HwLinkMetadata:
L_H_vis=jnp.array([], dtype=float),
L_H_pre_mask=jnp.array([], dtype=bool),
L_H_pre=jnp.array([], dtype=float),
mesh_moments=jnp.zeros((0, 13), dtype=float),
mesh_vertices=None,
mesh_faces=None,
mesh_offset=None,
mesh_uri=None,
)

@staticmethod
def compute_mesh_inertia(
vertices: jtp.Matrix, faces: jtp.Matrix, density: jtp.Float
) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]:
"""
Compute mass, center of mass, and inertia tensor from mesh geometry.

Uses the divergence theorem to compute volumetric properties by integrating
over tetrahedra formed between the mesh surface and the origin.

Args:
vertices: Mesh vertices (Nx3) in the link frame, should be centered.
faces: Triangle face indices (Mx3), integer indices into vertices array.
density: Material density.

Returns:
A tuple containing the computed mass, the CoM position and the 3x3
inertia tensor at the CoM.
"""

# Extract triangles from vertices using face indices
triangles = vertices[faces.astype(int)]
A, B, C = triangles[:, 0], triangles[:, 1], triangles[:, 2]

# Compute signed volume of tetrahedra relative to origin
# vol = 1/6 * (A . (B x C))
tetrahedron_volumes = jnp.sum(A * jnp.cross(B, C), axis=1) / 6.0

total_signed_volume = jnp.sum(tetrahedron_volumes)

# Normalize the global winding sign so positive density yields non-negative mass.
orientation_sign = jnp.where(total_signed_volume < 0, -1.0, 1.0)
tetrahedron_volumes = tetrahedron_volumes * orientation_sign
total_volume = jnp.sum(tetrahedron_volumes)

eps = jnp.asarray(1e-12, dtype=total_volume.dtype)
is_valid_volume = jnp.abs(total_volume) > eps
safe_total_volume = jnp.where(is_valid_volume, total_volume, 1.0)
mass = jnp.where(is_valid_volume, total_volume * density, 0.0)

# Compute center of mass
tet_coms = (A + B + C) / 4.0
com_position = jnp.where(
is_valid_volume,
jnp.sum(tet_coms * tetrahedron_volumes[:, None], axis=0)
/ safe_total_volume,
jnp.zeros(3, dtype=vertices.dtype),
)
Comment thread
flferretti marked this conversation as resolved.

# Compute inertia tensor with covariance approach
def compute_tetrahedron_covariance(a, b, c, vol):
s = a + b + c
return (vol / 20.0) * (
jnp.outer(a, a) + jnp.outer(b, b) + jnp.outer(c, c) + jnp.outer(s, s)
)

covariance_matrices = jax.vmap(compute_tetrahedron_covariance)(
A, B, C, tetrahedron_volumes
)
Σ_origin = jnp.sum(covariance_matrices, axis=0)

# Shift to CoM using parallel axis theorem
Σ_com = Σ_origin * density - mass * jnp.outer(com_position, com_position)

# Convert covariance to inertia tensor
I_com = jnp.trace(Σ_com) * jnp.eye(3, dtype=vertices.dtype) - Σ_com
I_com = jnp.where(
is_valid_volume, I_com, jnp.zeros((3, 3), dtype=vertices.dtype)
)

return mass, com_position, I_com

@staticmethod
def precompute_mesh_moments(vertices: np.ndarray, faces: np.ndarray) -> np.ndarray:
"""
Precompute volumetric moments from reference mesh geometry.

Computes the reference volume, center of mass, and volumetric covariance
matrix at the origin using numpy. These 13 scalars are sufficient to
analytically reconstruct mass and inertia under any anisotropic scaling,
avoiding the need to embed full mesh arrays in JIT-compiled programs.

Args:
vertices: Mesh vertices (Nx3), should be centered.
faces: Triangle face indices (Mx3).

Returns:
A 13-element array: [V_ref, com_x, com_y, com_z, Σ_00..Σ_22].
"""

triangles = vertices[faces.astype(int)]
A, B, C = triangles[:, 0], triangles[:, 1], triangles[:, 2]

volumes = np.sum(A * np.cross(B, C), axis=1) / 6.0

total_signed = np.sum(volumes)
sign = np.sign(total_signed) if abs(total_signed) > 1e-12 else 1.0
volumes = volumes * sign
V_ref = np.sum(volumes)

if abs(V_ref) < 1e-12:
return np.zeros(13, dtype=np.float64)

# Center of mass
com = np.sum(volumes[:, None] * (A + B + C) / 4.0, axis=0) / V_ref

# Volumetric covariance at origin (same formula as compute_mesh_inertia)
S = A + B + C
cov = (volumes[:, None, None] / 20.0) * (
A[:, :, None] * A[:, None, :]
+ B[:, :, None] * B[:, None, :]
+ C[:, :, None] * C[:, None, :]
+ S[:, :, None] * S[:, None, :]
)
Sigma = np.sum(cov, axis=0)

return np.concatenate([[V_ref], com, Sigma.flatten()])

@staticmethod
def compute_mesh_inertia_from_moments(
moments: jtp.Vector, dims: jtp.Vector, density: jtp.Float
) -> tuple[jtp.Float, jtp.Matrix]:
"""
Compute mass and inertia tensor from precomputed volumetric moments.

Uses analytical scaling laws to derive physical properties under
anisotropic scaling without requiring the full mesh geometry.

Under scaling S = diag(sx, sy, sz):
- V' = det(S) * V_ref
- com' = S @ com_ref
- Σ_origin' = det(S) * S @ Σ_ref @ S

Args:
moments: Precomputed moments array of length 13.
dims: Current anisotropic scale factors [sx, sy, sz].
density: Current material density.

Returns:
A tuple of (mass, inertia_at_com).
"""

V_ref = moments[0]
com_ref = moments[1:4]
Sigma_ref = moments[4:13].reshape(3, 3)

det_s = dims[0] * dims[1] * dims[2]
S = jnp.diag(dims)

mass = density * V_ref * det_s
com = dims * com_ref

Sigma_scaled = det_s * (S @ Sigma_ref @ S)
Sigma_com = density * Sigma_scaled - mass * jnp.outer(com, com)
I_com = jnp.trace(Sigma_com) * jnp.eye(3) - Sigma_com

is_valid = V_ref > 1e-12
mass = jnp.where(is_valid, mass, 0.0)
I_com = jnp.where(is_valid, I_com, jnp.zeros((3, 3)))

return mass, I_com

@staticmethod
def compute_mass_and_inertia(
Expand All @@ -977,7 +1159,7 @@ def compute_mass_and_inertia(
- inertia: The computed inertia tensor of the hardware link.
"""

def box(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
def box(dims, density, _moments) -> tuple[jtp.Float, jtp.Matrix]:
lx, ly, lz = dims

mass = density * lx * ly * lz
Expand All @@ -991,7 +1173,7 @@ def box(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
)
return mass, inertia

def cylinder(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
def cylinder(dims, density, _moments) -> tuple[jtp.Float, jtp.Matrix]:
r, l, _ = dims

mass = density * (jnp.pi * r**2 * l)
Expand All @@ -1006,7 +1188,7 @@ def cylinder(dims, density) -> tuple[jtp.Float, jtp.Matrix]:

return mass, inertia

def sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
def sphere(dims, density, _moments) -> tuple[jtp.Float, jtp.Matrix]:
r = dims[0]

mass = density * (4 / 3 * jnp.pi * r**3)
Expand All @@ -1015,16 +1197,35 @@ def sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]:

return mass, inertia

def compute_mass_inertia(shape_idx, dims, density):
return jax.lax.switch(shape_idx, (box, cylinder, sphere), dims, density)
def mesh(dims, density, moments) -> tuple[jtp.Float, jtp.Matrix]:
return HwLinkMetadata.compute_mesh_inertia_from_moments(
moments, dims, density
)

def compute_mass_inertia(shape_idx, dims, density, moments):
def unsupported_case(_):
return (
jnp.asarray(0.0, dtype=density.dtype),
jnp.zeros((3, 3), dtype=density.dtype),
)

def supported_case(idx):
return jax.lax.switch(
idx, (box, cylinder, sphere, mesh), dims, density, moments
)

return jax.lax.cond(
shape_idx < 0, unsupported_case, supported_case, shape_idx
)

mass, inertia = jax.vmap(compute_mass_inertia)(
masses, inertias = jax.vmap(compute_mass_inertia)(
hw_link_metadata.link_shape,
hw_link_metadata.geometry,
hw_link_metadata.density,
hw_link_metadata.mesh_moments,
)

return mass, inertia
return masses, inertias

@staticmethod
def _convert_scaling_to_3d_vector(
Expand All @@ -1034,7 +1235,7 @@ def _convert_scaling_to_3d_vector(
Convert scaling factors for specific shape dimensions into a 3D scaling vector.

Args:
link_shapes: The link_shapes of the link (e.g., box, sphere, cylinder).
link_shapes: The link_shapes of the link (e.g., box, sphere, cylinder, mesh).
scaling_factors: The scaling factors for the shape dimensions.

Returns:
Expand All @@ -1045,17 +1246,20 @@ def _convert_scaling_to_3d_vector(
- Box: [lx, ly, lz]
- Cylinder: [r, r, l]
- Sphere: [r, r, r]
- Mesh: [sx, sy, sz]
"""

# Index mapping for each shape type (link_shapes x 3 dims)
# Box: [lx, ly, lz] -> [0, 1, 2]
# Cylinder: [r, r, l] -> [0, 0, 1]
# Sphere: [r, r, r] -> [0, 0, 0]
# Mesh: [sx, sy, sz] -> [0, 1, 2]
shape_indices = jnp.array(
[
[0, 1, 2], # Box
[0, 0, 1], # Cylinder
[0, 0, 0], # Sphere
[0, 1, 2], # Mesh
]
)

Expand Down Expand Up @@ -1117,9 +1321,19 @@ def box(parent_idx, L_p_C):
]
)

def mesh(parent_idx, L_p_C):
sx, sy, sz = scaling_factors.dims[parent_idx]
return jnp.hstack(
[
L_p_C[0] * sx,
L_p_C[1] * sy,
L_p_C[2] * sz,
]
)

new_positions = jax.vmap(
lambda shape_idx, parent_idx, L_p_C: jax.lax.switch(
shape_idx, (box, cylinder, sphere), parent_idx, L_p_C
shape_idx, (box, cylinder, sphere, mesh), parent_idx, L_p_C
)
)(
parent_link_shapes,
Expand Down
Loading