-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathload_image_dataset.py
More file actions
120 lines (106 loc) · 3.83 KB
/
Copy pathload_image_dataset.py
File metadata and controls
120 lines (106 loc) · 3.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from pathlib import Path
import torchvision
import cub_dataset
import stanford_dogs_dataset
def get_vision_dataset(
config,
per_cls_indices=False,
preprocess=None,
num_images_per_cls=None,
):
if config["dataset"] == "cub":
cub_path = Path(config["root"], "cub")
if not cub_path.exists():
cub_path = Path(config["root"])
split_ = config["split"]
dataset = cub_dataset.Cub2011(
root=cub_path.as_posix(),
train=split_ == "train",
transform=preprocess,
)
image_and_targets = list()
for idx in range(len(dataset)):
sample = dataset.data.iloc[idx]
path = Path(dataset.root, dataset.base_folder, sample.filepath).as_posix()
target = sample.target - 1
image_and_targets.append((path, target))
elif config["dataset"] == "fgvc_aircraft":
split_ = config["split"]
dataset = torchvision.datasets.FGVCAircraft(
root=config["root"], transform=preprocess, split=split_
)
image_and_targets = [
(
Path(dataset._image_files[idx]).as_posix(),
dataset._labels[idx],
)
for idx in range(len(dataset))
]
elif config["dataset"] == "flowers":
split_ = config["split"]
dataset = torchvision.datasets.Flowers102(
root=config["root"], transform=preprocess, split=split_
)
image_and_targets = [
(Path(dataset._image_files[idx]).as_posix(), dataset._labels[idx])
for idx in range(len(dataset))
]
elif config["dataset"] == "pets":
split_ = config["split"]
dataset = torchvision.datasets.OxfordIIITPet(
root=config["root"], transform=preprocess, split=split_
)
image_and_targets = [
(Path(dataset._images[idx]).as_posix(), dataset._labels[idx])
for idx in range(len(dataset))
]
elif config["dataset"] == "stanford_cars":
split_ = config["split"]
dataset = torchvision.datasets.StanfordCars(
root=config["root"], transform=preprocess, split=split_
)
image_and_targets = [
(
Path(dataset._samples[idx][0]).as_posix(),
dataset._samples[idx][1],
)
for idx in range(len(dataset))
]
elif config["dataset"] == "stanford_dogs":
split_ = config["split"]
dataset = stanford_dogs_dataset.StanfordDogs(
root=config["root"], transform=preprocess, split=split_
)
image_and_targets = [
(
Path(dataset.root.joinpath("Images", dataset.files[idx])).as_posix(),
dataset.labels[idx],
)
for idx in range(len(dataset))
]
else:
raise ValueError
per_cls_img_indices = None
if per_cls_indices:
file2idx = dict()
per_cls_files = dict()
for rec_idx in range(len(dataset)):
image_path, label = image_and_targets[rec_idx]
file2idx[image_path] = rec_idx
if label not in per_cls_files:
per_cls_files[label] = list()
per_cls_files[label].append(image_path)
for k in range(len(per_cls_files)):
per_cls_files[k] = list(sorted(per_cls_files[k]))
per_cls_img_indices = dict()
for k in per_cls_files.keys():
if (
num_images_per_cls is not None
and num_images_per_cls != -1
and num_images_per_cls != 0
):
chosen_files = per_cls_files[k][:num_images_per_cls]
else:
chosen_files = per_cls_files[k]
per_cls_img_indices[k] = [file2idx[f] for f in chosen_files]
return dataset, per_cls_img_indices