Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions alf/algorithms/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Optional, Callable
import torch
import alf
from alf.utils.schedulers import as_scheduler
from alf.utils.schedulers import ConstantScheduler, as_scheduler


@alf.configurable
Expand Down Expand Up @@ -143,13 +143,18 @@ def __init__(self,
total number of FRAMES will be (``num_env_steps*frame_skip``) for
calculating sample efficiency. See alf/environments/wrappers.py
for the definition of FrameSkip.
unroll_length (float): number of time steps each environment proceeds per
iteration. The total number of time steps from all environments per
iteration can be computed as: ``num_envs * env_batch_size * unroll_length``.
If ``unroll_length`` is not an integer, the actual unroll_length
unroll_length (float|Scheduler): number of time steps each environment
proceeds per iteration. The total number of time steps from all
environments per iteration can be computed as:
``num_envs * env_batch_size * unroll_length``. If
``unroll_length`` is not an integer, the actual unroll_length
being used will fluctuate between ``floor(unroll_length)`` and
``ceil(unroll_length)`` and the expectation will be equal to
``unroll_length``.
``unroll_length``. For sync off-policy training,
``unroll_length`` can also be a scheduler. In that case,
``async_unroll`` and ``whole_replay_buffer_training`` must both
be False. If a resolved value is 0, the iteration skips rollout
and only performs replay-buffer updates.
unroll_with_grad (bool): a bool flag indicating whether we require
grad during ``unroll()``. This flag is only used by
``OffPolicyAlgorithm`` where unrolling with grads is usually
Expand Down Expand Up @@ -389,6 +394,16 @@ def __init__(self,
self.unroll_with_grad = unroll_with_grad
self.use_root_inputs_for_after_train_iter = use_root_inputs_for_after_train_iter
self.async_unroll = async_unroll
if not isinstance(self._unroll_length, ConstantScheduler):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConstantScheduler --> should check against a base class, e.g. Scheduler?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to check against ConstantScheduler here because a scalar input will be converted to one before this check due to the setter function on line 479.

Any non-constant scheduler should then raise an error if we're doing on-policy or async unroll.

assert not async_unroll, (
"scheduled unroll_length is not supported for async_unroll=True"
)
assert not whole_replay_buffer_training, (
"scheduled unroll_length is not supported for "
"whole_replay_buffer_training=True")
assert num_env_steps == 0, (
"scheduled unroll_length is not supported when num_env_steps "
"is used as a termination criterion")
if async_unroll:
assert not unroll_with_grad, ("unroll_with_grad is not supported "
"for async_unroll=True")
Expand Down Expand Up @@ -455,3 +470,11 @@ def __init__(self,
self.normalize_importance_weights_by_max = normalize_importance_weights_by_max
self.visualize_alf_tree = visualize_alf_tree
self.remote_training = remote_training

@property
def unroll_length(self):
return self._unroll_length()

@unroll_length.setter
def unroll_length(self, value):
self._unroll_length = as_scheduler(value)
12 changes: 10 additions & 2 deletions alf/algorithms/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,11 @@ def _unroll_iter_off_policy(self):
if not config.update_counter_every_mini_batch:
alf.summary.increment_global_counter()

unroll_length = self._remaining_unroll_length_fraction + config.unroll_length
# Preserve the configured value so we can distinguish it from the
# integerized length after carrying over any fractional remainder.
requested_unroll_length = config.unroll_length
unroll_length = (self._remaining_unroll_length_fraction +
requested_unroll_length)
self._remaining_unroll_length_fraction = unroll_length - int(
unroll_length)
unroll_length = int(unroll_length)
Expand All @@ -823,9 +827,13 @@ def _unroll_iter_off_policy(self):
unrolled = False
root_inputs = None
rollout_info = None
# Async unroll still needs one unroll call to pump queued work even when
# the configured unroll length is exactly zero.
allow_zero_length_unroll = (config.async_unroll
and requested_unroll_length == 0)
if (alf.summary.get_global_counter()
>= self._rl_train_after_update_steps
and (unroll_length > 0 or config.unroll_length == 0) and
and (unroll_length > 0 or allow_zero_length_unroll) and
(config.num_env_steps == 0
or self.get_step_metrics()[1].result() < config.num_env_steps)):
unrolled = True
Expand Down
111 changes: 111 additions & 0 deletions alf/algorithms/rl_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import alf
from alf.utils import common, dist_utils, tensor_utils
from alf.utils.schedulers import StepScheduler, update_progress
from alf.data_structures import AlgStep, Experience, LossInfo, StepType, TimeStep
from alf.algorithms.rl_algorithm import RLAlgorithm
from alf.algorithms.config import TrainerConfig
Expand Down Expand Up @@ -174,6 +175,45 @@ def current_time_step(self):

class RLAlgorithmTest(unittest.TestCase):

class _ReplayOnlyAlg(MyAlg):

def __init__(self, config):
observation_spec = TensorSpec((2, ), dtype='float32')
action_spec = alf.BoundedTensorSpec(shape=(),
dtype='int64',
minimum=0,
maximum=2)
super().__init__(observation_spec=observation_spec,
action_spec=action_spec,
env=None,
config=config,
on_policy=False)
# A non-None sentinel is enough to make RLAlgorithm treat this as
# replay-buffer-backed during off-policy training.
self._replay_buffer = object()
# These counters let the test assert whether rollout work was
# skipped and whether replay-only hooks still ran.
self.unroll_calls = []
self.train_calls = 0
self.after_train_iter_calls = 0

def _unroll(self, unroll_length: int):
self.unroll_calls.append(unroll_length)
return None

def train_from_replay_buffer(self, update_global_counter=False):
# Return a fixed step count so the test can focus on control flow
# rather than replay buffer contents.
self.train_calls += 1
self.update_global_counter = update_global_counter
return 7

def after_train_iter(self, root_inputs, train_info):
self.after_train_iter_calls += 1

def tearDown(self):
update_progress('iterations', 0)

def test_on_policy_algorithm(self):
# root_dir is not used. We have to give it a value because
# it is a required argument of TrainerConfig.
Expand All @@ -198,6 +238,77 @@ def test_on_policy_algorithm(self):
self.assertTrue(torch.all(logits[1, :] > logits[0, :]))
self.assertTrue(torch.all(logits[1, :] > logits[2, :]))

def test_scheduled_unroll_length_guards(self):
unroll_length = StepScheduler('iterations', [(1, 1), (2, 0)])

with self.assertRaisesRegex(
AssertionError,
"scheduled unroll_length is not supported for async_unroll=True"
):
TrainerConfig(root_dir='/tmp/rl_algorithm_test',
unroll_length=unroll_length,
async_unroll=True,
max_unroll_length=1)

with self.assertRaisesRegex(
AssertionError, "scheduled unroll_length is not supported for "
"whole_replay_buffer_training=True"):
TrainerConfig(root_dir='/tmp/rl_algorithm_test',
unroll_length=unroll_length,
whole_replay_buffer_training=True)

with self.assertRaisesRegex(
AssertionError,
"scheduled unroll_length is not supported when num_env_steps "
"is used as a termination criterion"):
TrainerConfig(root_dir='/tmp/rl_algorithm_test',
unroll_length=unroll_length,
num_env_steps=1,
num_iterations=0,
whole_replay_buffer_training=False)

def test_scheduled_zero_unroll_skips_rollout(self):
config = TrainerConfig(root_dir='/tmp/rl_algorithm_test',
unroll_length=StepScheduler(
'iterations', [(1, 1), (2, 0)]),
mini_batch_length=1,
mini_batch_size=1,
whole_replay_buffer_training=False)
alg = self._ReplayOnlyAlg(config)

update_progress('iterations', 0)
self.assertEqual(alg._train_iter_off_policy(), 7)
self.assertEqual(alg.unroll_calls, [1])
self.assertEqual(alg.train_calls, 1)
self.assertEqual(alg.after_train_iter_calls, 1)
self.assertTrue(alg.update_global_counter)

update_progress('iterations', 1)
self.assertEqual(alg._train_iter_off_policy(), 7)
self.assertEqual(alg.unroll_calls, [1])
self.assertEqual(alg.train_calls, 2)
self.assertEqual(alg.after_train_iter_calls, 1)

def test_constant_unroll_length_keeps_scalar_behavior(self):
config = TrainerConfig(root_dir='/tmp/rl_algorithm_test',
unroll_length=5,
async_unroll=True,
max_unroll_length=5)
self.assertEqual(config.unroll_length, 5)
self.assertEqual(config.max_unroll_length, 5)

def test_on_policy_constant_unroll_length_still_works(self):
config = TrainerConfig(root_dir='/tmp/rl_algorithm_test',
unroll_length=3)
env = MyEnv(batch_size=2)
alg = MyAlg(observation_spec=env.observation_spec(),
action_spec=env.action_spec(),
env=env,
config=config,
on_policy=True)
steps = alg.train_iter()
self.assertEqual(steps, 6)

def test_off_policy_algorithm(self):
with tempfile.TemporaryDirectory() as root_dir:
common.run_under_record_context(
Expand Down
Loading