Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions examples/kbot/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,27 +94,13 @@ def step_fn(
command: Array,
carry: Array,
) -> tuple[Array, Array]:
x_vel = command[..., 0]
y_vel = command[..., 1]
ang_vel = command[..., 2]

# Converts to the expected command structure.
xy_vel = jnp.stack([x_vel, y_vel], axis=-1)
cmd_vel = jnp.linalg.norm(xy_vel, axis=-1)
cmd_yaw = jnp.arctan2(y_vel, x_vel)

linvel_cmd_2 = jnp.stack([cmd_vel, cmd_yaw], axis=-1)
angvel_cmd_1 = jnp.stack([ang_vel], axis=-1)

# Call the model.
obs = jnp.concatenate(
[
joint_angles,
joint_angular_velocities / 10.0,
projected_gravity,
gyroscope,
linvel_cmd_2,
angvel_cmd_1,
],
axis=-1,
)
Expand Down
149 changes: 48 additions & 101 deletions examples/kbot/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class HumanoidWalkingTaskConfig(ksim.PPOConfig):
help="The number of hidden layers for the MLPs.",
)
var_scale: float = xax.field(
value=0.5,
value=1.0,
help="The scale for the standard deviations of the actor.",
)
start_cutoff_frequency: float = xax.field(
Expand Down Expand Up @@ -455,8 +455,8 @@ def get_physics_randomizers(self, physics_model: ksim.PhysicsModel) -> dict[str,
"floor_friction": ksim.FloorFrictionRandomizer.from_geom_name(
physics_model,
"floor",
scale_lower=0.98,
scale_upper=1.02,
# scale_lower=0.98,
# scale_upper=1.02,
),
"armature": ksim.ArmatureRandomizer(),
"all_bodies_mass_multiplication": ksim.AllBodiesMassMultiplicationRandomizer(
Expand All @@ -473,21 +473,16 @@ def get_physics_randomizers(self, physics_model: ksim.PhysicsModel) -> dict[str,
def get_events(self, physics_model: ksim.PhysicsModel) -> dict[str, ksim.Event]:
return {
"linear_push": ksim.LinearPushEvent(
linvel=1.0,
linvel=3.0,
vel_range=(0.0, 1.0),
interval_range=(1.0, 2.0),
scale=ksim.QuadraticScale.from_endpoints(0.1, 1.0),
interval_range=(0.5, 2.0),
scale=ksim.LinearScale.from_endpoints(0.25, 1.0),
),
"angular_push": ksim.AngularPushEvent(
angvel=math.radians(90.0),
angvel=math.radians(360.0),
vel_range=(0.0, 1.0),
interval_range=(1.0, 2.0),
scale=ksim.QuadraticScale.from_endpoints(0.1, 1.0),
),
"jump": ksim.JumpEvent(
jump_height_range=(0.1, 0.3),
interval_range=(1.0, 2.0),
scale=ksim.QuadraticScale.from_endpoints(0.1, 1.0),
interval_range=(0.5, 2.0),
scale=ksim.LinearScale.from_endpoints(0.25, 1.0),
),
}

Expand All @@ -500,7 +495,7 @@ def get_resets(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reset]:

def get_observations(self, physics_model: ksim.PhysicsModel) -> dict[str, ksim.Observation]:
return {
"joint_position": ksim.JointPositionObservation(noise=ksim.AdditiveUniformNoise(mag=math.radians(2))),
"joint_position": ksim.JointPositionObservation(noise=ksim.AdditiveUniformNoise(mag=math.radians(5))),
"joint_velocity": ksim.JointVelocityObservation(noise=ksim.AdditiveUniformNoise(mag=math.radians(30))),
"actuator_force": ksim.ActuatorForceObservation(),
"center_of_mass_inertia": ksim.CenterOfMassInertiaObservation(),
Expand Down Expand Up @@ -558,92 +553,79 @@ def get_observations(self, physics_model: ksim.PhysicsModel) -> dict[str, ksim.O
}

def get_commands(self, physics_model: ksim.PhysicsModel) -> dict[str, ksim.Command]:
return {
"linvel": ksim.LinearVelocityCommand(
min_vel=self.config.min_linear_velocity,
max_vel=self.config.max_linear_velocity,
ctrl_dt=self.config.ctrl_dt,
linear_accel=self.config.linear_velocity_accel,
angular_accel=self.config.angular_velocity_accel,
max_yaw=self.config.linear_velocity_max_yaw,
zero_prob=self.config.linear_velocity_zero_prob,
backward_prob=self.config.linear_velocity_backward_prob,
switch_prob=self.config.linear_velocity_switch_prob,
),
"angvel": ksim.AngularVelocityCommand(
min_vel=self.config.min_angular_velocity,
max_vel=self.config.max_angular_velocity,
ctrl_dt=self.config.ctrl_dt,
angular_accel=self.config.angular_velocity_accel,
zero_prob=self.config.angular_velocity_zero_prob,
switch_prob=self.config.angular_velocity_switch_prob,
),
}
return {}

def get_rewards(self, physics_model: ksim.PhysicsModel) -> dict[str, ksim.Reward]:
zeros = {k: v for k, v in ZEROS}

rewards = {
"stay_alive": ksim.StayAliveReward(scale=500.0),
# Command tracking rewards.
"linvel": ksim.LinearVelocityReward(
cmd="linvel",
scale=ksim.LinearScale.from_endpoints(2.0, 10.0),
),
"angvel": ksim.AngularVelocityReward(
cmd="angvel",
scale=ksim.LinearScale.from_endpoints(0.5, 10.0),
),
# Gait rewards.
"foot_airtime": ksim.FeetAirTimeReward(
ctrl_dt=self.config.ctrl_dt,
gait_period=self.config.gait_period,
air_time_percent=self.config.air_time_percent,
contact_obs="feet_contact",
scale=ksim.QuadraticScale(scale=10.0),
scale=1.0,
),
"upright": ksim.UprightReward(
scale=ksim.LinearScale(scale=3.0),
scale=1.0,
),
"motionless": ksim.StandFrozenReward(
scale=1.0,
),
"foot_height": ksim.SparseTargetHeightReward(
contact_obs="feet_contact",
position_obs="feet_position",
height=self.config.max_foot_height,
scale=ksim.QuadraticScale(scale=10.0),
scale=1.0,
),
"foot_contact": ksim.ForcePenalty(
force_obs="feet_force",
ctrl_dt=self.config.ctrl_dt,
bias=500.0, # Weight of the robot is 350 Newtons.
scale=0.1,
scale=1.0,
),
"foot_intersection": ksim.IntersectionPenalty(
position_obs="feet_position",
min_distance=0.25,
scale=ksim.QuadraticScale(scale=10.0),
scale=ksim.LinearScale(scale=10.0),
),
# Normalization penalties.
"ctrl": ksim.TorquePenalty.create(
model=physics_model,
scale=1.0,
scale=ksim.LinearScale(scale=1.0),
),
"energy": ksim.EnergyPenalty.create(
model=physics_model,
scale=ksim.LinearScale(scale=1.0),
),
}

# Joint deviation penalties.
deviation_names = [
"dof_right_shoulder_roll_03",
"dof_left_shoulder_roll_03",
"dof_right_shoulder_yaw_02",
"dof_left_shoulder_yaw_02",
"dof_right_hip_roll_03",
"dof_left_hip_roll_03",
"dof_right_hip_yaw_03",
"dof_left_hip_yaw_03",
"dof_right_wrist_00",
"dof_left_wrist_00",
"dof_right_knee_04",
"dof_left_knee_04",
"dof_right_ankle_02",
"dof_left_ankle_02",
]
rewards.update(
{
f"joint_deviation_{k}": ksim.JointDeviationPenalty.create(
"joint_deviation": ksim.JointDeviationPenalty.create(
physics_model=physics_model,
joint_names=names,
joint_targets=[zeros[name] for name in names],
scale=ksim.QuadraticScale(scale=10.0),
)
for (k, names) in (
("shoulder_roll", ["dof_right_shoulder_roll_03", "dof_left_shoulder_roll_03"]),
("shoulder_yaw", ["dof_right_shoulder_yaw_02", "dof_left_shoulder_yaw_02"]),
("hip_roll", ["dof_right_hip_roll_03", "dof_left_hip_roll_03"]),
("hip_yaw", ["dof_right_hip_yaw_03", "dof_left_hip_yaw_03"]),
("wrist", ["dof_right_wrist_00", "dof_left_wrist_00"]),
joint_names=deviation_names,
joint_targets=[zeros[name] for name in deviation_names],
scale=1.0,
)
}
)
Expand All @@ -658,7 +640,7 @@ def get_rewards(self, physics_model: ksim.PhysicsModel) -> dict[str, ksim.Reward
left_zero=zeros[left_name],
right_zero=zeros[right_name],
flipped=flipped,
scale=ksim.QuadraticScale(scale=10.0),
scale=1.0,
)
for (k, right_name, left_name, flipped) in (
("shoulder", "dof_right_shoulder_pitch_03", "dof_left_shoulder_pitch_03", True),
Expand All @@ -668,22 +650,6 @@ def get_rewards(self, physics_model: ksim.PhysicsModel) -> dict[str, ksim.Reward
}
)

# Symmetry rewards.
rewards.update(
{
f"symmetry_{k}": ksim.SymmetryReward.create(
physics_model=physics_model,
joint_names=names,
joint_targets=[zeros[name] for name in names],
scale=ksim.QuadraticScale(scale=10.0),
)
for (k, names) in (
("knee", ["dof_right_knee_04", "dof_left_knee_04"]),
("ankle", ["dof_right_ankle_02", "dof_left_ankle_02"]),
)
}
)

return rewards

def get_terminations(self, physics_model: ksim.PhysicsModel) -> dict[str, ksim.Termination]:
Expand All @@ -704,9 +670,9 @@ def get_model(self, params: ksim.InitParams) -> Model:
return Model(
params.key,
physics_model=params.physics_model,
num_actor_inputs=49,
num_actor_inputs=46,
num_actor_outputs=len(ZEROS),
num_critic_inputs=463,
num_critic_inputs=460,
min_std=0.0001,
max_std=1.0,
var_scale=self.config.var_scale,
Expand Down Expand Up @@ -734,21 +700,11 @@ def run_actor(
proj_grav_3 = observations["noisy_imu_projected_gravity"]
imu_gyro_3 = observations["noisy_imu_gyro"]

# Command tensors.
linvel_cmd: ksim.LinearVelocityCommandValue = commands["linvel"]
angvel_cmd: ksim.AngularVelocityCommandValue = commands["angvel"]

# Stacks into tensors.
linvel_cmd_2 = jnp.stack([linvel_cmd.target_vel, linvel_cmd.target_yaw], axis=-1)
angvel_cmd_1 = jnp.stack([angvel_cmd.target_vel], axis=-1)

obs = [
joint_pos_n, # NUM_JOINTS
joint_vel_n / 10.0, # NUM_JOINTS
proj_grav_3, # 3
imu_gyro_3, # 3
linvel_cmd_2, # 2
angvel_cmd_1, # 1
]

obs_n = jnp.concatenate(obs, axis=-1)
Expand Down Expand Up @@ -782,14 +738,6 @@ def run_critic(
# Flattens the last two dimensions.
feet_force_obs_6 = feet_force_obs_23.reshape(*feet_force_obs_23.shape[:-2], 6)

# Command tensors.
linvel_cmd: ksim.LinearVelocityCommandValue = commands["linvel"]
angvel_cmd: ksim.AngularVelocityCommandValue = commands["angvel"]

# Stacks into tensors.
linvel_cmd_2 = jnp.stack([linvel_cmd.target_vel, linvel_cmd.target_yaw], axis=-1)
angvel_cmd_1 = jnp.stack([angvel_cmd.target_vel], axis=-1)

obs_n = jnp.concatenate(
[
dh_joint_pos_j, # NUM_JOINTS
Expand All @@ -807,8 +755,6 @@ def run_critic(
feet_contact_2,
feet_height_2,
feet_force_obs_6 / 100.0,
linvel_cmd_2,
angvel_cmd_1,
],
axis=-1,
)
Expand Down Expand Up @@ -846,6 +792,7 @@ def _model_scan_fn(
transition_ppo_variables = ksim.PPOVariables(
log_probs=log_probs,
values=value.squeeze(-1),
# entropy=actor_dist.entropy(),
)

next_carry = jax.tree.map(
Expand Down
44 changes: 43 additions & 1 deletion ksim/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
"SmallJointJerkReward",
"AvoidLimitsPenalty",
"TorquePenalty",
"EnergyPenalty",
"JointDeviationPenalty",
"FlatBodyReward",
"PositionTrackingReward",
"StandFrozenReward",
"UprightReward",
"NoRollReward",
"LinkAccelerationPenalty",
Expand Down Expand Up @@ -580,6 +582,14 @@ def create(cls, model: PhysicsModel, scale: float | Scale = 1.0) -> Self:
)


@attrs.define(frozen=True, kw_only=True)
class EnergyPenalty(TorquePenalty):
def get_reward(self, trajectory: Trajectory) -> Array:
ctrl = trajectory.ctrl / jnp.array(self.ctrl_scales)
vel = trajectory.qvel[..., 6:]
return jnp.abs(ctrl * vel).mean(axis=-1)


@attrs.define(frozen=True, kw_only=True)
class JointDeviationPenalty(Reward):
"""Penalty for joint deviations from target positions."""
Expand Down Expand Up @@ -731,14 +741,46 @@ def create(
)


@attrs.define(frozen=True, kw_only=True)
class StandFrozenReward(Reward):
"""Reward for staying frozen."""

angvel_scale: float = attrs.field(default=0.25)
angvel_sq_scale: float = attrs.field(default=0.1, validator=attrs.validators.gt(0.0))
angvel_abs_scale: float = attrs.field(default=0.1, validator=attrs.validators.gt(0.0))
linvel_scale: float = attrs.field(default=0.25)
linvel_sq_scale: float = attrs.field(default=0.1, validator=attrs.validators.gt(0.0))
linvel_abs_scale: float = attrs.field(default=0.1, validator=attrs.validators.gt(0.0))

def get_reward(self, trajectory: Trajectory) -> dict[str, Array]:
linvel = trajectory.qvel[..., 0:3]
angvel = trajectory.qvel[..., 3:6]
linvel_norm = jnp.linalg.norm(linvel, axis=-1)
angvel_norm = jnp.linalg.norm(angvel, axis=-1)
return {
"linvel": exp_kernel_with_penalty(
linvel_norm,
self.linvel_scale,
self.linvel_sq_scale,
self.linvel_abs_scale,
),
"angvel": exp_kernel_with_penalty(
angvel_norm,
self.angvel_scale,
self.angvel_sq_scale,
self.angvel_abs_scale,
),
}


@attrs.define(frozen=True, kw_only=True)
class UprightReward(Reward):
"""Reward for staying upright."""

angvel_scale: float = attrs.field(default=0.25)
pose_scale: float = attrs.field(default=0.25)
angvel_sq_scale: float = attrs.field(default=0.1, validator=attrs.validators.gt(0.0))
angvel_abs_scale: float = attrs.field(default=0.1, validator=attrs.validators.gt(0.0))
pose_scale: float = attrs.field(default=0.25)
pose_sq_scale: float = attrs.field(default=0.1, validator=attrs.validators.gt(0.0))
pose_abs_scale: float = attrs.field(default=0.1, validator=attrs.validators.gt(0.0))

Expand Down