-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathim_examples.py
More file actions
52 lines (44 loc) · 1.44 KB
/
Copy pathim_examples.py
File metadata and controls
52 lines (44 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torchvision
from PIL import Image
from torchvision import transforms, datasets
from augment_dataset import create_transforms, load_data
transform = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor()
])
batch_size = 1
DATASET_NAME = "CIFAR10"
augmentation_type = "Sharpness"
severity = 25
augment_sign = True
g = torch.Generator()
g.manual_seed(1)
transforms_preprocess, transforms_augmentation = create_transforms(
random_cropping=False,
aggressive_augmentation=True,
custom=True,
augmentation_name=augmentation_type,
augmentation_severity=severity,
augmentation_sign=augment_sign,
dataset_name=DATASET_NAME
)
print(f"Preprocess transforms: {transforms_preprocess}\nAugmentation transforms: {transforms_augmentation}")
trainset, testset = load_data(
transforms_preprocess=transforms_preprocess,
transforms_augmentation=transforms_augmentation,
dataset_name=DATASET_NAME
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=False)
classes = trainset.dataset.classes
images, labels, confidences = next(iter(trainloader))
to_pil = transforms.ToPILImage()
resize = transforms.Resize(256)
im = to_pil(images[0])
im = resize(im)
im.show()
if augment_sign:
im.save(f"final_plots/image_examples/augmented_{augmentation_type}_min_{severity}.png")
else:
im.save(f"final_plots/image_examples/augmented_{augmentation_type}_plus_{severity}.png")