Skip to content
Open
89 changes: 77 additions & 12 deletions genesis/engine/sensors/base_sensor.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import functools
from dataclasses import dataclass, field
from functools import partial
from typing import TYPE_CHECKING, ClassVar, Generic, Sequence, Type, TypeVar, get_args, get_origin

from typing_extensions import TypeVar as TypeVarWithDefault
from typing import TYPE_CHECKING, ClassVar, Generic, Sequence, TypeVar, get_args, get_origin

import numpy as np
import quadrants as qd
import torch
from typing_extensions import TypeVar as TypeVarWithDefault

import genesis as gs
from genesis.typing import NumArrayType, NumericType
from genesis.repr_base import RBC
from genesis.typing import NumArrayType, NumericType
from genesis.utils.geom import euler_to_quat
from genesis.utils.misc import broadcast_tensor, concat_with_tensor, make_tensor_field

Expand Down Expand Up @@ -39,6 +38,19 @@ def _to_tuple(*values: NumArrayType, length_per_value: int = 3) -> tuple[Numeric
return full_tuple


def assert_measured_cache_will_update(method):
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
if not self._shared_metadata.update_ground_truth_only:
gs.raise_exception(
"Tried to update noise option but update_ground_truth_only is True. "
"Set a noisy option to nonzero value so that the measured cache will be updated."
)
return method(self, *args, **kwargs)

return wrapper


# Note: dataclass is used as opposed to pydantic.BaseModel since torch.Tensors are not supported by default
@dataclass
class SharedSensorMetadata:
Expand All @@ -48,6 +60,10 @@ class SharedSensorMetadata:

cache_sizes: list[int] = field(default_factory=list)
delays_ts: torch.Tensor = make_tensor_field((0, 0), dtype_factory=lambda: gs.tc_int)
history_lengths: list[int] = field(default_factory=list)
# If True, skip _update_shared_cache for this sensor class. Defaults True; concrete sensors set False when they
# need per-step measured-cache updates (cameras set True in BaseCameraSensor.build for lazy render-on-read).
update_ground_truth_only: bool = True

def __del__(self):
try:
Expand Down Expand Up @@ -120,41 +136,66 @@ def __init__(self, sensor_options: "SensorOptions", sensor_idx: int, sensor_mana
self._manager: "SensorManager" = sensor_manager
self._shared_metadata: SharedSensorMetadataT = sensor_manager._sensors_metadata[type(self)]
self._is_built = False
self._history_length: int = self._options.history_length

self._dt = self._manager._sim.dt
self._delay_ts = round(self._options.delay / self._dt)

self._cache_slices: list[slice] = []
return_format = self._get_return_format()
assert len(return_format) > 0
if isinstance(return_format[0], int):
return_format = (return_format,)
self._return_shapes: tuple[tuple[int, ...], ...] = return_format
intrinsic_shapes: tuple[tuple[int, ...], ...] = (
(return_format,) if isinstance(return_format[0], int) else return_format
)
self._intrinsic_return_shapes: tuple[tuple[int, ...], ...] = intrinsic_shapes

self._cache_size = 0
for shape in self._return_shapes:
for shape in intrinsic_shapes:
data_size = np.prod(shape)
self._cache_slices.append(slice(self._cache_size, self._cache_size + data_size))
self._cache_size += data_size

# Slices into the per-sensor tensor from get_cloned_from_cache (history stacks H frames on dim 1).
self._read_flat_slices: list[slice] = []
read_off = 0
for shape in intrinsic_shapes:
p = np.prod(shape)
span = p * self._history_length if self._history_length > 0 else p
self._read_flat_slices.append(slice(read_off, read_off + span))
read_off += span

if self._history_length > 0:
self._return_shapes = tuple((self._history_length, *s) for s in intrinsic_shapes)
else:
self._return_shapes = intrinsic_shapes

self._cache_idx: int = -1 # initialized by SensorManager during build

# =============================== methods to implement ===============================

def _options_require_measured_cache(self):
return np.any(np.abs(self._options.delay) > gs.EPS)
Comment on lines +176 to +177
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why adding this helper? I hate one-liner functions.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sensor classes extend this with super()


def build(self):
"""
Build the sensor.

This method is called by SensorManager during the scene build phase.
This is where any shared metadata should be initialized.
"""
if self._options_require_measured_cache():
self._shared_metadata.update_ground_truth_only = False

self._shared_metadata.delays_ts = concat_with_tensor(
self._shared_metadata.delays_ts,
self._delay_ts,
expand=(self._manager._sim._B, 1),
dim=1,
)
self._shared_metadata.cache_sizes.append(self._cache_size)
self._shared_metadata.history_lengths.append(self._options.history_length)
if self._delay_ts > 0:
self._shared_metadata.update_ground_truth_only = False

@classmethod
def reset(cls, shared_metadata: SharedSensorMetadataT, shared_ground_truth_cache: torch.Tensor, envs_idx):
Expand Down Expand Up @@ -207,7 +248,9 @@ def _update_shared_cache(
Update the shared sensor cache for all sensors of this class using metadata in SensorManager.

The information in shared_cache should be the final measured sensor data after all noise and post-processing.
NOTE: The implementation should include applying the delay using the `_apply_delay_to_shared_cache()` method.
``buffered_data`` is a sliced view of the per-dtype ground-truth ring: SensorManager has already written this
step's GT into the current slot; use ``_apply_delay_to_shared_cache(..., buffered_data, ...)`` for read delay
(do not call ``set`` on it for that GT block).
"""
raise NotImplementedError(f"{cls.__name__} has not implemented `update_shared_cache()`.")

Expand Down Expand Up @@ -282,7 +325,7 @@ def _apply_delay_to_shared_cache(
shared_cache : torch.Tensor
The shared cache tensor.
buffered_data : TensorRingBuffer
The buffered data tensor.
Ground-truth timeline ring for this sensor class slice (current step already written by SensorManager).
cur_jitter_ts : torch.Tensor | None
The current jitter in timesteps (divided by simulation dt) before the sensor data is read.
interpolate : Sequence[bool] | None
Expand Down Expand Up @@ -327,7 +370,12 @@ def _get_formatted_data(self, tensor: torch.Tensor, envs_idx=None) -> torch.Tens
tensor_chunk = tensor[envs_idx].reshape((len(envs_idx), -1))

for i, shape in enumerate(self._return_shapes):
field_data = tensor_chunk[..., self._cache_slices[i]].reshape((len(envs_idx), *shape))
sl = self._read_flat_slices[i]
if self._history_length > 0:
intrinsic_shape = self._intrinsic_return_shapes[i]
field_data = tensor_chunk[..., sl].reshape((len(envs_idx), self._history_length, *intrinsic_shape))
else:
field_data = tensor_chunk[..., sl].reshape((len(envs_idx), *shape))
if self._manager._sim.n_envs == 0:
field_data = field_data[0]
return_values.append(field_data)
Expand Down Expand Up @@ -443,27 +491,33 @@ class NoisySensorMixin(Generic[NoisySensorMetadataMixinT]):
"""

@gs.assert_built
@assert_measured_cache_will_update
def set_resolution(self, resolution, envs_idx=None):
self._set_metadata_field(resolution, self._shared_metadata.resolution, self._cache_size, envs_idx)

@gs.assert_built
@assert_measured_cache_will_update
def set_bias(self, bias, envs_idx=None):
self._set_metadata_field(bias, self._shared_metadata.bias, self._cache_size, envs_idx)

@gs.assert_built
@assert_measured_cache_will_update
def set_random_walk(self, random_walk, envs_idx=None):
self._set_metadata_field(random_walk, self._shared_metadata.random_walk, self._cache_size, envs_idx)

@gs.assert_built
@assert_measured_cache_will_update
def set_noise(self, noise, envs_idx=None):
self._set_metadata_field(noise, self._shared_metadata.noise, self._cache_size, envs_idx)

@gs.assert_built
@assert_measured_cache_will_update
def set_jitter(self, jitter, envs_idx=None):
jitter_ts = np.asarray(jitter, dtype=gs.np_float) / self._dt
self._set_metadata_field(jitter_ts, self._shared_metadata.jitter_ts, 1, envs_idx)

@gs.assert_built
@assert_measured_cache_will_update
def set_delay(self, delay, envs_idx=None):
self._set_metadata_field(delay, self._shared_metadata.delay_in_steps, 1, envs_idx)

Expand Down Expand Up @@ -496,6 +550,17 @@ def build(self):
self._shared_metadata.cur_jitter_ts = torch.zeros_like(self._shared_metadata.jitter_ts, device=gs.device)
self._shared_metadata.interpolate.append(self._options.interpolate)

def _options_require_measured_cache(self) -> bool:
return super()._options_require_measured_cache() or (
self._options.jitter > gs.EPS
or (self._options.interpolate and (self._delay_ts > 0 or self._options.jitter > gs.EPS))
or np.any(np.abs(self._options.bias) > gs.EPS)
or np.any(np.abs(self._options.noise) > gs.EPS)
or np.any(np.abs(self._options.random_walk) > gs.EPS)
or np.any(np.abs(self._options.resolution) > gs.EPS)
or np.any(np.array(self._options.jitter) > gs.EPS)
)
Comment thread
Milotrince marked this conversation as resolved.
Comment on lines +553 to +562
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a huge fan of this branching. I would rather always have measured cache always enabled.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's check how much it affects performance


@classmethod
def reset(cls, shared_metadata: NoisySensorMetadataMixin, shared_ground_truth_cache: torch.Tensor, envs_idx):
super().reset(shared_metadata, shared_ground_truth_cache, envs_idx)
Expand Down
2 changes: 1 addition & 1 deletion genesis/engine/sensors/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
from genesis.vis.rasterizer_context import RasterizerContext

from .base_sensor import (
OptionsT,
RigidSensorMetadataMixin,
RigidSensorMixin,
Sensor,
SharedSensorMetadata,
)
from .base_sensor import OptionsT

if TYPE_CHECKING:
from genesis.utils.ring_buffer import TensorRingBuffer
Expand Down
39 changes: 31 additions & 8 deletions genesis/engine/sensors/contact_force.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Type

import quadrants as qd
import numpy as np
import quadrants as qd
import torch

import genesis as gs
from genesis.options.sensors import (
Contact as ContactSensorOptions,
)
from genesis.options.sensors import (
ContactForce as ContactForceSensorOptions,
)
from genesis.utils.geom import inv_transform_by_quat, qd_inv_transform_by_quat, transform_by_quat
from genesis.utils.misc import concat_with_tensor, make_tensor_field, tensor_to_array, qd_to_torch
from genesis.utils.misc import concat_with_tensor, make_tensor_field, qd_to_torch, tensor_to_array

from .base_sensor import (
NoisySensorMetadataMixin,
Expand All @@ -23,8 +25,8 @@
)

if TYPE_CHECKING:
from genesis.engine.solvers import RigidSolver
from genesis.engine.entities.rigid_entity.rigid_link import RigidLink
from genesis.engine.solvers import RigidSolver
from genesis.ext.pyrender.mesh import Mesh
from genesis.utils.ring_buffer import TensorRingBuffer
from genesis.vis.rasterizer_context import RasterizerContext
Expand Down Expand Up @@ -76,6 +78,8 @@ class ContactSensorMetadata(SharedSensorMetadata):

solver: "RigidSolver | None" = None
expanded_links_idx: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_int)
# (num_contact_sensors, max_num_filter_links); unused slots are -1.
filter_links_idx: torch.Tensor = make_tensor_field((0, 0), dtype_factory=lambda: gs.tc_int)


class ContactSensor(Sensor[ContactSensorOptions, ContactSensorMetadata]):
Expand Down Expand Up @@ -103,6 +107,15 @@ def build(self):
self._shared_metadata.expanded_links_idx, link_idx, expand=(1,), dim=0
)

num_sensors, cur_num_filter_links = self._shared_metadata.filter_links_idx.shape
max_num_filter_links = max(cur_num_filter_links, len(self._options.filter_link_idx))
filter_links_idx = torch.full((num_sensors + 1, max_num_filter_links), -1, dtype=gs.tc_int, device=gs.device)
filter_links_idx[:num_sensors, :cur_num_filter_links] = self._shared_metadata.filter_links_idx
filter_links_idx[num_sensors, : len(self._options.filter_link_idx)] = torch.tensor(
self._options.filter_link_idx, dtype=gs.tc_int, device=gs.device
)
self._shared_metadata.filter_links_idx = filter_links_idx

def _get_return_format(self) -> tuple[int, ...]:
return (1,)

Expand All @@ -122,9 +135,16 @@ def _update_shared_ground_truth_cache(
return
if shared_metadata.solver.n_envs == 0:
link_a, link_b = link_a[None], link_b[None]
is_contact_a = (link_a[..., None, :] == shared_metadata.expanded_links_idx[..., None]).any(dim=-1)
is_contact_b = (link_b[..., None, :] == shared_metadata.expanded_links_idx[..., None]).any(dim=-1)
shared_ground_truth_cache[:] = (is_contact_a | is_contact_b).T

is_contact_a = link_a[..., None, :] == shared_metadata.expanded_links_idx[..., None]
is_contact_b = link_b[..., None, :] == shared_metadata.expanded_links_idx[..., None]
if shared_metadata.filter_links_idx.numel() > 0:
filter = shared_metadata.filter_links_idx[None, :, None, :]
filtered_a = (link_b[:, None, :, None] == filter).any(dim=-1)
filtered_b = (link_a[:, None, :, None] == filter).any(dim=-1)
shared_ground_truth_cache[:] = ((is_contact_a & ~filtered_a) | (is_contact_b & ~filtered_b)).any(dim=-1).T
else:
shared_ground_truth_cache[:] = (is_contact_a | is_contact_b).any(dim=-1).T

@classmethod
def _update_shared_cache(
Expand All @@ -134,7 +154,6 @@ def _update_shared_cache(
shared_cache: torch.Tensor,
buffered_data: "TensorRingBuffer",
):
buffered_data.set(shared_ground_truth_cache)
cls._apply_delay_to_shared_cache(shared_metadata, shared_cache, buffered_data)

def _draw_debug(self, context: "RasterizerContext"):
Expand Down Expand Up @@ -198,6 +217,11 @@ def build(self):
self._shared_metadata.max_force, self._options.max_force, expand=(1, 3)
)

def _options_require_measured_cache(self) -> bool:
return super()._options_require_measured_cache() or (
np.any(np.array(self._options.min_force) > gs.EPS) or np.any(np.isfinite(self._options.max_force))
)

def _get_return_format(self) -> tuple[int, ...]:
return (3,)

Expand Down Expand Up @@ -256,7 +280,6 @@ def _update_shared_cache(
shared_cache: torch.Tensor,
buffered_data: "TensorRingBuffer",
):
buffered_data.set(shared_ground_truth_cache)
torch.normal(0.0, shared_metadata.jitter_ts, out=shared_metadata.cur_jitter_ts)
cls._apply_delay_to_shared_cache(
shared_metadata,
Expand Down
16 changes: 13 additions & 3 deletions genesis/engine/sensors/imu.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple, Type

import quadrants as qd
import numpy as np
import quadrants as qd
import torch

import genesis as gs
from genesis.options.sensors import CrossCouplingAxisType, IMU as IMUOptions
from genesis.options.sensors import IMU as IMUOptions
from genesis.options.sensors import CrossCouplingAxisType
from genesis.utils.geom import (
inv_transform_by_quat,
transform_by_quat,
Expand Down Expand Up @@ -100,18 +101,21 @@ def __init__(self, options: IMUOptions, shared_metadata: IMUSharedMetadata, mana

@gs.assert_built
def set_acc_cross_axis_coupling(self, cross_axis_coupling: CrossCouplingAxisType, envs_idx=None):
self._assert_measured_cache_will_update()
envs_idx = self._sanitize_envs_idx(envs_idx)
rot_matrix = _get_cross_axis_coupling_to_alignment_matrix(cross_axis_coupling)
self._shared_metadata.alignment_rot_matrix[envs_idx, self._idx * 3, :, :] = rot_matrix

@gs.assert_built
def set_gyro_cross_axis_coupling(self, cross_axis_coupling: CrossCouplingAxisType, envs_idx=None):
self._assert_measured_cache_will_update()
envs_idx = self._sanitize_envs_idx(envs_idx)
rot_matrix = _get_cross_axis_coupling_to_alignment_matrix(cross_axis_coupling)
self._shared_metadata.alignment_rot_matrix[envs_idx, self._idx * 3 + 1, :, :] = rot_matrix

@gs.assert_built
def set_mag_cross_axis_coupling(self, cross_axis_coupling: CrossCouplingAxisType, envs_idx=None):
self._assert_measured_cache_will_update()
envs_idx = self._sanitize_envs_idx(envs_idx)
rot_matrix = _get_cross_axis_coupling_to_alignment_matrix(cross_axis_coupling)
self._shared_metadata.alignment_rot_matrix[envs_idx, self._idx * 3 + 2, :, :] = rot_matrix
Expand Down Expand Up @@ -153,6 +157,13 @@ def build(self):
self.quat_offset = self._shared_metadata.offsets_quat[0, self._idx]
self.pos_offset = self._shared_metadata.offsets_pos[0, self._idx]

def _options_require_measured_cache(self) -> bool:
return super()._options_require_measured_cache() or (
np.any(np.abs(self._options.acc_cross_axis_coupling) > gs.EPS)
or np.any(np.abs(self._options.gyro_cross_axis_coupling) > gs.EPS)
or np.any(np.abs(self._options.mag_cross_axis_coupling) > gs.EPS)
)

def _get_return_format(self) -> tuple[tuple[int, ...], ...]:
return (3,), (3,), (3,)

Expand Down Expand Up @@ -216,7 +227,6 @@ def _update_shared_cache(
"""
Update the current measured sensor data for all IMU sensors.
"""
buffered_data.set(shared_ground_truth_cache)
torch.normal(0.0, shared_metadata.jitter_ts, out=shared_metadata.cur_jitter_ts)
cls._apply_delay_to_shared_cache(
shared_metadata,
Expand Down
Loading
Loading