-
Notifications
You must be signed in to change notification settings - Fork 59
Post Process Experience with Customizable Modes #1768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pytorch
Are you sure you want to change the base?
Changes from 3 commits
00efea8
e4cdb81
a05e8da
9cfe6a5
26ab09a
734dae8
94a50bf
8fc3ff2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |
| import os | ||
| import time | ||
| import torch | ||
| from typing import Callable, Optional | ||
| from typing import Callable, List, Optional | ||
| from absl import logging | ||
|
|
||
| import alf | ||
|
|
@@ -147,6 +147,7 @@ def __init__(self, | |
| optimizer=None, | ||
| checkpoint=None, | ||
| is_eval: bool = False, | ||
| episodic_annotation: bool = False, | ||
| overwrite_policy_output=False, | ||
| debug_summaries=False, | ||
| name="RLAlgorithm"): | ||
|
|
@@ -186,6 +187,10 @@ def __init__(self, | |
| during deployment. In this case, the algorithm do not need to | ||
| create certain components such as value_network for ActorCriticAlgorithm, | ||
| critic_networks for SacAlgorithm. | ||
| episodic_annotation: episodic annotation is an operation that annotates the | ||
| episode after it being collected, and then the annotated episode will be | ||
| observed by the replay buffer. If True, annotate the episode before being | ||
| observed by the replay buffer. Otherwise, episodic annotation is not applied. | ||
| overwrite_policy_output (bool): if True, overwrite the policy output | ||
| with next_step.prev_action. This option can be used in some | ||
| cases such as data collection. | ||
|
|
@@ -203,6 +208,7 @@ def __init__(self, | |
| debug_summaries=debug_summaries, | ||
| name=name) | ||
| self._is_eval = is_eval | ||
| self._episodic_annotation = episodic_annotation | ||
|
|
||
| self._env = env | ||
| self._observation_spec = observation_spec | ||
|
|
@@ -235,7 +241,7 @@ def __init__(self, | |
| self._current_time_step = None | ||
| self._current_policy_state = None | ||
| self._current_transform_state = None | ||
|
|
||
| self._cached_exp = [] # for lazy observation | ||
| if self._env is not None and not self.on_policy: | ||
| replay_buffer_length = adjust_replay_buffer_length( | ||
| config, self._num_earliest_frames_ignored) | ||
|
|
@@ -566,10 +572,11 @@ def _async_unroll(self, unroll_length: int): | |
| step_time += unroll_result.step_time | ||
| max_step_time = max(max_step_time, unroll_result.step_time) | ||
|
|
||
| store_exp_time += self._process_unroll_step( | ||
| store_exp_time_i, effective_unroll_steps = self._process_unroll_step( | ||
| policy_step, policy_step.output, time_step, | ||
| transformed_time_step, policy_state, experience_list, | ||
| original_reward_list) | ||
| store_exp_time += store_exp_time_i | ||
|
|
||
| alf.summary.scalar("time/unroll_env_step", | ||
| env_step_time, | ||
|
|
@@ -596,7 +603,32 @@ def _async_unroll(self, unroll_length: int): | |
|
|
||
| self._current_transform_state = common.detach(trans_state) | ||
|
|
||
| return experience | ||
| return experience, effective_unroll_steps | ||
|
|
||
| def should_post_process_episode(self, rollout_info, step_type: StepType): | ||
| """A function that determines whether the ``post_process_episode`` function should | ||
| be applied to the current list of experiences. | ||
| Users can customize this function in the derived class. | ||
| Bu default, it returns True all the time steps. When this is combined with | ||
| ``post_process_episode`` which simply return the input unmodified (as the default | ||
| implementation in this class), it is a dummy version of eposodic annotation with | ||
| logic equivalent to the case of episodic_annotation=False. | ||
| """ | ||
| return True | ||
|
|
||
| def post_process_episode(self, experiences: List[Experience]): | ||
| """A function for postprocessing a list of experience. It is called when | ||
| ``should_post_process_episode`` is True. | ||
| By default, it returns the input unmodified. | ||
| Users can customize this function in the derived class, to create a number of | ||
| useful features such as 'hindsight relabeling' of a trajectory etc. | ||
|
|
||
| Args: | ||
| experiences: a list of experience, containing the experience starting from the | ||
| initial time when ``should_post_process_episode`` is False to the step where | ||
| ``should_post_process_episode`` is True. | ||
| """ | ||
| return experiences | ||
|
|
||
| def _process_unroll_step(self, policy_step, action, time_step, | ||
| transformed_time_step, policy_state, | ||
|
|
@@ -605,12 +637,36 @@ def _process_unroll_step(self, policy_step, action, time_step, | |
| exp = make_experience(time_step.cpu(), | ||
| alf.layers.to_float32(policy_step), | ||
| alf.layers.to_float32(policy_state)) | ||
|
|
||
| store_exp_time = 0 | ||
| if not self.on_policy: | ||
| t0 = time.time() | ||
| self.observe_for_replay(exp) | ||
| store_exp_time = time.time() - t0 | ||
| effective_number_of_unroll_steps = 1 | ||
| if self._episodic_annotation: | ||
| assert not self.on_policy, "only support episodic annotation for off policy training" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe assert this in the |
||
| store_exp_time = 0 | ||
| # if last step, annotate | ||
| rollout_info = policy_step.info | ||
| self._cached_exp.append(exp) | ||
| if self.should_post_process_episode(rollout_info, | ||
| time_step.step_type): | ||
|
|
||
| # 1) process | ||
| annotated_exp_list = self.post_process_episode( | ||
| self._cached_exp) | ||
| effective_number_of_unroll_steps = len(annotated_exp_list) | ||
| # 2) observe | ||
| t0 = time.time() | ||
| for exp in annotated_exp_list: | ||
| self.observe_for_replay(exp) | ||
| store_exp_time = time.time() - t0 | ||
| # clean up the exp cache | ||
| self._cached_exp = [] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to assume that all envs end on the same step? What if some envs are LAST, some are MID? cached_exp will be cleared even for those with MID steps? Even when doing this for an env with batch_size 1, this annotation mode will delay experience from being stored into the replay buffer. Ok to submit the change as is, but may need to do two things:
Also, delaying train_step because of delayed experience storage can have unexpected side effects, e.g. if episodes are 100 steps long, and unroll once per train iter, then summary will only happen every 100 train iters. It will also shift the distribution of the data training sees due to the delay. Overall I think doing this episode level relabeling at the DataTransformer stage, after reading from replay_buffer is perhaps a better way, and a cleaner way as well (less scattered code). That would require the replay buffer to keep track of episode begin and end, which I think it already does.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There is no such assumption. It is totally up to the users to inject their own assumptions.
No it won't. By default, the behavior is the same as before.
Different use cases. This is an alternative interface that can support more than pure relabeling (e.g. excluding data), which is not directly supported by the replay buffer hindsight relabel.
|
||
| else: | ||
| # effective unroll steps as 0 if not post_process_episode timepoint yet | ||
| effective_number_of_unroll_steps = 0 | ||
| else: | ||
| store_exp_time = 0 | ||
| if not self.on_policy: | ||
| t0 = time.time() | ||
| self.observe_for_replay(exp) | ||
| store_exp_time = time.time() - t0 | ||
|
|
||
| exp_for_training = Experience( | ||
| time_step=transformed_time_step, | ||
|
|
@@ -620,7 +676,7 @@ def _process_unroll_step(self, policy_step, action, time_step, | |
|
|
||
| experience_list.append(exp_for_training) | ||
| original_reward_list.append(time_step.reward) | ||
| return store_exp_time | ||
| return store_exp_time, effective_number_of_unroll_steps | ||
|
|
||
| def reset_state(self): | ||
| """Reset the state of the algorithm. | ||
|
|
@@ -665,6 +721,7 @@ def _sync_unroll(self, unroll_length: int): | |
| policy_step_time = 0. | ||
| env_step_time = 0. | ||
| store_exp_time = 0. | ||
| effective_unroll_steps = 0 | ||
| for _ in range(unroll_length): | ||
| policy_state = common.reset_state_if_necessary( | ||
| policy_state, initial_state, time_step.is_first()) | ||
|
|
@@ -693,9 +750,10 @@ def _sync_unroll(self, unroll_length: int): | |
| if self._overwrite_policy_output: | ||
| policy_step = policy_step._replace( | ||
| output=next_time_step.prev_action) | ||
| store_exp_time += self._process_unroll_step( | ||
| store_exp_time_i, effective_unroll_steps = self._process_unroll_step( | ||
| policy_step, action, time_step, transformed_time_step, | ||
| policy_state, experience_list, original_reward_list) | ||
| store_exp_time += store_exp_time_i | ||
|
|
||
| time_step = next_time_step | ||
| policy_state = policy_step.state | ||
|
|
@@ -723,7 +781,7 @@ def _sync_unroll(self, unroll_length: int): | |
| self._current_policy_state = common.detach(policy_state) | ||
| self._current_transform_state = common.detach(trans_state) | ||
|
|
||
| return experience | ||
| return experience, effective_unroll_steps | ||
|
|
||
| def train_iter(self): | ||
| """Perform one iteration of training. | ||
|
|
@@ -747,7 +805,7 @@ def _compute_train_info_and_loss_info_on_policy(self, unroll_length): | |
| with record_time("time/unroll"): | ||
| with torch.cuda.amp.autocast(self._config.enable_amp, | ||
| dtype=self._config.amp_dtype): | ||
| experience = self.unroll(self._config.unroll_length) | ||
| experience, _ = self.unroll(self._config.unroll_length) | ||
| self.summarize_metrics() | ||
|
|
||
| train_info = experience.rollout_info | ||
|
|
@@ -804,6 +862,7 @@ def _unroll_iter_off_policy(self): | |
| unrolled = False | ||
| root_inputs = None | ||
| rollout_info = None | ||
| effective_unroll_steps = 0 | ||
| if (alf.summary.get_global_counter() | ||
| >= self._rl_train_after_update_steps | ||
| and (unroll_length > 0 or config.unroll_length == 0) and | ||
|
|
@@ -822,19 +881,21 @@ def _unroll_iter_off_policy(self): | |
| # need to remember whether summary has been written between | ||
| # two unrolls. | ||
| with self._ensure_rollout_summary: | ||
| experience = self.unroll(unroll_length) | ||
| experience, effective_unroll_steps = self.unroll( | ||
| unroll_length) | ||
| if experience: | ||
| self.summarize_rollout(experience) | ||
| self.summarize_metrics() | ||
| rollout_info = experience.rollout_info | ||
| if config.use_root_inputs_for_after_train_iter: | ||
| root_inputs = experience.time_step | ||
| del experience | ||
| return unrolled, root_inputs, rollout_info | ||
| return unrolled, root_inputs, rollout_info, effective_unroll_steps | ||
|
|
||
| def _train_iter_off_policy(self): | ||
| """User may override this for their own training procedure.""" | ||
| unrolled, root_inputs, rollout_info = self._unroll_iter_off_policy() | ||
| unrolled, root_inputs, rollout_info, effective_unroll_steps = self._unroll_iter_off_policy( | ||
| ) | ||
|
|
||
| # replay buffer may not have been created for two different reasons: | ||
| # 1. in online RL training (``has_offline`` is False), unroll is not | ||
|
|
@@ -846,11 +907,12 @@ def _train_iter_off_policy(self): | |
| return 0 | ||
|
|
||
| self.train() | ||
| steps = self.train_from_replay_buffer(update_global_counter=True) | ||
|
|
||
| if unrolled: | ||
| with record_time("time/after_train_iter"): | ||
| self.after_train_iter(root_inputs, rollout_info) | ||
| steps = 0 | ||
| for i in range(effective_unroll_steps): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unroll_steps is the wrong name? It should be called unroll_iterations to indicate training iterations, not env steps? also rename effective_number_of_unroll_steps to be effective_unroll_iters to be consistent. (i.e. remove "number_of_")
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the comments. Changed. |
||
| steps += self.train_from_replay_buffer(update_global_counter=True) | ||
| if unrolled: | ||
| with record_time("time/after_train_iter"): | ||
| self.after_train_iter(root_inputs, rollout_info) | ||
|
|
||
| # For now, we only return the steps of the primary algorithm's training | ||
| return steps | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need to make this change?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not necessary anymore. removed