-
Notifications
You must be signed in to change notification settings - Fork 59
Post Process Experience with Customizable Modes #1768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pytorch
Are you sure you want to change the base?
Changes from 7 commits
00efea8
e4cdb81
a05e8da
9cfe6a5
26ab09a
734dae8
94a50bf
8fc3ff2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |
| import os | ||
| import time | ||
| import torch | ||
| from typing import Callable, Optional | ||
| from typing import Callable, List, Optional, Tuple | ||
| from absl import logging | ||
|
|
||
| import alf | ||
|
|
@@ -544,6 +544,7 @@ def _async_unroll(self, unroll_length: int): | |
| store_exp_time = 0. | ||
| step_time = 0. | ||
| max_step_time = 0. | ||
| effective_unroll_steps = 0 | ||
| qsize = self._async_unroller.get_queue_size() | ||
| unroll_results = self._async_unroller.gather_unroll_results( | ||
| unroll_length, self._config.max_unroll_length) | ||
|
|
@@ -566,10 +567,12 @@ def _async_unroll(self, unroll_length: int): | |
| step_time += unroll_result.step_time | ||
| max_step_time = max(max_step_time, unroll_result.step_time) | ||
|
|
||
| store_exp_time += self._process_unroll_step( | ||
| store_exp_time_i, effective_unroll_steps_i = self._process_unroll_step( | ||
| policy_step, policy_step.output, time_step, | ||
| transformed_time_step, policy_state, experience_list, | ||
| original_reward_list) | ||
| store_exp_time += store_exp_time_i | ||
| effective_unroll_steps += effective_unroll_steps_i | ||
|
|
||
| alf.summary.scalar("time/unroll_env_step", | ||
| env_step_time, | ||
|
|
@@ -596,20 +599,59 @@ def _async_unroll(self, unroll_length: int): | |
|
|
||
| self._current_transform_state = common.detach(trans_state) | ||
|
|
||
| return experience | ||
| # if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as | ||
| # an effective unroll iter | ||
| effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length | ||
| return experience, effective_unroll_iters | ||
|
|
||
| def post_process_experience(self, rollout_info, step_type: StepType, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This name is confusing with the existing function
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed to |
||
| experiences: Experience) -> Tuple[List, int]: | ||
| """A function for postprocessing experience. By default, it returns the input | ||
| experience unmodified. Users can customize this function in the derived | ||
| class to achieve different effects. For example: | ||
| - per-step processing: return the current step of experience unmodified (by default) | ||
| or a modified version according to the customized ``post_process_experience``. | ||
| As another example, task filtering can be simply achieved by returning ``[]`` | ||
| for that particular task. | ||
| - per-episode processing: this can be achieved by returning a list of processed | ||
| experiences. For example, this can be used for success episode labeling. | ||
|
|
||
| Args: | ||
| rollout_info: the rollout info. | ||
| step_type: the step type of the current experience. | ||
| experiences: one step of experience. | ||
|
|
||
| Returns: | ||
| - a list of experiences. Users can customize this functions in the | ||
| derived class to achieve different effects. For example: | ||
| * return a list that contains only the input experience (default behavior). | ||
| * return a list that contains a number of experiences. This can be useful | ||
| for episode processing such as success episode labeling. | ||
| - an integer representing the effective number of unroll steps per env. The | ||
| default value of 1, meaning the length of effective experience is 1 | ||
| after calling ``post_process_experience``, the same as the input length | ||
| of experience. | ||
| """ | ||
| return [experiences], 1 | ||
|
|
||
| def _process_unroll_step(self, policy_step, action, time_step, | ||
| transformed_time_step, policy_state, | ||
| experience_list, original_reward_list): | ||
| experience_list, | ||
| original_reward_list) -> Tuple[int, int]: | ||
| self.observe_for_metrics(time_step.cpu()) | ||
| exp = make_experience(time_step.cpu(), | ||
| alf.layers.to_float32(policy_step), | ||
| alf.layers.to_float32(policy_state)) | ||
|
|
||
| effective_unroll_steps = 1 | ||
| store_exp_time = 0 | ||
| if not self.on_policy: | ||
| # 1) post process | ||
| post_processed_exp_list, effective_unroll_steps = self.post_process_experience( | ||
| policy_step.info, time_step.step_type, exp) | ||
| # 2) observe | ||
| t0 = time.time() | ||
| self.observe_for_replay(exp) | ||
| for exp in post_processed_exp_list: | ||
| self.observe_for_replay(exp) | ||
| store_exp_time = time.time() - t0 | ||
|
|
||
| exp_for_training = Experience( | ||
|
|
@@ -620,7 +662,7 @@ def _process_unroll_step(self, policy_step, action, time_step, | |
|
|
||
| experience_list.append(exp_for_training) | ||
| original_reward_list.append(time_step.reward) | ||
| return store_exp_time | ||
| return store_exp_time, effective_unroll_steps | ||
|
|
||
| def reset_state(self): | ||
| """Reset the state of the algorithm. | ||
|
|
@@ -644,6 +686,8 @@ def _sync_unroll(self, unroll_length: int): | |
| Returns: | ||
| Experience: The stacked experience with shape :math:`[T, B, \ldots]` | ||
| for each of its members. | ||
| effective_unroll_iters: the effective number of unroll iterations. | ||
| Each unroll iteration contains ``unroll_length`` unroll steps. | ||
| """ | ||
| if self._current_time_step is None: | ||
| self._current_time_step = common.get_initial_time_step(self._env) | ||
|
|
@@ -665,6 +709,7 @@ def _sync_unroll(self, unroll_length: int): | |
| policy_step_time = 0. | ||
| env_step_time = 0. | ||
| store_exp_time = 0. | ||
| effective_unroll_steps = 0 | ||
| for _ in range(unroll_length): | ||
| policy_state = common.reset_state_if_necessary( | ||
| policy_state, initial_state, time_step.is_first()) | ||
|
|
@@ -693,9 +738,11 @@ def _sync_unroll(self, unroll_length: int): | |
| if self._overwrite_policy_output: | ||
| policy_step = policy_step._replace( | ||
| output=next_time_step.prev_action) | ||
| store_exp_time += self._process_unroll_step( | ||
| store_exp_time_i, effective_unroll_steps_i = self._process_unroll_step( | ||
| policy_step, action, time_step, transformed_time_step, | ||
| policy_state, experience_list, original_reward_list) | ||
| store_exp_time += store_exp_time_i | ||
| effective_unroll_steps += effective_unroll_steps_i | ||
|
|
||
| time_step = next_time_step | ||
| policy_state = policy_step.state | ||
|
|
@@ -723,7 +770,10 @@ def _sync_unroll(self, unroll_length: int): | |
| self._current_policy_state = common.detach(policy_state) | ||
| self._current_transform_state = common.detach(trans_state) | ||
|
|
||
| return experience | ||
| # if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as | ||
| # an effective unroll iter | ||
| effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's strange to call unroll "iter"? The original definition is that each training iter we have one unroll. So what does unroll iters mean in this context?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added comments. One |
||
| return experience, effective_unroll_iters | ||
|
|
||
| def train_iter(self): | ||
| """Perform one iteration of training. | ||
|
|
@@ -747,7 +797,7 @@ def _compute_train_info_and_loss_info_on_policy(self, unroll_length): | |
| with record_time("time/unroll"): | ||
| with torch.cuda.amp.autocast(self._config.enable_amp, | ||
| dtype=self._config.amp_dtype): | ||
| experience = self.unroll(self._config.unroll_length) | ||
| experience, _ = self.unroll(self._config.unroll_length) | ||
| self.summarize_metrics() | ||
|
|
||
| train_info = experience.rollout_info | ||
|
|
@@ -788,6 +838,9 @@ def _unroll_iter_off_policy(self): | |
| unroll length, it may not have been called. | ||
| - root_inputs: root-level time step returned by the unroll | ||
| - rollout_info: rollout info returned by the unroll | ||
| - effective_unroll_iters: the effective number of unroll iterations. | ||
| ``train_from_replay_buffer`` will be run ``effective_unroll_iters`` times | ||
| during ``_train_iter_off_policy``. | ||
| """ | ||
| config: TrainerConfig = self._config | ||
|
|
||
|
|
@@ -804,6 +857,7 @@ def _unroll_iter_off_policy(self): | |
| unrolled = False | ||
| root_inputs = None | ||
| rollout_info = None | ||
| effective_unroll_iters = 0 | ||
| if (alf.summary.get_global_counter() | ||
| >= self._rl_train_after_update_steps | ||
| and (unroll_length > 0 or config.unroll_length == 0) and | ||
|
|
@@ -822,19 +876,21 @@ def _unroll_iter_off_policy(self): | |
| # need to remember whether summary has been written between | ||
| # two unrolls. | ||
| with self._ensure_rollout_summary: | ||
| experience = self.unroll(unroll_length) | ||
| experience, effective_unroll_iters = self.unroll( | ||
| unroll_length) | ||
| if experience: | ||
| self.summarize_rollout(experience) | ||
| self.summarize_metrics() | ||
| rollout_info = experience.rollout_info | ||
| if config.use_root_inputs_for_after_train_iter: | ||
| root_inputs = experience.time_step | ||
| del experience | ||
| return unrolled, root_inputs, rollout_info | ||
| return unrolled, root_inputs, rollout_info, effective_unroll_iters | ||
|
|
||
| def _train_iter_off_policy(self): | ||
| """User may override this for their own training procedure.""" | ||
| unrolled, root_inputs, rollout_info = self._unroll_iter_off_policy() | ||
| unrolled, root_inputs, rollout_info, effective_unroll_iters = self._unroll_iter_off_policy( | ||
| ) | ||
|
|
||
| # replay buffer may not have been created for two different reasons: | ||
| # 1. in online RL training (``has_offline`` is False), unroll is not | ||
|
|
@@ -846,11 +902,12 @@ def _train_iter_off_policy(self): | |
| return 0 | ||
|
|
||
| self.train() | ||
| steps = self.train_from_replay_buffer(update_global_counter=True) | ||
|
|
||
| if unrolled: | ||
| with record_time("time/after_train_iter"): | ||
| self.after_train_iter(root_inputs, rollout_info) | ||
| steps = 0 | ||
| for i in range(effective_unroll_iters): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's possible the effective_unroll_iters is always smaller than 1 in the case of num_envs > 1.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Now also handles the fractional unroll case. |
||
| steps += self.train_from_replay_buffer(update_global_counter=True) | ||
| if unrolled: | ||
| with record_time("time/after_train_iter"): | ||
| self.after_train_iter(root_inputs, rollout_info) | ||
|
|
||
| # For now, we only return the steps of the primary algorithm's training | ||
| return steps | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we lack a formal definition of "effective" in the code document.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added more comments with examples, especially in
preprocess_unroll_experience