-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpart1_code.py
More file actions
204 lines (154 loc) · 7.39 KB
/
Copy pathpart1_code.py
File metadata and controls
204 lines (154 loc) · 7.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import imageio.v2 as imageio
import time
def positional_encoding(x, num_frequencies=6, incl_input=True):
"""
Apply positional encoding to the input.
Args:
x (torch.Tensor): Input tensor to be positionally encoded.
The dimension of x is [N, D], where N is the number of input coordinates,
and D is the dimension of the input coordinate.
num_frequencies (optional, int): The number of frequencies used in
the positional encoding (default: 6).
incl_input (optional, bool): If True, concatenate the input with the
computed positional encoding (default: True).
Returns:
(torch.Tensor): Positional encoding of the input tensor.
"""
results = []
if incl_input:
results.append(x)
############################# TODO 1(a) BEGIN ############################
# encode input tensor and append the encoded tensor to the list of results.
for L in range(num_frequencies):
results.append(torch.sin((2 ** L) * torch.pi * x))
results.append(torch.cos((2 ** L) * torch.pi * x))
############################# TODO 1(a) END ##############################
return torch.cat(results, dim=-1)
"""1.2 Complete the class model_2d() that will be used to fit the 2D image.
"""
class model_2d(nn.Module):
"""
Define a 2D model comprising of three fully connected layers,
two relu activations and one sigmoid activation.
"""
def __init__(self, filter_size=128, num_frequencies=6):
super().__init__()
############################# TODO 1(b) BEGIN ############################
# for autograder compliance, please follow the given naming for your layers
self.layer_in = nn.Linear(2 * (1+ 2 * num_frequencies), filter_size)
self.layer = nn.Linear(filter_size, filter_size)
self.layer_out = nn.Linear(filter_size, 3)
############################# TODO 1(b) END ##############################
def forward(self, x):
############################# TODO 1(b) BEGIN ############################
# example of forward through a layer: y = self.layer_in(x)
# first ReLu activation of input linear layer
x = F.relu(self.layer_in(x))
# second ReLu activation of the linear layer
x = F.relu(self.layer(x))
# final sigmoid activation of output linear layer
x = torch.sigmoid(self.layer_out(x))
############################# TODO 1(b) END ##############################
return x
def normalize_coord(height, width, num_frequencies=6):
"""
Creates the 2D normalized coordinates, and applies positional encoding to them
Args:
height (int): Height of the image
width (int): Width of the image
num_frequencies (optional, int): The number of frequencies used in
the positional encoding (default: 6).
Returns:
(torch.Tensor): Returns the 2D normalized coordinates after applying positional encoding to them.
"""
############################# TODO 1(c) BEGIN ############################
# Create the 2D normalized coordinates, and apply positional encoding to them
# Create a grid of coordinates in the range [0, 1]
vertical_coords = torch.linspace(0, 1, steps = height + 1)[:-1]
horizontal_coords = torch.linspace(0, 1, steps = width + 1)[:-1]
grid_horizontal, grid_vertical = torch.meshgrid(horizontal_coords, vertical_coords, indexing = 'ij')
grid_horizontal = grid_horizontal.flatten().reshape(-1, 1)
grid_vertical = grid_vertical.flatten().reshape(-1, 1)
x = torch.cat((grid_vertical, grid_horizontal), axis=1)
# positional encoding
embedded_coordinates = positional_encoding(x, num_frequencies, incl_input=True)
embedded_coordinates = embedded_coordinates.float() # ensure it is float
############################# TODO 1(c) END ############################
return embedded_coordinates
"""You need to complete 1.1 and 1.2 first before completing the train_2d_model function. Don't forget to transfer the completed functions from 1.1 and 1.2 to the part1.py file and upload it to the autograder.
Fill the gaps in the train_2d_model() function to train the model to fit the 2D image.
"""
def train_2d_model(test_img, num_frequencies, device, model=model_2d, positional_encoding=positional_encoding, show=True):
# Optimizer parameters
lr = 5e-4
iterations = 10000
height, width = test_img.shape[:2]
# Number of iters after which stats are displayed
display = 2000
# Define the model and initialize its weights.
model2d = model(num_frequencies=num_frequencies)
model2d.to(device)
def weights_init(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
model2d.apply(weights_init)
############################# TODO 1(c) BEGIN ############################
# Define the optimizer
optimizer = torch.optim.Adam(model2d.parameters(), lr=5e-4)
############################# TODO 1(c) END ############################
# Seed RNG, for repeatability
seed = 5670
torch.manual_seed(seed)
np.random.seed(seed)
# Lists to log metrics etc.
psnrs = []
iternums = []
t = time.time()
t0 = time.time()
############################# TODO 1(c) BEGIN ############################
# Create the 2D normalized coordinates, and apply positional encoding to them
norm_coords = normalize_coord(height, width, num_frequencies)
############################# TODO 1(c) END ############################
for i in range(iterations+1):
optimizer.zero_grad()
############################# TODO 1(c) BEGIN ############################
# Run one iteration
pred = model2d(norm_coords).view(height, width, 3)
loss = torch.mean((pred - test_img) ** 2)
# Perform backpropagation and update the model's weights
optimizer.zero_grad() # Clearing gradients again for safety
loss.backward()
optimizer.step()
############################# TODO 1(c) END ############################
# Display images/plots/stats
if i % display == 0 and show:
############################# TODO 1(c) BEGIN ############################
# Calculate psnr
psnr = 10 * torch.log10(1 / loss)
############################# TODO 1(c) END ############################
print("Iteration %d " % i, "Loss: %.4f " % loss.item(), "PSNR: %.2f" % psnr.item(), \
"Time: %.2f secs per iter" % ((time.time() - t) / display), "%.2f secs in total" % (time.time() - t0))
t = time.time()
psnrs.append(psnr.item())
iternums.append(i)
plt.figure(figsize=(13, 4))
plt.subplot(131)
plt.imshow(pred.detach().cpu().numpy())
plt.title(f"Iteration {i}")
plt.subplot(132)
plt.imshow(test_img.cpu().numpy())
plt.title("Target image")
plt.subplot(133)
plt.plot(iternums, psnrs)
plt.title("PSNR")
plt.show()
print('Done!')
torch.save(model2d.state_dict(),'model_2d_' + str(num_frequencies) + 'freq.pt')
plt.imsave('van_gogh_' + str(num_frequencies) + 'freq.png',pred.detach().cpu().numpy())
return pred.detach().cpu()