Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions alf/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,9 @@ def train_from_unroll(self, experience, train_info):
return shape[0] * shape[1]

@common.mark_replay
def train_from_replay_buffer(self, update_global_counter=False):
def train_from_replay_buffer(self,
effective_unroll_steps,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring for this arg?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This arg is now removed

update_global_counter=False):
"""This function can be called by any algorithm that has its own
replay buffer configured. There are several parameters specified in
``self._config`` that will affect how the training is performed:
Expand Down Expand Up @@ -1469,6 +1471,7 @@ def train_from_replay_buffer(self, update_global_counter=False):
``True``, it will affect the counter only if
``config.update_counter_every_mini_batch=True``.
"""

config: TrainerConfig = self._config

# returns 0 if haven't started training yet, when ``_replay_buffer`` is
Expand All @@ -1479,7 +1482,8 @@ def train_from_replay_buffer(self, update_global_counter=False):
# training is not started yet, ``_replay_buffer`` will be None since it
# is only lazily created later when online RL training started.
if (self._replay_buffer and self._replay_buffer.total_size
< config.initial_collect_steps):
< config.initial_collect_steps) or (effective_unroll_steps
== 0):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any situation that train_from_replay_buffer will be called with effective_unroll_steps=0?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

effective_unroll_steps is now removed from this function

assert (
self._replay_buffer.num_environments *
self._replay_buffer.max_length >= config.initial_collect_steps
Expand All @@ -1493,21 +1497,22 @@ def _replay():
# ``_replay_buffer`` for training.
# TODO: If this function can be called asynchronously, and using
# prioritized replay, then make sure replay and train below is atomic.
effective_num_updates_per_train_iter = config.num_updates_per_train_iter
with record_time("time/replay"):
mini_batch_size = config.mini_batch_size
if mini_batch_size is None:
mini_batch_size = self._replay_buffer.num_environments
if config.whole_replay_buffer_training:
experience, batch_info = self._replay_buffer.gather_all(
ignore_earliest_frames=True)
num_updates = config.num_updates_per_train_iter
num_updates = effective_num_updates_per_train_iter

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need to make this change?

@Haichao-Zhang Haichao-Zhang May 23, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary anymore. removed

else:
assert config.mini_batch_length is not None, (
"No mini_batch_length is specified for off-policy training"
)
experience, batch_info = self._replay_buffer.get_batch(
batch_size=(mini_batch_size *
config.num_updates_per_train_iter),
effective_num_updates_per_train_iter),
batch_length=config.mini_batch_length)
num_updates = 1
return experience, batch_info, num_updates, mini_batch_size
Expand Down
95 changes: 76 additions & 19 deletions alf/algorithms/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import time
import torch
from typing import Callable, Optional
from typing import Callable, List, Optional
from absl import logging

import alf
Expand Down Expand Up @@ -147,6 +147,7 @@ def __init__(self,
optimizer=None,
checkpoint=None,
is_eval: bool = False,
episodic_annotation: bool = False,
overwrite_policy_output=False,
debug_summaries=False,
name="RLAlgorithm"):
Expand Down Expand Up @@ -186,6 +187,8 @@ def __init__(self,
during deployment. In this case, the algorithm do not need to
create certain components such as value_network for ActorCriticAlgorithm,
critic_networks for SacAlgorithm.
episodic_annotation: if True, annotate the episode before being observed by the
replay buffer.
overwrite_policy_output (bool): if True, overwrite the policy output
with next_step.prev_action. This option can be used in some
cases such as data collection.
Expand All @@ -203,6 +206,7 @@ def __init__(self,
debug_summaries=debug_summaries,
name=name)
self._is_eval = is_eval
self._episodic_annotation = episodic_annotation

self._env = env
self._observation_spec = observation_spec
Expand Down Expand Up @@ -235,11 +239,14 @@ def __init__(self,
self._current_time_step = None
self._current_policy_state = None
self._current_transform_state = None

self._cached_exp = [] # for lazy observation
if self._env is not None and not self.on_policy:
replay_buffer_length = adjust_replay_buffer_length(
config, self._num_earliest_frames_ignored)

if self._episodic_annotation:
assert self._env.batch_size == 1, "only support non-batched environment"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this to the docstring of episodic_annotation?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion is not necessary here so remove also


if config.whole_replay_buffer_training and config.clear_replay_buffer:
# For whole replay buffer training, we would like to be sure
# that the replay buffer have enough samples in it to perform
Expand Down Expand Up @@ -598,19 +605,62 @@ def _async_unroll(self, unroll_length: int):

return experience

def should_post_process_episode(self, rollout_info, step_type: StepType):
"""A function that determines whether the ``post_process_episode`` function should
be applied to the current list of experiences.
"""

@runjerry runjerry May 9, 2025

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interface mainly used for subclasses? Maybe mention this. Same for post_process_episode.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Added comments. Also for post_process_episode

return False

def post_process_episode(self, experiences: List[Experience]):
"""A function for postprocessing a list of experience. It is called when
``should_post_process_episode`` is True.
It can be used to create a number of useful features such as 'hindsight relabeling'
of a trajectory etc.

Args:
experiences: a list of experience, containing the experience starting from the
initial time when ``should_post_process_episode`` is False to the step where
``should_post_process_episode`` is True.
"""
return None

def _process_unroll_step(self, policy_step, action, time_step,
transformed_time_step, policy_state,
experience_list, original_reward_list):
self.observe_for_metrics(time_step.cpu())
exp = make_experience(time_step.cpu(),
alf.layers.to_float32(policy_step),
alf.layers.to_float32(policy_state))

store_exp_time = 0
if not self.on_policy:
t0 = time.time()
self.observe_for_replay(exp)
store_exp_time = time.time() - t0
effective_number_of_unroll_steps = 1
if self._episodic_annotation:
store_exp_time = 0
# if last step, annotate
rollout_info = policy_step.info
self._cached_exp.append(exp)
if self.should_post_process_episode(rollout_info,
time_step.step_type):

# 1) process
annotated_exp_list = self.post_process_episode(
self._cached_exp)
effective_number_of_unroll_steps = len(annotated_exp_list)
# 2) observe
if not self.on_policy:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this condition check should be performed earlier, since it seems a waste to do all the post_process_episode if self.on_policy?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

t0 = time.time()
for exp in annotated_exp_list:
self.observe_for_replay(exp)
store_exp_time = time.time() - t0
# clean up the exp cache
self._cached_exp = []

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to assume that all envs end on the same step? What if some envs are LAST, some are MID? cached_exp will be cleared even for those with MID steps?

Even when doing this for an env with batch_size 1, this annotation mode will delay experience from being stored into the replay buffer.

Ok to submit the change as is, but may need to do two things:

  1. rename the feature to something like store_experience_on_episode_end, and document its behavior clearly in the docstr.
    experience relabel should be done when reading data out of replay buffer as in hindsight relabel.

  2. assert that batch_size is 1 when enabled.

Also, delaying train_step because of delayed experience storage can have unexpected side effects, e.g. if episodes are 100 steps long, and unroll once per train iter, then summary will only happen every 100 train iters. It will also shift the distribution of the data training sees due to the delay.

Overall I think doing this episode level relabeling at the DataTransformer stage, after reading from replay_buffer is perhaps a better way, and a cleaner way as well (less scattered code). That would require the replay buffer to keep track of episode begin and end, which I think it already does.

@Haichao-Zhang Haichao-Zhang May 23, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to assume that all envs end on the same step? What if some envs are LAST, some are MID? cached_exp will be cleared even for those with MID steps?

There is no such assumption. It is totally up to the users to inject their own assumptions.
By default, the behavior is the same as before.
Sorry that the function names are a bit mis-leading and their role has been extended to handle per-step case as well. Changed the function names and added more comments.

Even when doing this for an env with batch_size 1, this annotation mode will delay experience from being stored into the replay buffer.

No it won't. By default, the behavior is the same as before.

Ok to submit the change as is, but may need to do two things:

  1. rename the feature to something like store_experience_on_episode_end, and document its behavior clearly in the docstr.
    The suggested name is not appropriate.

experience relabel should be done when reading data out of replay buffer as in hindsight relabel.

Different use cases. This is an alternative interface that can support more than pure relabeling (e.g. excluding data), which is not directly supported by the replay buffer hindsight relabel.

  1. assert that batch_size is 1 when enabled.
    There is no such assumption in the current PR. It is up to the user.

Also, delaying train_step because of delayed experience storage can have unexpected side effects, e.g. if episodes are 100 steps long, and unroll once per train iter, then summary will only happen every 100 train iters. It will also shift the distribution of the data training sees due to the delay.
There is no delay.

Overall I think doing this episode level relabeling at the DataTransformer stage, after reading from replay_buffer is perhaps a better way, and a cleaner way as well (less scattered code). That would require the replay buffer to keep track of episode begin and end, which I think it already does.
As explained, it is more than pure relabeling.

else:
# effective unroll steps as 0 if not post_process_episode timepoint yet
effective_number_of_unroll_steps = 0
else:
store_exp_time = 0
if not self.on_policy:
t0 = time.time()
self.observe_for_replay(exp)
store_exp_time = time.time() - t0

exp_for_training = Experience(
time_step=transformed_time_step,
Expand All @@ -620,7 +670,7 @@ def _process_unroll_step(self, policy_step, action, time_step,

experience_list.append(exp_for_training)
original_reward_list.append(time_step.reward)
return store_exp_time
return store_exp_time, effective_number_of_unroll_steps

def reset_state(self):
"""Reset the state of the algorithm.
Expand Down Expand Up @@ -665,6 +715,7 @@ def _sync_unroll(self, unroll_length: int):
policy_step_time = 0.
env_step_time = 0.
store_exp_time = 0.
effective_unroll_steps = 0
for _ in range(unroll_length):
policy_state = common.reset_state_if_necessary(
policy_state, initial_state, time_step.is_first())
Expand Down Expand Up @@ -693,9 +744,10 @@ def _sync_unroll(self, unroll_length: int):
if self._overwrite_policy_output:
policy_step = policy_step._replace(
output=next_time_step.prev_action)
store_exp_time += self._process_unroll_step(
store_exp_time_i, effective_unroll_steps = self._process_unroll_step(
policy_step, action, time_step, transformed_time_step,
policy_state, experience_list, original_reward_list)
store_exp_time += store_exp_time_i

time_step = next_time_step
policy_state = policy_step.state
Expand Down Expand Up @@ -723,7 +775,7 @@ def _sync_unroll(self, unroll_length: int):
self._current_policy_state = common.detach(policy_state)
self._current_transform_state = common.detach(trans_state)

return experience
return experience, effective_unroll_steps

def train_iter(self):
"""Perform one iteration of training.
Expand Down Expand Up @@ -804,6 +856,7 @@ def _unroll_iter_off_policy(self):
unrolled = False
root_inputs = None
rollout_info = None
effective_unroll_steps = 0
if (alf.summary.get_global_counter()
>= self._rl_train_after_update_steps
and (unroll_length > 0 or config.unroll_length == 0) and
Expand All @@ -822,19 +875,21 @@ def _unroll_iter_off_policy(self):
# need to remember whether summary has been written between
# two unrolls.
with self._ensure_rollout_summary:
experience = self.unroll(unroll_length)
experience, effective_unroll_steps = self.unroll(
unroll_length)
if experience:
self.summarize_rollout(experience)
self.summarize_metrics()
rollout_info = experience.rollout_info
if config.use_root_inputs_for_after_train_iter:
root_inputs = experience.time_step
del experience
return unrolled, root_inputs, rollout_info
return unrolled, root_inputs, rollout_info, effective_unroll_steps

def _train_iter_off_policy(self):
"""User may override this for their own training procedure."""
unrolled, root_inputs, rollout_info = self._unroll_iter_off_policy()
unrolled, root_inputs, rollout_info, effective_unroll_steps = self._unroll_iter_off_policy(
)

# replay buffer may not have been created for two different reasons:
# 1. in online RL training (``has_offline`` is False), unroll is not
Expand All @@ -846,11 +901,13 @@ def _train_iter_off_policy(self):
return 0

self.train()
steps = self.train_from_replay_buffer(update_global_counter=True)

if unrolled:
with record_time("time/after_train_iter"):
self.after_train_iter(root_inputs, rollout_info)
steps = 0
for i in range(effective_unroll_steps):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unroll_steps is the wrong name? It should be called unroll_iterations to indicate training iterations, not env steps?

also rename effective_number_of_unroll_steps to be effective_unroll_iters to be consistent. (i.e. remove "number_of_")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments. Changed.

steps += self.train_from_replay_buffer(effective_unroll_steps=1,
update_global_counter=True)
if unrolled:
with record_time("time/after_train_iter"):
self.after_train_iter(root_inputs, rollout_info)

@runjerry runjerry May 9, 2025

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that this update fundamentally changes the off-policy update logic w.r.t. its actual unroll in the env. Previously, between every call of self._unroll_iter_off_policy, the policy gets an "update" from self.train_from_replay_buffer. Now if self._episodic_annotation, policy training only happens after each episode, though the UTD stays the same. I feel that the episodic annotation function should be configurable independently of the choice of such unroll/update logic. Ideally, we may want to keep the previous version here while achieving the same effect of the change of above lines by configuring unroll_length and num_updates_per_train_iter.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If self._episodic_annotation is False, everything is the same as before.
If self._episodic_annotation is True, by default (with the new commit), also reduces to the original logic, so everything is the same after before (policy training only happens after each time step, not after each episode)

In the derived class, it is up to the user for determining what kind of annotation function he/she wants to implement and use.


# For now, we only return the steps of the primary algorithm's training
return steps
Expand Down