-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathcheckpoint_utils.py
More file actions
46 lines (36 loc) · 1.42 KB
/
checkpoint_utils.py
File metadata and controls
46 lines (36 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import re
import glob
import torch
def find_latest_checkpoint(output_dir, checkpoint_prefix):
"""Find the latest checkpoint file and return (path, step) or (None, 0)."""
if not checkpoint_prefix:
return None, 0
patterns = [
os.path.join(output_dir, f"{checkpoint_prefix}*.pt"),
os.path.join(output_dir, f"{checkpoint_prefix}_step*.pt"),
]
checkpoint_steps = {}
for pattern in patterns:
for path in glob.glob(pattern):
match = re.search(r'step(\d+)\.pt$', path)
if match:
checkpoint_steps[path] = int(match.group(1))
if not checkpoint_steps:
return None, 0
latest = max(checkpoint_steps, key=checkpoint_steps.get)
return latest, checkpoint_steps[latest]
def load_checkpoint(path, model, config):
"""Load checkpoint and verify config matches. Returns checkpoint dict."""
checkpoint = torch.load(path, map_location='cpu', weights_only=False)
# Verify config matches
saved_config = checkpoint.get('config', {})
for key, value in config.items():
saved_value = saved_config.get(key)
if saved_value != value:
raise ValueError(
f"Config mismatch! {key}: saved={saved_value}, current={value}. "
f"Use a fresh output_dir or delete old checkpoints."
)
model.load_state_dict(checkpoint['model_state_dict'])
return checkpoint