-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain_dis_simu.py
More file actions
63 lines (51 loc) · 2.25 KB
/
Copy pathtrain_dis_simu.py
File metadata and controls
63 lines (51 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import toml
import torch
import torch.nn as nn
from models.single_SSL_model import disesti_3
# simulation
from utils.simulation import MyDataset
import trainers.simulation_trainer as trainer_run
seed = 7
torch.manual_seed(seed)
def run(config, device):
# train dataset
train_dataset = MyDataset(**config['train_dataset'], **config['FFT'], **config['train_gene_setting'])
train_dataloader = torch.utils.data.DataLoader(train_dataset,
**config['train_dataloader'],
collate_fn=train_dataset.collate_fn)
# val dataset
validation_dataset = MyDataset(**config['validation_dataset'], **config['FFT'], **config['val_gene_setting'])
validation_dataloader = torch.utils.data.DataLoader(validation_dataset,
**config['validation_dataloader'],
collate_fn=validation_dataset.collate_fn)
# loss
loss = nn.MSELoss()
model = disesti_3(device, **config['FFT']).to(device)
# more gpus
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model).to(device)
else:
model.to(device)
# optimizer
optimizer = torch.optim.Adam(params=model.parameters(), lr=config['optimizer']['lr'])
# trainer
trainer = trainer_run.Trainer(config=config,
model=model,
optimizer=optimizer,
loss_func=loss,
train_dataset=train_dataset,
train_dataloader=train_dataloader,
validation_dataset=validation_dataset,
validation_dataloader=validation_dataloader,
device=device)
# train
trainer.train()
if __name__ == '__main__':
os.environ['PATH'] = '/sbin:' + os.environ.get('PATH', '')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# config
config = toml.load('./configs/simulation/config.toml')
run(config, device)