Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions alf/algorithms/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,18 @@ def __init__(self, data_transformer_ctors, observation_spec):

@staticmethod
def _validate_order(data_transformers):
# Hindsight should probably not be used together with FrameStacker,
# unless done really carefully. Hindsight after FrameStacker is
# simply wrong, because Hindsight would read ``achieved_goal`` field
# of a future step directly from the replay buffer without stacking.
def _tier_of(data_transformer):
if isinstance(data_transformer, UntransformedTimeStep):
return 1
if isinstance(data_transformer,
(HindsightExperienceTransformer, FrameStacker)):
if isinstance(data_transformer, HindsightExperienceTransformer):
return 2
return 3
if isinstance(data_transformer, FrameStacker):
return 3
return 4

prev_tier = 0
for i in range(len(data_transformers)):
Expand Down
2 changes: 2 additions & 0 deletions alf/algorithms/merlin_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,8 @@ def __init__(self,
enc_layers.append(res_block)
in_channels = 64

if output_activation is None:
output_activation = alf.math.identity

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems unnecessary. Can provide alf.math.identity as argument.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Removed.

enc_layers.extend([
nn.Flatten(),
alf.layers.FC(
Expand Down
6 changes: 5 additions & 1 deletion alf/algorithms/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,19 +223,23 @@ def __init__(self,
replay_buffer_length = adjust_replay_buffer_length(
config, self._num_earliest_frames_ignored)

total_replay_size = replay_buffer_length * self._env.batch_size
if config.whole_replay_buffer_training and config.clear_replay_buffer:
# For whole replay buffer training, we would like to be sure
# that the replay buffer have enough samples in it to perform
# the training, which will most likely happen in the 2nd
# iteration. The minimum_initial_collect_steps guarantees that.
minimum_initial_collect_steps = replay_buffer_length * self._env.batch_size
minimum_initial_collect_steps = total_replay_size
if config.initial_collect_steps < minimum_initial_collect_steps:
common.info(
'Set the initial_collect_steps to minimum required '
f'value {minimum_initial_collect_steps} because '
'whole_replay_buffer_training is on.')
config.initial_collect_steps = minimum_initial_collect_steps

assert config.initial_collect_steps <= total_replay_size, \
"Training will not happen - insufficient replay buffer size."

self.set_replay_buffer(self._env.batch_size, replay_buffer_length,
config.priority_replay)

Expand Down
6 changes: 5 additions & 1 deletion alf/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,11 @@ def pre_config(configs):
try:
config1(name, value, mutable=False)
_HANDLED_PRE_CONFIGS.append((name, value))
except ValueError:
except ValueError as e:
# Most of the times, for command line flags, this warning is a false alarm.
# This can be useful in other failures, e.g. when the Config has already been used,
# before configuring its value.
logging.warning("pre_config potential error: %s", e)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warning is hard to understand. It's better to identify the case of the Config has already been used. Perhaps throw a different type of Exception when config has been used in config1()?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Logging error in config1.

_PRE_CONFIGS.append((name, value))


Expand Down
3 changes: 2 additions & 1 deletion alf/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ def _generate_time_step(batched,
if env_id is None:
env_id = md.arange(batch_size, dtype=md.int32)
if reward is not None:
assert reward.shape[:1] == outer_dims
assert reward.shape[:1] == outer_dims, "%s, %s" % (reward.shape,
outer_dims)
if prev_action is not None:
flat_action = nest.flatten(prev_action)
assert flat_action[0].shape[:1] == outer_dims
Expand Down
4 changes: 3 additions & 1 deletion alf/networks/critic_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self,
joint_fc_layer_params=None,
activation=torch.relu_,
kernel_initializer=None,
last_bias_init_value=0.0,
use_fc_bn=False,
use_naive_parallel_network=False,
name="CriticNetwork"):
Expand Down Expand Up @@ -174,7 +175,8 @@ def __init__(self,
last_activation=math_ops.identity,
use_fc_bn=use_fc_bn,
last_kernel_initializer=last_kernel_initializer,
name=name)
last_bias_init_value=last_bias_init_value,
name=name + ".joint_encoder")
self._use_naive_parallel_network = use_naive_parallel_network

def make_parallel(self, n):
Expand Down
4 changes: 3 additions & 1 deletion alf/networks/encoding_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def __init__(self,
last_layer_size=None,
last_activation=None,
last_kernel_initializer=None,
last_bias_init_value=0.0,
last_use_fc_bn=False,
name="EncodingNetwork"):
"""
Expand Down Expand Up @@ -540,7 +541,8 @@ def __init__(self,
last_layer_size,
activation=last_activation,
use_bn=last_use_fc_bn,
kernel_initializer=last_kernel_initializer))
kernel_initializer=last_kernel_initializer,
bias_init_value=last_bias_init_value))
input_size = last_layer_size

if output_tensor_spec is not None:
Expand Down
1 change: 1 addition & 0 deletions alf/trainers/policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def __init__(self, config: TrainerConfig, ddp_rank: int = -1):
logging.info(
"observation_spec=%s" % pprint.pformat(env.observation_spec()))
logging.info("action_spec=%s" % pprint.pformat(env.action_spec()))
logging.info("reward_spec=%s" % pprint.pformat(env.reward_spec()))

# for offline buffer construction
untransformed_observation_spec = env.observation_spec()
Expand Down
27 changes: 20 additions & 7 deletions alf/utils/data_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
DataItem = alf.data_structures.namedtuple(
"DataItem", [
"env_id", "x", "o", "reward", "step_type", "batch_info",
"replay_buffer", "rollout_info_field"
"replay_buffer", "rollout_info_field", "discount"
],
default_value=())

Expand All @@ -40,12 +40,20 @@ def get_batch(env_ids, dim, t, x):
batch_size = len(env_ids)
x = torch.as_tensor(x, dtype=torch.float32, device="cpu")
t = torch.as_tensor(t, dtype=torch.int32, device="cpu")
ox = (x * torch.arange(
batch_size, dtype=torch.float32, requires_grad=True,
device="cpu").unsqueeze(1) * torch.arange(
dim, dtype=torch.float32, requires_grad=True,
device="cpu").unsqueeze(0))
a = x * torch.ones(batch_size, dtype=torch.float32, device="cpu")
# ox = (x * torch.arange(
# batch_size, dtype=torch.float32, requires_grad=True,
# device="cpu").unsqueeze(1) * torch.arange(
# dim, dtype=torch.float32, requires_grad=True,
# device="cpu").unsqueeze(0))
if batch_size > 1 and x.ndim > 0 and batch_size == x.shape[0]:
a = x
else:
a = x * torch.ones(batch_size, dtype=torch.float32, device="cpu")
if batch_size > 1 and t.ndim > 0 and batch_size == t.shape[0]:
pass
else:
t = t * torch.ones(batch_size, dtype=torch.int32, device="cpu")
ox = a.unsqueeze(1).clone().requires_grad_(True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the purpose of this change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed because we allow x and t inputs to be scalars, which will be expanded to be consistent with the batch_size. Made code easier to read, and commented.

g = torch.zeros(batch_size, dtype=torch.float32, device="cpu")
# reward function adapted from ReplayBuffer: default_reward_fn
r = torch.where(
Expand All @@ -60,6 +68,10 @@ def get_batch(env_ids, dim, t, x):
"a": a,
"g": g
}),
discount=torch.tensor(
t != alf.data_structures.StepType.LAST,
dtype=torch.float32,
device="cpu"),
reward=r)


Expand All @@ -79,6 +91,7 @@ def __init__(self, *args):
"a": alf.TensorSpec(shape=(), dtype=torch.float32),
"g": alf.TensorSpec(shape=(), dtype=torch.float32)
}),
discount=alf.TensorSpec(shape=(), dtype=torch.float32),
reward=alf.TensorSpec(shape=(), dtype=torch.float32))

@parameterized.named_parameters([
Expand Down
2 changes: 2 additions & 0 deletions alf/utils/external_configurables.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@

gin.external_configurable(torch.nn.init.xavier_normal_,
'torch.nn.init.xavier_normal_')
gin.external_configurable(torch.nn.Embedding, 'torch.nn.Embedding')
gin.external_configurable(torch.nn.Sequential, 'torch.nn.Sequential')
6 changes: 5 additions & 1 deletion alf/utils/normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def _summary(name, val):
def _summarize_all(path, t, m2, m):
if path:
path += "."
spec = TensorSpec.from_tensor(m if m2 is None else m2)
if m2 is not None:
spec = TensorSpec.from_tensor(m2)
else:
assert m is not None
spec = TensorSpec.from_tensor(m)
_summary(path + "tensor.batch_min",
_reduce_along_batch_dims(t, spec, torch.min))
_summary(path + "tensor.batch_max",
Expand Down