-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_simulated_env.py
More file actions
61 lines (54 loc) · 2.45 KB
/
Copy patheval_simulated_env.py
File metadata and controls
61 lines (54 loc) · 2.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from project_init import *
from tools import *
from simulated_env import *
if __name__ == '__main__':
env_names, envs, env_info = gen_environments(CONFIG.env_setting)
mix_mem_path = gen_mix_mem_path(env_names)
vae_weights_path = gen_vae_weights_path(env_names)
predictor_weights_path = gen_predictor_weights_path(env_names)
if CONFIG.env_setting == 'gridworld_3_rooms_rand_starts':
print('Deactivating random starts for control')
for env in envs:
env.player_random_start = False
# load and prepare data
mix_memory = load_env_samples(mix_mem_path)
train_data_var = np.var(mix_memory['s'][0] / 255)
del mix_memory
# instantiate vae and load trained weights
vae = vq_vae_net(obs_shape=env_info['obs_shape'],
n_embeddings=CONFIG.vae_n_embeddings,
d_embeddings=CONFIG.vae_d_embeddings,
train_data_var=train_data_var,
commitment_cost=CONFIG.vae_commitment_cost,
frame_stack=CONFIG.vae_frame_stack,
summary=CONFIG.model_summaries,
tf_eager_mode=CONFIG.tf_eager_mode)
load_vae_weights(vae=vae, weights_path=vae_weights_path)
# instantiate predictor
pred = predictor_net(n_actions=env_info['n_actions'],
obs_shape=env_info['obs_shape'],
n_envs=len(envs),
vae=vae,
det_filters=CONFIG.pred_det_filters,
prob_filters=CONFIG.pred_prob_filters,
decider_lw=CONFIG.pred_decider_lw,
n_models=CONFIG.pred_n_models,
tensorboard_log=CONFIG.pred_tb_log,
summary=CONFIG.model_summaries,
tf_eager_mode=CONFIG.tf_eager_mode)
pred.load_weights(predictor_weights_path)
#dream_env = MultiSimulatedLatentSpaceEnv(envs, pred, vae, [0, 1, 2], 0.9)
#dream_env = MultiLatentSpaceEnv(envs, vae, [0, 1, 2])
#dream_env = MultiEnv(envs, [0, 1, 2])
dream_env = SimulatedLatentSpaceEnv(envs[2], pred, vae, 0)
#dream_env = LatentSpaceEnv(envs[2], vae)
print(dream_env.reset())
rewards = []
for t in range(1000):
dream_env.render()
a = dream_env.action_space.sample()
s, r, done, info = dream_env.step(a)
rewards.append(r)
if done:
break
print(f'Return: {np.sum(rewards)}')