Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,7 @@ def export_networks(self, epoch):

if not self.opt.model_type in [
"palette",
"cm",
]: # Note: export is for generators from GANs only at the moment
]: # Note: export is supported only for GAN and consistency models
# For export
from util.export import export

Expand Down
4 changes: 1 addition & 3 deletions models/cm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self, opt, rank):
self.visual_names.append(visual_outputs)

# Define network
# TODO this is hard coded
opt.alg_palette_sampling_method = ""
opt.alg_diffusion_cond_embed = opt.alg_diffusion_cond_image_creation
opt.alg_diffusion_cond_embed_dim = 256
Expand Down Expand Up @@ -148,7 +149,6 @@ def __init__(self, opt, rank):
self.iter_calculator_init()

def set_input(self, data):

if (
len(data["A"].to(self.device).shape) == 5
): # we're using temporal successive frames
Expand Down Expand Up @@ -203,7 +203,6 @@ def set_input(self, data):
self.real_B = self.gt_image

def compute_cm_loss(self):

y_0 = self.gt_image # ground truth
y_cond = self.cond_image # conditioning
mask = self.mask
Expand All @@ -224,7 +223,6 @@ def compute_cm_loss(self):
self.loss_G_tot = loss * self.opt.alg_diffusion_lambda_G

def inference(self):

if hasattr(self.netG_A, "module"):
netG = self.netG_A.module
else:
Expand Down
6 changes: 2 additions & 4 deletions models/modules/cm_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ def forward(
mask=None,
x_cond=None,
):

num_timesteps = improved_timesteps_schedule(
self.current_t,
total_training_steps,
Expand Down Expand Up @@ -344,7 +343,6 @@ def forward(
)

def restoration(self, y, y_cond, sigmas, mask, clip_denoised=True):

if mask is not None:
mask = torch.clamp(
mask, min=0.0, max=1.0
Expand All @@ -370,7 +368,6 @@ def restoration(self, y, y_cond, sigmas, mask, clip_denoised=True):
x = x * mask + (1 - mask) * y

for sigma in sigmas[1:]:

sigma = torch.full((x.shape[0],), sigma, dtype=x.dtype, device=x.device)
x = x + pad_dims_like(
(sigma**2 - self.sigma_min**2) ** 0.5, x
Expand All @@ -393,5 +390,6 @@ def restoration(self, y, y_cond, sigmas, mask, clip_denoised=True):
return x

def embed_sigmas(self, sigmas):
emb = self.cm_cond_embed(sigmas).squeeze(dim=[2, 3])
# squeeze twice to be onnx compatible
emb = self.cm_cond_embed(sigmas).squeeze(dim=3).squeeze(dim=2)
return emb
52 changes: 47 additions & 5 deletions util/export.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,34 @@
import torch

from models import gan_networks, diffusion_networks

from models import gan_networks

class ConsistencyWrapper(torch.nn.Module):
"""
Consistency model wrapper for onnx & jit trace
"""

def __init__(self, model, sigmas):
super().__init__()
self.model = model
self.sigmas = sigmas

def forward(self, x, mask):
return self.model.restoration(x, None, self.sigmas, mask)


def export(opt, cuda, model_in_file, model_out_file, opset_version, export_type):
model = gan_networks.define_G(**vars(opt))
if opt.model_type == "palette":
raise ValueError('export() is not supported for model type "palette"')

if opt.model_type == "cm":
opt.alg_palette_sampling_method = ""
opt.alg_diffusion_cond_embed = opt.alg_diffusion_cond_image_creation
opt.alg_diffusion_cond_embed_dim = 256

model = diffusion_networks.define_G(**vars(opt))
else:
model = gan_networks.define_G(**vars(opt))

model.eval()
model.load_state_dict(torch.load(model_in_file))
Expand All @@ -18,21 +41,40 @@ def export(opt, cuda, model_in_file, model_out_file, opset_version, export_type)
device = "cuda"
else:
device = "cpu"
dummy_input = torch.randn(

dummy_image = torch.randn(
1, opt.model_input_nc, opt.data_crop_size, opt.data_crop_size, device=device
)
dummy_inputs = [dummy_image]

if opt.model_type == "cm":
# at the moment, consistency models have two inputs: origin image and mask
# TODO allow to change number of sigmas
sigmas = [80.0, 24.4, 5.84, 0.9, 0.661]
model = ConsistencyWrapper(model, sigmas)
dummy_inputs += [
torch.randn(
1,
opt.model_input_nc,
opt.data_crop_size,
opt.data_crop_size,
device=device,
),
]

dummy_inputs = tuple(dummy_inputs)

if export_type == "onnx":
torch.onnx.export(
model,
dummy_input,
dummy_inputs,
model_out_file,
verbose=False,
opset_version=opset_version,
)

elif export_type == "jit":
jit_model = torch.jit.trace(model, dummy_input)
jit_model = torch.jit.trace(model, dummy_inputs)
jit_model.save(model_out_file)

else:
Expand Down
2 changes: 0 additions & 2 deletions util/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def get_activations(
# This happens if you choose a dimensionality not equal 2048.

if len(pred.shape) == 4:

if pred.size(2) != 1 or pred.size(3) != 1:
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))

Expand All @@ -150,7 +149,6 @@ def _compute_statistics_of_dataloader(
nb_max_img=float("inf"),
root=None,
):

if path_sv is not None and os.path.isfile(path_sv):
print("Activations loaded for domain %s, from %s." % (domain, path_sv))
f = torch.load(path_sv)
Expand Down