Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 58 additions & 27 deletions alf/algorithms/distributed_off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def __init__(self,
port: int = 50000,
env: AlfEnvironment = None,
config: TrainerConfig = None,
optimizer: alf.optimizers.Optimizer = None,
debug_summaries: bool = False,
name: str = "DistributedOffPolicyAlgorithm",
**kwargs):
Expand All @@ -117,7 +116,6 @@ def __init__(self,
to always specify this argument.
port: port number for communication on the *current* machine.
env: The environment to interact with. Its batch size must be 1.
optimizer: optimizer for the training the core algorithm.
debug_summaries: True if debug summaries should be created.
name: the name of this algorithm.
*args: args to pass to ``core_alg_ctor``.
Expand All @@ -144,7 +142,6 @@ def __init__(self,
predict_state_spec=core_alg.predict_state_spec,
env=env,
config=config,
optimizer=optimizer,
# Prevent in-alg ckpt since there is no such a use case.
checkpoint=None,
debug_summaries=debug_summaries,
Expand All @@ -155,14 +152,22 @@ def __init__(self,
self._ddp_rank = max(0, PerProcessContext().ddp_rank)
self._num_ranks = PerProcessContext().num_processes

def state_dict(self, *args, **kwargs):
return self._core_alg.state_dict(*args, **kwargs)

def load_state_dict(self, state_dict, strict=True, **kwargs):
return self._core_alg.load_state_dict(state_dict,
strict=strict,
**kwargs)

def _distributed_state_dict(self) -> dict:
"""Return `self._core_alg` state dict for distributed training.

This dict will be used for param syncing between a trainer and an unroller.
Sometimes optimizers have large state vectors which we want to exclude.
Also we should exclude other parameters such as those of pretrained models.
Also, we should exclude other parameters such as those of pretrained models.
"""
# Note that self._core_alg won't create a relay buffer so we don't have
# Note that self._core_alg won't create a replay buffer so we don't have
# to worry about including it in the state dict.
return {
k: v
Expand Down Expand Up @@ -203,6 +208,9 @@ def after_update(self, root_inputs, info):
def after_train_iter(self, root_inputs, rollout_info):
return self._core_alg.after_train_iter(root_inputs, rollout_info)

def summarize_metrics(self):
self._core_alg.summarize_metrics()


def receive_experience_data(replay_buffer: ReplayBuffer,
new_unroller_ips_and_ports: 'Manager.Queue',
Expand Down Expand Up @@ -253,7 +261,9 @@ def receive_experience_data(replay_buffer: ReplayBuffer,
unroller_id, message = socket.recv_multipart()

buffer = io.BytesIO(message)
exp_params = torch.load(buffer, map_location='cpu')
exp_params = torch.load(buffer,
map_location='cpu',
weights_only=False)
# we prune env_info according to the replay buffer for the following reasons:
# 1) avoid env_info mismatch and allow the distributed unroller to have
# a customized env_info for tb summarization,
Expand Down Expand Up @@ -308,7 +318,9 @@ def pull_params_from_trainer(memory_name: str, memory_lock: mp.Lock,


@alf.configurable(whitelist=[
'max_utd_ratio', 'push_params_every_n_grad_updates', 'name', 'optimizer'
'max_utd_ratio',
'push_params_every_n_grad_updates',
'name',
])
class DistributedTrainer(DistributedOffPolicyAlgorithm):

Expand All @@ -319,7 +331,6 @@ def __init__(self,
push_params_every_n_grad_updates: int = 1,
env: AlfEnvironment = None,
config: TrainerConfig = None,
optimizer: alf.optimizers.Optimizer = None,
debug_summaries: bool = False,
name: str = "DistributedTrainer",
**kwargs):
Expand Down Expand Up @@ -347,7 +358,6 @@ def __init__(self,
port=_trainer_addr_config.port,
env=env,
config=config,
optimizer=optimizer,
debug_summaries=debug_summaries,
name=name,
**kwargs)
Expand Down Expand Up @@ -421,9 +431,9 @@ def _send_params_to_unroller(self,
buffer = io.BytesIO()
torch.save(self._distributed_state_dict(), buffer)
self._params_socket.send_multipart([unroller_id1, buffer.getvalue()])
# 3 sec timeout for receiving unroller's acknowledgement
# 60 sec timeout for receiving unroller's acknowledgement
# In case some unrollers might die, we don't want to block forever
for _ in range(30):
for _ in range(600):
try:
_, message = self._params_socket.recv_multipart(
flags=zmq.NOBLOCK)
Expand Down Expand Up @@ -493,27 +503,42 @@ def _wait_unroller_registration():
thread.daemon = True
thread.start()

def _create_data_receiver_subprocess(self):
"""Create a process to receive experience data from unrollers.
def _create_replay_buffer_sample_experience(self):
"""
Create a sample experience used to initialize the replay buffer.
"""
# First create the replay buffer in the main process. For this, we need
# to create a dummy experience to set up the replay buffer.
time_step = self._env.current_time_step()
rollout_state = self.get_initial_rollout_state(self._env.batch_size)
alg_step = self.rollout_step(time_step, rollout_state)
exp = make_experience(time_step, alg_step, rollout_state)
exp = alf.utils.common.prune_exp_replay_state(exp,
self._use_rollout_state,
self.rollout_state_spec,
self.train_state_spec)
return alf.utils.common.prune_exp_replay_state(exp,
self._use_rollout_state,
self.rollout_state_spec,
self.train_state_spec)

# enable multi_processing in replay_buffer here, because we need to
# receive data in a subprocess and process the data in the main process.
def _create_multiprocessing_replay_buffer(self):
"""
Create the replay buffer in shared memory.
"""
ctx = mp.get_context('spawn')
self._set_replay_buffer(exp, mp_context=ctx)
assert self._replay_buffer is None
self._set_replay_buffer(self._create_replay_buffer_sample_experience(),
mp_context=ctx)
assert self._replay_buffer._allow_multiprocess, (
"The replay buffer must allow multi-processing.")

def _create_data_receiver_subprocess(self):
"""
Create a process to receive experience data from unrollers.

The warm-up train_iter() creates the normal trainer replay buffer in
multiprocessing form before checkpoint restore, so restored shards
load directly into the buffer that the receiver subprocess will use.
"""
assert self._replay_buffer is not None
assert self._replay_buffer._allow_multiprocess
ctx = mp.get_context('spawn')

# start the data receiver subprocess
# Need to create the subprocess with 'spawn' so that we can pass a Module
# object to subprocess with tensors in shared memory.
Expand All @@ -535,8 +560,13 @@ def _train_iter_off_policy(self):
if self._num_train_iters == 0:
# First time will be called by ``Trainer._restore_checkpoint()``
# where the ckpt (if any) will be loaded after this function.
# Create the normal trainer replay buffer in multiprocessing form
# before checkpoint load so replay data is restored directly into
# it. Do not start the receiver subprocess yet; it should only
# consume unroller data after checkpoint restore has completed.
self._create_multiprocessing_replay_buffer()
self._num_train_iters += 1
return super()._train_iter_off_policy()
return 0

if self._num_train_iters == 1:
# Only open the unroller registration after we are sure that
Expand Down Expand Up @@ -582,8 +612,7 @@ def _train_iter_off_policy(self):
return steps


@alf.configurable(
whitelist=['episode_length', 'name', 'optimizer', 'unroller_only'])
@alf.configurable(whitelist=['episode_length', 'name', 'unroller_only'])
class DistributedUnroller(DistributedOffPolicyAlgorithm):

def __init__(self,
Expand Down Expand Up @@ -732,8 +761,10 @@ def observe_for_replay(self, exp: Experience):
# Get the current worker id to send the exp to
worker_id = f'worker-{self._current_worker}'
self._num_exps += 1
episode_end = ((self._episode_length <= 0 and bool(exp.is_last()))
or (self._num_exps % self._episode_length == 0))
if self._episode_length <= 0:
episode_end = bool(exp.is_last())
else:
episode_end = (self._num_exps % self._episode_length == 0)

if self._is_first_step:
# When the unroller has a ``max_episode_length``, we need to correctly
Expand Down
Loading