-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
304 lines (221 loc) · 8.93 KB
/
utils.py
File metadata and controls
304 lines (221 loc) · 8.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import cv2
import os
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from architectures.swin_unet import Swin_Unet
def read_frames(image_folder,):
"""
Read all frames of the image_folder, salve it
in a folder and return.
"""
images = [img for img in os.listdir(image_folder) if img.endswith(".png")]
if images == []:
images = [img for img in os.listdir(image_folder) if img.endswith(".jpg")]
return images
def frame_2_gray(images) -> list:
"""
Recive a list of images and return the same list
but with the images in gray scale.
"""
gray_frames = []
for img in images:
gray_frames.append(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY))
return gray_frames
def frame_2_video(image_folder, video_name, gray=False, frame_rate=16):
"""
Get the path with the frames and the name that video must be
and create and save the video.
"""
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
images = read_frames(image_folder)
frame = cv2.imread(os.path.join(image_folder, images[0]))
height, width, layers = frame.shape
if gray == True:
video = cv2.VideoWriter(video_name, fourcc, frame_rate, (width,height), 0)
else:
video = cv2.VideoWriter(video_name, fourcc, frame_rate, (width,height))
for image in images:
temp_image = cv2.imread(os.path.join(image_folder, image), 1)
if gray == True:
temp_image = cv2.cvtColor(temp_image, cv2.COLOR_BGR2GRAY)
video.write(temp_image)
video.release()
# print("Convertion Done")
################## Losses #####################
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics import PeakSignalNoiseRatio
from torchmetrics.image.fid import FrechetInceptionDistance
from piq import ssim, SSIMLoss
from piq import VIFLoss
from piq import VSILoss
from torch.nn import KLDivLoss
from architectures_losses.vgg_loss import VGGLoss
from architectures_losses.smooth_loss import Loss
# from pytorch_tools import losses
def valid_loss(loss) -> torch.Tensor:
"""
Recive a loss and verify if it is a torch.Tensor,
and return a Tensor if it is a list.
"""
if type(loss) == list:
loss = loss[0]
return loss
def model_losses(losses: list, inputs: list):
"""
losses: List of loss functions to calculate the loss
inputs: List of data to calculate the loss.
return a sum of all losses and List contains all losses individually
"""
dict_losses = {}
# train_losses = []
total_losses = 0
# populate values of each loss
for idx, loss in enumerate(losses):
#loss name
loss_name = str(type(loss)).split(".")[-1].split("'")[0]
if len(dict_losses) == 0:
dict_losses[loss_name] = valid_loss(loss(inputs[0], inputs[1]))
if len(dict_losses) < 2:
dict_losses[loss_name] = valid_loss(loss(inputs[3], inputs[1]))
else:
dict_losses[loss_name] = valid_loss(loss(inputs[3], inputs[2]))
assert len(dict_losses) == len(losses), "Loss and Output has to have same size"
# # create variables for each loss
# for key, value in dict_losses.items():
# exec(f"{key}={value}")
# calculate total loss
for key, value in dict_losses.items():
total_losses += value
return total_losses, dict_losses
def commet_log_metric(experiment, metric_name: str, metric, step: int, me_type="train") -> None:
"""
Generic method to create the experiment log
"""
experiment.log_metric(f"{metric_name}_{me_type}", metric, step=step)
def create_gray_videos(dataset, path_video_save):
images_paths = f"./data/train/{dataset}"
img_classes = os.listdir(images_paths)
os.makedirs(path_video_save, exist_ok=True)
for v_class in img_classes:
# video_name = "./data/videos/gray/sunset_gray.mp4"
# video_name = f"./data/videos/videvo_gray/{v_class}.mp4"
image_folder = f"./data/train/{dataset}/{v_class}"
video_name = f'{path_video_save}{v_class}.mp4'
frame_2_video(image_folder, video_name, True)
assert len(img_classes) == len(os.listdir(path_video_save)), "Created videos must be same amout of files that video classes."
print("Gray videos created")
from architectures.color_model_simple import ColorNetwork
def load_trained_model(model_path, image_size, device):
#old
# model = ColorNetwork(in_channel=1, out_channel=128, stride=2, padding=2,img_size=image_size[0]).to(device)
checkpoint = torch.load(model_path)
#Get the number of channels in the first layer
ch_deep = 128
# ch_deep = checkpoint["inc.double_conv.0.weight"].shape[0]
# model = ColorNetwork(1, 3, image_size[0], ch_deep).to(device)
model = Swin_Unet(net_dimension=ch_deep, c_out=3, img_size=image_size).to(device)
model.load_state_dict(checkpoint)
model.to(device)
return model
def generate_paper_colored_samples(
# dataset = "DAVIS",
root_path_images = f"./Vit-autoencoder/temp_result/DAVIS",
frame_number = "00042",
root_path_destiny = "images_to_paper",
video_name = "rallye.mp4"):
os.makedirs(root_path_destiny, exist_ok=True)
import shutil
list_models = os.listdir(root_path_images)
for model in list_models:
origin_path = f"{root_path_images}/{model}/{video_name}/{frame_number}.jpg"
destiny_path = f"{root_path_destiny}/{video_name}"
os.makedirs(destiny_path, exist_ok=True)
shutil.copy(origin_path, f"{destiny_path}/{model}_{frame_number}.jpg")
# Gerenate grayscale videos of all videos in DAVIS_val
# for class_video in os.listdir("./data/train/DAVIS_val"):
# image_folder = f"./data/train/DAVIS_val/{class_video}"
# path_video_save = "./data/videos/gray/"
# video_name = f'{path_video_save}video_{class_video}.mp4'
# frame_2_video(image_folder, video_name, True)
# image_folder = "./Vit-autoencoder/temp_result/20221207_111837"
# class_video = "breakdance"
# image_folder = f"./data/train/DAVIS_val/{class_video}"
# path_video_save = "./data/videos/gray/"
# video_name = f'{path_video_save}video_{class_video}.mp4'
# frame_2_video(image_folder, video_name, True)
def get_model_time():
from datetime import datetime
#to create the timestamp
dt = datetime.now()
# dt_str = datetime.timestamp(dt)
dt_str = str(dt).replace(':','.')
dt_str = datetime.now().strftime('%Y%m%d_%H%M%S')
return dt_str
def save_losses(dic_losses, filename="losses_network_v1"):
os.makedirs("losses", exist_ok=True)
fout = f"{filename}.csv"
fo = open(fout, "w")
for k, v in dic_losses.items():
fo.write(str(k) + "," +str(float(v.cpu().numpy())) + '\n')
fo.close()
# torch tensor to image
def to_img(x):
x = 0.5 * (x + 1)
x = x.clamp(0, 1)
x = x.view(x.size(0), 3, x.shape[2], x.shape[2])
return x
def create_samples(data, constrative=False):
"""
img: Image with RGB colors (ground truth)
img_gray: Grayscale version of the img (this) variable will be used to be colorized
img_color: the image with color that bt used as example (first at the scene)
"""
# Test if the pos_color must be returned
if len(data) == 4:
img, img_color, next_frame, random_frame = data
if isinstance(img, list):
img, img_color, next_frame, random_frame = img[0], img_color[0], next_frame[0], random_frame[0]
else:
img, img_color, next_frame = data
if isinstance(img, list):
img, img_color, next_frame = img[0], img_color[0], next_frame[0]
img_gray = transforms.Grayscale(num_output_channels=3)(img)
gray_next_frame = transforms.Grayscale(num_output_channels=3)(next_frame)
# img_gray = img[:,:1,:,:]
if constrative:
return img, img_gray, img_color, next_frame
else:
return img, img_gray, img_color, gray_next_frame
def is_notebook():
try:
shell = get_ipython().__class__.__name__
if shell == 'ZMQInteractiveShell':
return True # Jupyter notebook or qtconsole
elif shell == 'TerminalInteractiveShell':
return False # Terminal running IPython
else:
return False # Other type (?)
except NameError:
return False # Probably standard Python interpreter
def scale_0_and_1(tensor):
"""
Recives a tensor and return their values between 0 and 1
"""
tensor_min = tensor.min()
tensor_max = tensor.max()
tensor_rescaled = (tensor - tensor_min) / (tensor_max - tensor_min)
return tensor_rescaled
def plot_images(images):
plt.figure(figsize=(32, 32))
plt.imshow(torch.cat([
torch.cat([i for i in images.cpu()], dim=-1),
], dim=-2).permute(1, 2, 0).cpu())
plt.show()
def resume(model, filename):
model.load_state_dict(torch.load(filename))
def get_criterion_name(criterion):
loss_name = ""
for loss in criterion:
loss_name += "_"+type(loss).__name__
return loss_name