Skip to content
Open
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 75 additions & 18 deletions alf/algorithms/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import time
import torch
from typing import Callable, Optional
from typing import Callable, List, Optional, Tuple
from absl import logging

import alf
Expand Down Expand Up @@ -544,6 +544,7 @@ def _async_unroll(self, unroll_length: int):
store_exp_time = 0.
step_time = 0.
max_step_time = 0.
effective_unroll_steps = 0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we lack a formal definition of "effective" in the code document.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more comments with examples, especially in preprocess_unroll_experience

qsize = self._async_unroller.get_queue_size()
unroll_results = self._async_unroller.gather_unroll_results(
unroll_length, self._config.max_unroll_length)
Expand All @@ -566,10 +567,12 @@ def _async_unroll(self, unroll_length: int):
step_time += unroll_result.step_time
max_step_time = max(max_step_time, unroll_result.step_time)

store_exp_time += self._process_unroll_step(
store_exp_time_i, effective_unroll_steps_i = self._process_unroll_step(
policy_step, policy_step.output, time_step,
transformed_time_step, policy_state, experience_list,
original_reward_list)
store_exp_time += store_exp_time_i
effective_unroll_steps += effective_unroll_steps_i

alf.summary.scalar("time/unroll_env_step",
env_step_time,
Expand All @@ -596,20 +599,59 @@ def _async_unroll(self, unroll_length: int):

self._current_transform_state = common.detach(trans_state)

return experience
# if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as
# an effective unroll iter
effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length
return experience, effective_unroll_iters

def post_process_experience(self, rollout_info, step_type: StepType,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is confusing with the existing function preprocess_experience which might suggest that this happens after that but in fact this happens before training.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to preprocess_unroll_experience

experiences: Experience) -> Tuple[List, int]:
"""A function for postprocessing experience. By default, it returns the input
experience unmodified. Users can customize this function in the derived
class to achieve different effects. For example:
- per-step processing: return the current step of experience unmodified (by default)
or a modified version according to the customized ``post_process_experience``.
As another example, task filtering can be simply achieved by returning ``[]``
for that particular task.
- per-episode processing: this can be achieved by returning a list of processed
experiences. For example, this can be used for success episode labeling.

Args:
rollout_info: the rollout info.
step_type: the step type of the current experience.
experiences: one step of experience.

Returns:
- a list of experiences. Users can customize this functions in the
derived class to achieve different effects. For example:
* return a list that contains only the input experience (default behavior).
* return a list that contains a number of experiences. This can be useful
for episode processing such as success episode labeling.
- an integer representing the effective number of unroll steps per env. The
default value of 1, meaning the length of effective experience is 1
after calling ``post_process_experience``, the same as the input length
of experience.
"""
return [experiences], 1

def _process_unroll_step(self, policy_step, action, time_step,
transformed_time_step, policy_state,
experience_list, original_reward_list):
experience_list,
original_reward_list) -> Tuple[int, int]:
self.observe_for_metrics(time_step.cpu())
exp = make_experience(time_step.cpu(),
alf.layers.to_float32(policy_step),
alf.layers.to_float32(policy_state))

effective_unroll_steps = 1
store_exp_time = 0
if not self.on_policy:
# 1) post process
post_processed_exp_list, effective_unroll_steps = self.post_process_experience(
policy_step.info, time_step.step_type, exp)
# 2) observe
t0 = time.time()
self.observe_for_replay(exp)
for exp in post_processed_exp_list:
self.observe_for_replay(exp)
store_exp_time = time.time() - t0

exp_for_training = Experience(
Expand All @@ -620,7 +662,7 @@ def _process_unroll_step(self, policy_step, action, time_step,

experience_list.append(exp_for_training)
original_reward_list.append(time_step.reward)
return store_exp_time
return store_exp_time, effective_unroll_steps

def reset_state(self):
"""Reset the state of the algorithm.
Expand All @@ -644,6 +686,8 @@ def _sync_unroll(self, unroll_length: int):
Returns:
Experience: The stacked experience with shape :math:`[T, B, \ldots]`
for each of its members.
effective_unroll_iters: the effective number of unroll iterations.
Each unroll iteration contains ``unroll_length`` unroll steps.
"""
if self._current_time_step is None:
self._current_time_step = common.get_initial_time_step(self._env)
Expand All @@ -665,6 +709,7 @@ def _sync_unroll(self, unroll_length: int):
policy_step_time = 0.
env_step_time = 0.
store_exp_time = 0.
effective_unroll_steps = 0
for _ in range(unroll_length):
policy_state = common.reset_state_if_necessary(
policy_state, initial_state, time_step.is_first())
Expand Down Expand Up @@ -693,9 +738,11 @@ def _sync_unroll(self, unroll_length: int):
if self._overwrite_policy_output:
policy_step = policy_step._replace(
output=next_time_step.prev_action)
store_exp_time += self._process_unroll_step(
store_exp_time_i, effective_unroll_steps_i = self._process_unroll_step(
policy_step, action, time_step, transformed_time_step,
policy_state, experience_list, original_reward_list)
store_exp_time += store_exp_time_i
effective_unroll_steps += effective_unroll_steps_i

time_step = next_time_step
policy_state = policy_step.state
Expand Down Expand Up @@ -723,7 +770,10 @@ def _sync_unroll(self, unroll_length: int):
self._current_policy_state = common.detach(policy_state)
self._current_transform_state = common.detach(trans_state)

return experience
# if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as
# an effective unroll iter
effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's strange to call unroll "iter"? The original definition is that each training iter we have one unroll. So what does unroll iters mean in this context?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments. One effective_unroll_iter refers to the unroll_length times of calling of rollout_step in the unroll phase.

return experience, effective_unroll_iters

def train_iter(self):
"""Perform one iteration of training.
Expand All @@ -747,7 +797,7 @@ def _compute_train_info_and_loss_info_on_policy(self, unroll_length):
with record_time("time/unroll"):
with torch.cuda.amp.autocast(self._config.enable_amp,
dtype=self._config.amp_dtype):
experience = self.unroll(self._config.unroll_length)
experience, _ = self.unroll(self._config.unroll_length)
self.summarize_metrics()

train_info = experience.rollout_info
Expand Down Expand Up @@ -788,6 +838,9 @@ def _unroll_iter_off_policy(self):
unroll length, it may not have been called.
- root_inputs: root-level time step returned by the unroll
- rollout_info: rollout info returned by the unroll
- effective_unroll_iters: the effective number of unroll iterations.
``train_from_replay_buffer`` will be run ``effective_unroll_iters`` times
during ``_train_iter_off_policy``.
"""
config: TrainerConfig = self._config

Expand All @@ -804,6 +857,7 @@ def _unroll_iter_off_policy(self):
unrolled = False
root_inputs = None
rollout_info = None
effective_unroll_iters = 0
if (alf.summary.get_global_counter()
>= self._rl_train_after_update_steps
and (unroll_length > 0 or config.unroll_length == 0) and
Expand All @@ -822,19 +876,21 @@ def _unroll_iter_off_policy(self):
# need to remember whether summary has been written between
# two unrolls.
with self._ensure_rollout_summary:
experience = self.unroll(unroll_length)
experience, effective_unroll_iters = self.unroll(
unroll_length)
if experience:
self.summarize_rollout(experience)
self.summarize_metrics()
rollout_info = experience.rollout_info
if config.use_root_inputs_for_after_train_iter:
root_inputs = experience.time_step
del experience
return unrolled, root_inputs, rollout_info
return unrolled, root_inputs, rollout_info, effective_unroll_iters

def _train_iter_off_policy(self):
"""User may override this for their own training procedure."""
unrolled, root_inputs, rollout_info = self._unroll_iter_off_policy()
unrolled, root_inputs, rollout_info, effective_unroll_iters = self._unroll_iter_off_policy(
)

# replay buffer may not have been created for two different reasons:
# 1. in online RL training (``has_offline`` is False), unroll is not
Expand All @@ -846,11 +902,12 @@ def _train_iter_off_policy(self):
return 0

self.train()
steps = self.train_from_replay_buffer(update_global_counter=True)

if unrolled:
with record_time("time/after_train_iter"):
self.after_train_iter(root_inputs, rollout_info)
steps = 0
for i in range(effective_unroll_iters):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's possible the effective_unroll_iters is always smaller than 1 in the case of num_envs > 1.

@Haichao-Zhang Haichao-Zhang May 23, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Now also handles the fractional unroll case.

steps += self.train_from_replay_buffer(update_global_counter=True)
if unrolled:
with record_time("time/after_train_iter"):
self.after_train_iter(root_inputs, rollout_info)

# For now, we only return the steps of the primary algorithm's training
return steps
Expand Down
Loading