-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_linear.py
More file actions
294 lines (246 loc) · 15.6 KB
/
Copy pathgenerate_linear.py
File metadata and controls
294 lines (246 loc) · 15.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import argparse
import os
import torch
import numpy as np
import json
from tqdm import tqdm
import nibabel as nib
import torch.amp
from flowlet.models import WaveletFlowMatching
from flowlet.generation.generator import calculate_crop_params
from flowlet.utils import setup_logging, set_seed, get_logger
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Generate Samples with Linearly Interpolated Age using FlowLet")
# --- Model/Checkpoint Args ---
parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint (.pth file, e.g., fmw_best.pth).")
parser.add_argument("--config_path", type=str, default=None, help="(Optional) Path to the model configuration JSON file. If not provided, attempts to infer from checkpoint_path directory.")
# --- Generation Args ---
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the generated NIfTI files.")
parser.add_argument("--num_total_samples", type=int, default=3000, help="Total number of samples to generate across the Age range.")
parser.add_argument("--min_age", type=float, required=True, help="The minimum original Age for the linear interpolation range.")
parser.add_argument("--max_age", type=float, required=True, help="The maximum original Age for the linear interpolation range.")
parser.add_argument("--condition_ranges_path", type=str, default=None, help="Path to JSON file containing condition ranges (min/max) used during *training* for normalization. Crucial for correct Age mapping.")
parser.add_argument("--save_size", type=int, nargs=3, default=[91, 109, 91], metavar=('D', 'H', 'W'), help="Spatial size to crop generated images to before saving.")
parser.add_argument("--model_input_size", type=int, nargs=3, default=None, metavar=('D', 'H', 'W'), help="Model's expected padded input size (D, H, W). Required if not found in config.")
parser.add_argument("--filename_prefix", type=str, default="FlowLet", help="Prefix for the output filenames.")
parser.add_argument("--num_flow_steps", type=int, default=10, help="Number of flow steps for the model. Default is 10.")
# --- System Args ---
parser.add_argument("--seed", type=int, default=42, help="Random seed for generation noise (for reproducibility).")
parser.add_argument("--device", type=str, default="cuda", help="Device to use ('cuda' or 'cpu').")
args = parser.parse_args()
if args.min_age >= args.max_age:
parser.error("--min_age must be strictly less than --max_age")
return args
def load_config_from_checkpoint_dir(checkpoint_path):
"""Tries to find and load a config.json from the checkpoint's directory."""
config_path = os.path.join(os.path.dirname(checkpoint_path), "config.json")
if os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = json.load(f)
logger.info(f"Loaded model configuration from {config_path}")
return config
except Exception as e:
logger.warning(f"Found config.json at {config_path} but failed to load: {e}. Need manual config.")
return None
else:
logger.warning(f"No config.json found in checkpoint directory ({os.path.dirname(checkpoint_path)}). Need manual config.")
return None
def main():
args = parse_args()
# --- Setup ---
setup_logging(log_dir=args.output_dir, filename_prefix="flowlet_generate_linear_ablation") # Log to main output dir
logger.info(f"Starting linear Age generation for ablation study.")
logger.info(f"Generating {args.num_total_samples} samples for Age range [{args.min_age:.2f}, {args.max_age:.2f}] across different modes.")
logger.info(f"Base output directory for modes: {args.output_dir}")
set_seed(args.seed)
device = torch.device(args.device if torch.cuda.is_available() and args.device == "cuda" else "cpu")
logger.info(f"Using device: {device}")
# --- Load Configuration ---
model_config = None
if args.config_path and os.path.exists(args.config_path):
try:
with open(args.config_path, 'r') as f:
model_config = json.load(f)
logger.info(f"Loaded model configuration from specified path: {args.config_path}")
except Exception as e:
logger.error(f"Failed to load specified config file {args.config_path}: {e}", exc_info=True)
return
elif args.checkpoint_path:
model_config = load_config_from_checkpoint_dir(args.checkpoint_path)
if model_config is None:
logger.error("Model configuration could not be loaded. Provide a valid --config_path, or "
"ensure a config.json sits next to the checkpoint. Refusing to guess the U-Net "
"architecture, since a mismatch with the checkpoint would corrupt the loaded weights.")
return
# Extract necessary info from config
try:
model_input_size = tuple(model_config['model_input_size'])
condition_vars = model_config.get('condition_vars', [])
if 'Age' not in condition_vars:
logger.warning(f"'Age' not found in model's condition_vars: {condition_vars}. Model might not be Age-conditional.")
attention_res = tuple(map(int, model_config['unet_attention_res'].split(',')))
channel_mult = tuple(map(int, model_config['unet_channel_mult'].split(',')))
# Use num_flow_steps from args if provided, otherwise from config, otherwise default
num_flow_steps_model_init = args.num_flow_steps if hasattr(args, 'num_flow_steps') and args.num_flow_steps is not None else model_config.get('num_flow_steps', 10)
except KeyError as e:
logger.error(f"Missing essential key in loaded/constructed config: {e}")
return
except Exception as e:
logger.error(f"Error parsing configuration values: {e}")
return
# --- Load Condition Ranges (CRITICAL for Normalization) ---
condition_ranges = None
if args.condition_ranges_path and os.path.exists(args.condition_ranges_path):
try:
with open(args.condition_ranges_path, 'r') as f:
condition_ranges = json.load(f)
logger.info(f"Loaded condition ranges from: {args.condition_ranges_path}")
except Exception as e:
logger.warning(f"Failed to load condition ranges file {args.condition_ranges_path}: {e}.", exc_info=True)
else:
# Try finding ranges in checkpoint dir
ranges_path_alt = os.path.join(os.path.dirname(args.checkpoint_path), "condition_ranges.json")
if os.path.exists(ranges_path_alt):
try:
with open(ranges_path_alt, 'r') as f:
condition_ranges = json.load(f)
logger.info(f"Loaded condition ranges from checkpoint directory: {ranges_path_alt}")
except Exception as e:
logger.warning(f"Failed to load condition_ranges.json from checkpoint dir: {e}.", exc_info=True)
if condition_ranges is None or 'Age' not in condition_ranges:
logger.error("Condition ranges for 'Age' could not be loaded. These are REQUIRED to normalize the target Age for the model. Please provide a valid --condition_ranges_path pointing to the JSON file saved during training.")
return
else:
logger.info(f"Using Age normalization range: Min={condition_ranges['Age']['min']:.2f}, Max={condition_ranges['Age']['max']:.2f}")
# --- Load Model ---
logger.info(f"Loading model checkpoint from: {args.checkpoint_path}")
if not os.path.exists(args.checkpoint_path):
logger.error(f"Checkpoint file not found: {args.checkpoint_path}")
return
try:
ckpt = torch.load(args.checkpoint_path, map_location=device)
condition_dims_dict = {var: 1 for var in condition_vars} if condition_vars else {}
unet_args = {
"in_channels": 8, "model_channels": model_config.get('unet_model_channels', 128), "out_channels": 8,
"num_res_blocks": model_config.get('unet_num_res_blocks', 2),
"attention_resolutions": attention_res,
"dropout": model_config.get('unet_dropout', 0.1),
"channel_mult": channel_mult,
"conv_resample": True, "dims": 3,
"use_checkpoint": model_config.get('use_checkpointing', False), # Should be False for inference
"num_heads": model_config.get('unet_num_heads', 8),
"num_head_channels": model_config.get('unet_num_head_channels', -1),
"use_scale_shift_norm": True, "resblock_updown": True,
"condition_dims": condition_dims_dict,
"condition_embedding_dim": model_config.get('condition_embedding_dim', 512),
"use_xformers": model_config.get('use_xformers', True),
"use_cross_attention": not model_config.get('unet_disable_cross_attn', False) and bool(condition_dims_dict),
"norm_num_groups": model_config.get('unet_norm_num_groups', 32), "norm_eps": 1e-6,
}
# Pass num_flow_steps from args/config to model for its default sampling behavior
wfm_model = WaveletFlowMatching(u_net_args=unet_args, num_flow_steps=num_flow_steps_model_init).to(device)
wfm_model.flow_net.use_checkpoint = False
state_dict = ckpt.get("flow_net_state_dict", ckpt.get("model_state_dict", ckpt))
if not state_dict: raise KeyError("Could not find a model state dictionary in the checkpoint.")
is_currently_compiled = hasattr(wfm_model.flow_net, '_orig_mod')
is_saved_compiled = any(k.startswith('_orig_mod.') for k in state_dict.keys())
if not is_currently_compiled and is_saved_compiled:
logger.info("Saved state_dict is compiled. Removing '_orig_mod.' prefix.")
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
wfm_model.flow_net.load_state_dict(state_dict)
wfm_model.eval()
logger.info(f"Model loaded successfully from epoch {ckpt.get('epoch', -1)+1}")
except Exception as e:
logger.error(f"Failed to load model: {e}", exc_info=True)
return
# --- Prepare Cropping ---
do_crop = False
crop_indices = None
save_size_tuple = tuple(args.save_size)
if save_size_tuple != model_input_size:
try:
crop_indices = calculate_crop_params(save_size_tuple, model_input_size)
do_crop = True
logger.info(f"Will crop generated images from {model_input_size} to {save_size_tuple}.")
except ValueError as e:
logger.error(f"Cannot crop: {e}. Saving full size {model_input_size}.")
do_crop = False
# --- Generate Target Ages ---
target_ages = np.linspace(args.min_age, args.max_age, args.num_total_samples)
logger.info(f"Generated {len(target_ages)} target ages using linspace.")
# --- Generation Loop ---
train_min_age = condition_ranges['Age']['min']
train_max_age = condition_ranges['Age']['max']
age_range_norm = train_max_age - train_min_age # Renamed to avoid conflict
if age_range_norm <= 0:
logger.error(f"Invalid training Age range from condition_ranges.json: [{train_min_age}, {train_max_age}]")
return
# Define generation modes for ablation
generation_modes = {
"baseline": {"disable_cross_attn": False, "disable_cond_film": False, "suffix": "Baseline"},
"film_only": {"disable_cross_attn": True, "disable_cond_film": False, "suffix": "FiLM_Only"},
"crossattn_only": {"disable_cross_attn": False, "disable_cond_film": True, "suffix": "CrossAttn_Only"},
"unconditional": {"disable_cross_attn": True, "disable_cond_film": True, "suffix": "Unconditional"},
}
for mode_name, mode_flags in generation_modes.items():
logger.info(f"--- Generating samples for mode: {mode_name} ---")
# Create a unique subdirectory for this mode's outputs
current_output_dir_name = f"{args.filename_prefix}_{mode_flags['suffix']}_AgeLin_{args.min_age:.1f}-{args.max_age:.1f}_N{args.num_total_samples}_Steps{args.num_flow_steps}"
current_output_dir = os.path.join(args.output_dir, current_output_dir_name)
os.makedirs(current_output_dir, exist_ok=True)
logger.info(f"Output for this mode: {current_output_dir}")
pbar = tqdm(enumerate(target_ages), total=args.num_total_samples, desc=f"Generating {mode_name}")
for i, target_age in pbar:
pbar.set_postfix({"Age": f"{target_age:.2f}"})
# --- Prepare Condition for this sample ---
norm_age = (target_age - train_min_age) / age_range_norm
norm_age_clipped = np.clip(norm_age, 0.0, 1.0)
if norm_age != norm_age_clipped:
logger.debug(f"Target Age {target_age:.2f} normalized to {norm_age:.3f}, clipped to {norm_age_clipped:.3f}.")
norm_age = norm_age_clipped
model_conditions = {}
if 'Age' in wfm_model.condition_dims:
model_conditions['Age'] = torch.tensor([[norm_age]], dtype=torch.float32, device=device)
for k_model, dim in wfm_model.condition_dims.items():
if k_model != 'Age': # Default other conditions to 0.5 normalized
model_conditions[k_model] = torch.full((1, dim), 0.5, dtype=torch.float32, device=device)
# --- Generate Single Sample ---
try:
with torch.no_grad(), torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')):
synthetic_images = wfm_model.sample(
num_samples=1,
model_output_size=model_input_size,
conditions_dict=model_conditions,
disable_cross_attn_inference=mode_flags["disable_cross_attn"],
disable_cond_film_inference=mode_flags["disable_cond_film"]
)
# --- Post-process and Save ---
if do_crop:
try:
(d_start, d_end), (h_start, h_end), (w_start, w_end) = crop_indices
synthetic_images_final = synthetic_images[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
if synthetic_images_final.shape[-3:] != save_size_tuple:
logger.warning(f"Cropped shape {synthetic_images_final.shape[-3:]} != target {save_size_tuple}. Check cropping logic.")
except Exception as e:
logger.error(f"Error during cropping sample {i} ({mode_name}): {e}. Saving uncropped.", exc_info=True)
synthetic_images_final = synthetic_images
else:
synthetic_images_final = synthetic_images
img_to_save = (synthetic_images_final[0, 0].cpu().float() + 1.0) / 2.0
img_to_save = torch.clamp(img_to_save, 0.0, 1.0)
img_data = img_to_save.numpy().astype(np.float32)
img_nifti = nib.Nifti1Image(img_data, affine=np.eye(4))
sample_index_str = str(i).zfill(len(str(args.num_total_samples - 1)))
# Use original filename_prefix, add mode suffix here
filename = f"{args.filename_prefix}_{mode_flags['suffix']}_AGE_{target_age:.2f}_sample_{sample_index_str}.nii.gz"
save_path = os.path.join(current_output_dir, filename)
nib.save(img_nifti, save_path)
except Exception as e:
logger.error(f"Failed to generate or save sample {i} (Age: {target_age:.2f}, Mode: {mode_name}): {e}", exc_info=True)
continue
logger.info(f"Finished generating samples for mode {mode_name}.")
logger.info("Linear Age generation (all modes) finished.")
if __name__ == "__main__":
main()