-
Notifications
You must be signed in to change notification settings - Fork 59
Post Process Experience with Customizable Modes #1768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pytorch
Are you sure you want to change the base?
Changes from 1 commit
00efea8
e4cdb81
a05e8da
9cfe6a5
26ab09a
734dae8
94a50bf
8fc3ff2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1426,7 +1426,9 @@ def train_from_unroll(self, experience, train_info): | |
| return shape[0] * shape[1] | ||
|
|
||
| @common.mark_replay | ||
| def train_from_replay_buffer(self, update_global_counter=False): | ||
| def train_from_replay_buffer(self, | ||
| effective_unroll_steps, | ||
| update_global_counter=False): | ||
| """This function can be called by any algorithm that has its own | ||
| replay buffer configured. There are several parameters specified in | ||
| ``self._config`` that will affect how the training is performed: | ||
|
|
@@ -1469,6 +1471,7 @@ def train_from_replay_buffer(self, update_global_counter=False): | |
| ``True``, it will affect the counter only if | ||
| ``config.update_counter_every_mini_batch=True``. | ||
| """ | ||
|
|
||
| config: TrainerConfig = self._config | ||
|
|
||
| # returns 0 if haven't started training yet, when ``_replay_buffer`` is | ||
|
|
@@ -1479,7 +1482,8 @@ def train_from_replay_buffer(self, update_global_counter=False): | |
| # training is not started yet, ``_replay_buffer`` will be None since it | ||
| # is only lazily created later when online RL training started. | ||
| if (self._replay_buffer and self._replay_buffer.total_size | ||
| < config.initial_collect_steps): | ||
| < config.initial_collect_steps) or (effective_unroll_steps | ||
| == 0): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any situation that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. effective_unroll_steps is now removed from this function |
||
| assert ( | ||
| self._replay_buffer.num_environments * | ||
| self._replay_buffer.max_length >= config.initial_collect_steps | ||
|
|
@@ -1493,21 +1497,22 @@ def _replay(): | |
| # ``_replay_buffer`` for training. | ||
| # TODO: If this function can be called asynchronously, and using | ||
| # prioritized replay, then make sure replay and train below is atomic. | ||
| effective_num_updates_per_train_iter = config.num_updates_per_train_iter | ||
| with record_time("time/replay"): | ||
| mini_batch_size = config.mini_batch_size | ||
| if mini_batch_size is None: | ||
| mini_batch_size = self._replay_buffer.num_environments | ||
| if config.whole_replay_buffer_training: | ||
| experience, batch_info = self._replay_buffer.gather_all( | ||
| ignore_earliest_frames=True) | ||
| num_updates = config.num_updates_per_train_iter | ||
| num_updates = effective_num_updates_per_train_iter | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you need to make this change?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not necessary anymore. removed |
||
| else: | ||
| assert config.mini_batch_length is not None, ( | ||
| "No mini_batch_length is specified for off-policy training" | ||
| ) | ||
| experience, batch_info = self._replay_buffer.get_batch( | ||
| batch_size=(mini_batch_size * | ||
| config.num_updates_per_train_iter), | ||
| effective_num_updates_per_train_iter), | ||
| batch_length=config.mini_batch_length) | ||
| num_updates = 1 | ||
| return experience, batch_info, num_updates, mini_batch_size | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |
| import os | ||
| import time | ||
| import torch | ||
| from typing import Callable, Optional | ||
| from typing import Callable, List, Optional | ||
| from absl import logging | ||
|
|
||
| import alf | ||
|
|
@@ -147,6 +147,7 @@ def __init__(self, | |
| optimizer=None, | ||
| checkpoint=None, | ||
| is_eval: bool = False, | ||
| episodic_annotation: bool = False, | ||
| overwrite_policy_output=False, | ||
| debug_summaries=False, | ||
| name="RLAlgorithm"): | ||
|
|
@@ -186,6 +187,8 @@ def __init__(self, | |
| during deployment. In this case, the algorithm do not need to | ||
| create certain components such as value_network for ActorCriticAlgorithm, | ||
| critic_networks for SacAlgorithm. | ||
| episodic_annotation: if True, annotate the episode before being observed by the | ||
| replay buffer. | ||
| overwrite_policy_output (bool): if True, overwrite the policy output | ||
| with next_step.prev_action. This option can be used in some | ||
| cases such as data collection. | ||
|
|
@@ -203,6 +206,7 @@ def __init__(self, | |
| debug_summaries=debug_summaries, | ||
| name=name) | ||
| self._is_eval = is_eval | ||
| self._episodic_annotation = episodic_annotation | ||
|
|
||
| self._env = env | ||
| self._observation_spec = observation_spec | ||
|
|
@@ -235,11 +239,14 @@ def __init__(self, | |
| self._current_time_step = None | ||
| self._current_policy_state = None | ||
| self._current_transform_state = None | ||
|
|
||
| self._cached_exp = [] # for lazy observation | ||
| if self._env is not None and not self.on_policy: | ||
| replay_buffer_length = adjust_replay_buffer_length( | ||
| config, self._num_earliest_frames_ignored) | ||
|
|
||
| if self._episodic_annotation: | ||
| assert self._env.batch_size == 1, "only support non-batched environment" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add this to the docstring of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assertion is not necessary here so remove also |
||
|
|
||
| if config.whole_replay_buffer_training and config.clear_replay_buffer: | ||
| # For whole replay buffer training, we would like to be sure | ||
| # that the replay buffer have enough samples in it to perform | ||
|
|
@@ -598,19 +605,62 @@ def _async_unroll(self, unroll_length: int): | |
|
|
||
| return experience | ||
|
|
||
| def should_post_process_episode(self, rollout_info, step_type: StepType): | ||
| """A function that determines whether the ``post_process_episode`` function should | ||
| be applied to the current list of experiences. | ||
| """ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an interface mainly used for subclasses? Maybe mention this. Same for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Added comments. Also for post_process_episode |
||
| return False | ||
|
|
||
| def post_process_episode(self, experiences: List[Experience]): | ||
| """A function for postprocessing a list of experience. It is called when | ||
| ``should_post_process_episode`` is True. | ||
| It can be used to create a number of useful features such as 'hindsight relabeling' | ||
| of a trajectory etc. | ||
|
|
||
| Args: | ||
| experiences: a list of experience, containing the experience starting from the | ||
| initial time when ``should_post_process_episode`` is False to the step where | ||
| ``should_post_process_episode`` is True. | ||
| """ | ||
| return None | ||
|
|
||
| def _process_unroll_step(self, policy_step, action, time_step, | ||
| transformed_time_step, policy_state, | ||
| experience_list, original_reward_list): | ||
| self.observe_for_metrics(time_step.cpu()) | ||
| exp = make_experience(time_step.cpu(), | ||
| alf.layers.to_float32(policy_step), | ||
| alf.layers.to_float32(policy_state)) | ||
|
|
||
| store_exp_time = 0 | ||
| if not self.on_policy: | ||
| t0 = time.time() | ||
| self.observe_for_replay(exp) | ||
| store_exp_time = time.time() - t0 | ||
| effective_number_of_unroll_steps = 1 | ||
| if self._episodic_annotation: | ||
| store_exp_time = 0 | ||
| # if last step, annotate | ||
| rollout_info = policy_step.info | ||
| self._cached_exp.append(exp) | ||
| if self.should_post_process_episode(rollout_info, | ||
| time_step.step_type): | ||
|
|
||
| # 1) process | ||
| annotated_exp_list = self.post_process_episode( | ||
| self._cached_exp) | ||
| effective_number_of_unroll_steps = len(annotated_exp_list) | ||
| # 2) observe | ||
| if not self.on_policy: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this condition check should be performed earlier, since it seems a waste to do all the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated |
||
| t0 = time.time() | ||
| for exp in annotated_exp_list: | ||
| self.observe_for_replay(exp) | ||
| store_exp_time = time.time() - t0 | ||
| # clean up the exp cache | ||
| self._cached_exp = [] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to assume that all envs end on the same step? What if some envs are LAST, some are MID? cached_exp will be cleared even for those with MID steps? Even when doing this for an env with batch_size 1, this annotation mode will delay experience from being stored into the replay buffer. Ok to submit the change as is, but may need to do two things:
Also, delaying train_step because of delayed experience storage can have unexpected side effects, e.g. if episodes are 100 steps long, and unroll once per train iter, then summary will only happen every 100 train iters. It will also shift the distribution of the data training sees due to the delay. Overall I think doing this episode level relabeling at the DataTransformer stage, after reading from replay_buffer is perhaps a better way, and a cleaner way as well (less scattered code). That would require the replay buffer to keep track of episode begin and end, which I think it already does.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There is no such assumption. It is totally up to the users to inject their own assumptions.
No it won't. By default, the behavior is the same as before.
Different use cases. This is an alternative interface that can support more than pure relabeling (e.g. excluding data), which is not directly supported by the replay buffer hindsight relabel.
|
||
| else: | ||
| # effective unroll steps as 0 if not post_process_episode timepoint yet | ||
| effective_number_of_unroll_steps = 0 | ||
| else: | ||
| store_exp_time = 0 | ||
| if not self.on_policy: | ||
| t0 = time.time() | ||
| self.observe_for_replay(exp) | ||
| store_exp_time = time.time() - t0 | ||
|
|
||
| exp_for_training = Experience( | ||
| time_step=transformed_time_step, | ||
|
|
@@ -620,7 +670,7 @@ def _process_unroll_step(self, policy_step, action, time_step, | |
|
|
||
| experience_list.append(exp_for_training) | ||
| original_reward_list.append(time_step.reward) | ||
| return store_exp_time | ||
| return store_exp_time, effective_number_of_unroll_steps | ||
|
|
||
| def reset_state(self): | ||
| """Reset the state of the algorithm. | ||
|
|
@@ -665,6 +715,7 @@ def _sync_unroll(self, unroll_length: int): | |
| policy_step_time = 0. | ||
| env_step_time = 0. | ||
| store_exp_time = 0. | ||
| effective_unroll_steps = 0 | ||
| for _ in range(unroll_length): | ||
| policy_state = common.reset_state_if_necessary( | ||
| policy_state, initial_state, time_step.is_first()) | ||
|
|
@@ -693,9 +744,10 @@ def _sync_unroll(self, unroll_length: int): | |
| if self._overwrite_policy_output: | ||
| policy_step = policy_step._replace( | ||
| output=next_time_step.prev_action) | ||
| store_exp_time += self._process_unroll_step( | ||
| store_exp_time_i, effective_unroll_steps = self._process_unroll_step( | ||
| policy_step, action, time_step, transformed_time_step, | ||
| policy_state, experience_list, original_reward_list) | ||
| store_exp_time += store_exp_time_i | ||
|
|
||
| time_step = next_time_step | ||
| policy_state = policy_step.state | ||
|
|
@@ -723,7 +775,7 @@ def _sync_unroll(self, unroll_length: int): | |
| self._current_policy_state = common.detach(policy_state) | ||
| self._current_transform_state = common.detach(trans_state) | ||
|
|
||
| return experience | ||
| return experience, effective_unroll_steps | ||
|
|
||
| def train_iter(self): | ||
| """Perform one iteration of training. | ||
|
|
@@ -804,6 +856,7 @@ def _unroll_iter_off_policy(self): | |
| unrolled = False | ||
| root_inputs = None | ||
| rollout_info = None | ||
| effective_unroll_steps = 0 | ||
| if (alf.summary.get_global_counter() | ||
| >= self._rl_train_after_update_steps | ||
| and (unroll_length > 0 or config.unroll_length == 0) and | ||
|
|
@@ -822,19 +875,21 @@ def _unroll_iter_off_policy(self): | |
| # need to remember whether summary has been written between | ||
| # two unrolls. | ||
| with self._ensure_rollout_summary: | ||
| experience = self.unroll(unroll_length) | ||
| experience, effective_unroll_steps = self.unroll( | ||
| unroll_length) | ||
| if experience: | ||
| self.summarize_rollout(experience) | ||
| self.summarize_metrics() | ||
| rollout_info = experience.rollout_info | ||
| if config.use_root_inputs_for_after_train_iter: | ||
| root_inputs = experience.time_step | ||
| del experience | ||
| return unrolled, root_inputs, rollout_info | ||
| return unrolled, root_inputs, rollout_info, effective_unroll_steps | ||
|
|
||
| def _train_iter_off_policy(self): | ||
| """User may override this for their own training procedure.""" | ||
| unrolled, root_inputs, rollout_info = self._unroll_iter_off_policy() | ||
| unrolled, root_inputs, rollout_info, effective_unroll_steps = self._unroll_iter_off_policy( | ||
| ) | ||
|
|
||
| # replay buffer may not have been created for two different reasons: | ||
| # 1. in online RL training (``has_offline`` is False), unroll is not | ||
|
|
@@ -846,11 +901,13 @@ def _train_iter_off_policy(self): | |
| return 0 | ||
|
|
||
| self.train() | ||
| steps = self.train_from_replay_buffer(update_global_counter=True) | ||
|
|
||
| if unrolled: | ||
| with record_time("time/after_train_iter"): | ||
| self.after_train_iter(root_inputs, rollout_info) | ||
| steps = 0 | ||
| for i in range(effective_unroll_steps): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unroll_steps is the wrong name? It should be called unroll_iterations to indicate training iterations, not env steps? also rename effective_number_of_unroll_steps to be effective_unroll_iters to be consistent. (i.e. remove "number_of_")
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the comments. Changed. |
||
| steps += self.train_from_replay_buffer(effective_unroll_steps=1, | ||
| update_global_counter=True) | ||
| if unrolled: | ||
| with record_time("time/after_train_iter"): | ||
| self.after_train_iter(root_inputs, rollout_info) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel that this update fundamentally changes the off-policy update logic w.r.t. its actual unroll in the env. Previously, between every call of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If In the derived class, it is up to the user for determining what kind of annotation function he/she wants to implement and use. |
||
|
|
||
| # For now, we only return the steps of the primary algorithm's training | ||
| return steps | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add docstring for this arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This arg is now removed