Skip to content

Commit 03354c1

Browse files
committed
feat: steppable component can now perform backward pass and optimizer steps
1 parent 719e35e commit 03354c1

3 files changed

Lines changed: 124 additions & 25 deletions

File tree

src/modalities/utils/profilers/modalities_profiler.py

Lines changed: 109 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import os
1+
import pickle
22
import shutil
33
from dataclasses import dataclass
44
from pathlib import Path
@@ -31,6 +31,90 @@ class CustomComponentRegisterable:
3131
custom_config: type
3232

3333

34+
class SteppableProfilerIF:
35+
def __enter__(self):
36+
raise NotImplementedError
37+
38+
def __exit__(self, exc_type, exc_value, traceback):
39+
raise NotImplementedError
40+
41+
def step(self):
42+
raise NotImplementedError
43+
44+
45+
class SteppableMemoryProfiler(SteppableProfilerIF):
46+
MEMORY_SNAPSHOT_MAX_ENTRIES = 100_000
47+
48+
def __init__(self, memory_snapshot_path: Path, num_wait_steps: int, num_warmup_steps: int, num_active_steps: int):
49+
self._memory_snapshot_path = memory_snapshot_path
50+
self._curr_step = None
51+
self._num_wait_steps = num_wait_steps
52+
self._num_warmup_steps = num_warmup_steps
53+
self._num_active_steps = num_active_steps
54+
55+
def __enter__(self):
56+
self._curr_step = 0
57+
# start recording memory history if there is no wait / warmup steps
58+
if self._curr_step == self._num_wait_steps + self._num_warmup_steps and self._num_active_steps > 0:
59+
torch.cuda.memory._record_memory_history(max_entries=SteppableMemoryProfiler.MEMORY_SNAPSHOT_MAX_ENTRIES)
60+
return self
61+
62+
def __exit__(self, exc_type, exc_value, traceback):
63+
if self._curr_step is None:
64+
raise RuntimeError("SteppableMemoryProfilerContext exited without being entered")
65+
if self._curr_step < self._num_wait_steps + self._num_warmup_steps + self._num_active_steps:
66+
# if we exit before finishing all steps, dump the memory snapshot
67+
raise RuntimeError("SteppableMemoryProfilerContext exited before finishing all steps")
68+
return
69+
70+
def step(self):
71+
if self._curr_step is None:
72+
raise RuntimeError("SteppableMemoryProfilerContext.step() called outside of context manager")
73+
self._curr_step += 1
74+
if self._curr_step < self._num_wait_steps + self._num_warmup_steps:
75+
return
76+
elif self._curr_step == self._num_wait_steps + self._num_warmup_steps:
77+
torch.cuda.memory._record_memory_history(max_entries=SteppableMemoryProfiler.MEMORY_SNAPSHOT_MAX_ENTRIES)
78+
elif (
79+
self._curr_step == self._num_wait_steps + self._num_warmup_steps + self._num_active_steps
80+
and self._num_active_steps > 0
81+
):
82+
with open(self._memory_snapshot_path, "wb") as output:
83+
pickle.dump(torch.cuda.memory._snapshot(), output)
84+
85+
86+
class ProfilerListContext(SteppableProfilerIF):
87+
def __init__(self, profiler_cms: list[SteppableProfilerIF]):
88+
self.profiler_cms = profiler_cms
89+
self._entered = None
90+
91+
def __enter__(self):
92+
if self._entered is not None:
93+
raise RuntimeError("ProfilerListContext entered multiple times without exiting")
94+
self._entered = []
95+
for profiler_cm in self.profiler_cms:
96+
return_val = profiler_cm.__enter__()
97+
if return_val is not None:
98+
self._entered.append(return_val)
99+
else:
100+
self._entered.append(profiler_cm)
101+
102+
return self
103+
104+
def __exit__(self, exc_type, exc_value, traceback):
105+
if self._entered is None:
106+
raise RuntimeError("ProfilerListContext exited without being entered")
107+
for profiler_cm in self._entered:
108+
profiler_cm.__exit__(exc_type, exc_value, traceback)
109+
self._entered = None
110+
111+
def step(self):
112+
if self._entered is None:
113+
raise RuntimeError("ProfilerListContext.step() called outside of context manager")
114+
for profiler_cm in self._entered:
115+
profiler_cm.step()
116+
117+
34118
class ModalitiesProfilerStarter:
35119
"""Starter class to run profiling either in single process or distributed mode."""
36120

@@ -71,15 +155,13 @@ def run_distributed(
71155

72156
global_rank = torch.distributed.get_rank()
73157
world_size = torch.distributed.get_world_size()
74-
local_rank = int(os.environ["LOCAL_RANK"])
75158

76159
ModalitiesProfilerStarter._run_helper(
77160
config_file_path=config_file_path,
78161
num_measurement_steps=num_measurement_steps,
79162
num_wait_steps=num_wait_steps,
80163
num_warmup_steps=num_warmup_steps,
81164
experiment_folder_path=experiment_root_path / experiment_id,
82-
local_rank=local_rank,
83165
global_rank=global_rank,
84166
world_size=world_size,
85167
profiled_ranks=profiled_ranks,
@@ -122,7 +204,6 @@ def run_single_process(
122204

123205
global_rank = 0
124206
world_size = 1
125-
local_rank = 0
126207
profiled_ranks = [0]
127208

128209
ModalitiesProfilerStarter._run_helper(
@@ -133,7 +214,6 @@ def run_single_process(
133214
experiment_folder_path=experiment_root_path / experiment_id,
134215
global_rank=global_rank,
135216
world_size=world_size,
136-
local_rank=local_rank,
137217
profiled_ranks=profiled_ranks,
138218
custom_component_registerables=custom_component_registerables,
139219
)
@@ -161,22 +241,35 @@ def _run_helper(
161241
experiment_folder_path: Path,
162242
profiled_ranks: list[int],
163243
global_rank: int,
164-
local_rank: int,
165244
world_size: int,
166245
custom_component_registerables: list[CustomComponentRegisterable] | None = None,
167246
):
168-
# build profiler
169-
profiler_activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
170-
profile_context_manager = profile(
247+
# build profilers
248+
profiler_activities = [ProfilerActivity.CUDA] # ProfilerActivity.CPU,
249+
kernel_profiler = profile(
171250
activities=profiler_activities,
172251
schedule=schedule(wait=num_wait_steps, warmup=num_warmup_steps, active=num_measurement_steps),
173-
record_shapes=True,
174-
profile_memory=True,
175-
with_flops=True,
176-
with_stack=True,
177-
with_modules=True,
252+
record_shapes=False,
253+
profile_memory=False,
254+
with_flops=False,
255+
with_stack=False,
256+
with_modules=False,
257+
# record_shapes=True,
258+
# profile_memory=True,
259+
# with_flops=True,
260+
# with_stack=True,
261+
# with_modules=True,
178262
)
179263

264+
SteppableMemoryProfiler(
265+
memory_snapshot_path=experiment_folder_path / f"memory_snapshot_ranks_{world_size}_rank_{global_rank}.pkl",
266+
num_wait_steps=num_wait_steps,
267+
num_warmup_steps=num_warmup_steps,
268+
num_active_steps=num_measurement_steps,
269+
)
270+
271+
profile_context_manager = ProfilerListContext(profiler_cms=[kernel_profiler]) # , memory_profiler]
272+
180273
# register custom components and build components from config
181274
# workaround to avoid triggering synchronization of experiment id in single process
182275
experiment_id = experiment_folder_path.name if world_size == 1 else None
@@ -199,15 +292,12 @@ def _run_helper(
199292
show_progress=(global_rank == profiled_ranks[0]), # only show progress on a single rank that is profiled
200293
)
201294
trace_output_path = experiment_folder_path / f"profiler_trace_ranks_{world_size}_rank_{global_rank}.json"
202-
memory_output_path = experiment_folder_path / f"profiler_memory_ranks_{world_size}_rank_{global_rank}.html"
203295
summary_output_path = experiment_folder_path / f"profiler_summary_ranks_{world_size}_rank_{global_rank}.txt"
204296

205297
ModalitiesProfiler.export_profiling_results(
206-
profiler_context_manager=profile_context_manager,
298+
profiler_context_manager=kernel_profiler,
207299
trace_output_path=trace_output_path,
208-
memory_output_path=memory_output_path,
209300
summary_output_path=summary_output_path,
210-
local_rank=local_rank,
211301
global_rank=global_rank,
212302
profiled_ranks=profiled_ranks,
213303
)
@@ -218,7 +308,7 @@ class ModalitiesProfiler:
218308
def profile(
219309
steppable_component: SteppableComponentIF,
220310
num_total_steps: int,
221-
profile_context_manager: torch.profiler.profile,
311+
profile_context_manager: SteppableProfilerIF,
222312
show_progress: bool = False,
223313
) -> None:
224314
"""Profile a steppable component using the provided profiler context manager.
@@ -243,10 +333,8 @@ def profile(
243333
def export_profiling_results(
244334
profiler_context_manager: torch.profiler.profile,
245335
trace_output_path: Path,
246-
memory_output_path: Path,
247336
summary_output_path: Path,
248337
global_rank: int,
249-
local_rank: int,
250338
profiled_ranks: list[int],
251339
) -> None:
252340
"""Export profiling results to specified output paths if the current rank is in profiled_ranks.
@@ -263,8 +351,6 @@ def export_profiling_results(
263351
if global_rank in profiled_ranks:
264352
logger.info(f"Saving profiling results for rank {global_rank}...")
265353
profiler_context_manager.export_chrome_trace(trace_output_path.as_posix())
266-
device = local_rank if local_rank is not None else None
267-
profiler_context_manager.export_memory_timeline(memory_output_path.as_posix(), device=device)
268354
table = profiler_context_manager.key_averages().table()
269355
with open(summary_output_path, "w", encoding="utf-8") as f:
270356
f.write(table)

src/modalities/utils/profilers/steppable_component_configs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from modalities.config.pydantic_if_types import (
44
PydanticDatasetBatchGeneratorIFType,
55
PydanticLossIFType,
6+
PydanticOptimizerIFType,
67
PydanticPytorchModuleType,
78
)
89

@@ -11,3 +12,4 @@ class SteppableForwardPassConfig(BaseModel):
1112
model: PydanticPytorchModuleType
1213
dataset_batch_generator: PydanticDatasetBatchGeneratorIFType
1314
loss_fn: PydanticLossIFType | None = None
15+
optimizer: PydanticOptimizerIFType | None = None

src/modalities/utils/profilers/steppable_components.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@ class SteppableForwardPass(SteppableComponentIF):
1515
The component is used for profiling.
1616
"""
1717

18-
def __init__(self, model: nn.Module, dataset_batch_generator: DatasetBatchGeneratorIF, loss_fn: Loss | None = None):
18+
def __init__(
19+
self,
20+
model: nn.Module,
21+
dataset_batch_generator: DatasetBatchGeneratorIF,
22+
loss_fn: Loss | None = None,
23+
optimizer: torch.optim.Optimizer | None = None,
24+
):
1925
"""Initializes the SteppableForwardPass component.
2026
2127
Args:
@@ -27,6 +33,7 @@ def __init__(self, model: nn.Module, dataset_batch_generator: DatasetBatchGenera
2733
self.loss_fn = loss_fn
2834
self.dataset_batch_generator = dataset_batch_generator
2935
self.device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
36+
self.optimizer = optimizer
3037

3138
def step(
3239
self,
@@ -37,4 +44,8 @@ def step(
3744
predictions = self.model(batch.samples)
3845
result_batch = InferenceResultBatch(targets=batch.targets, predictions=predictions)
3946
if self.loss_fn is not None:
40-
self.loss_fn(result_batch)
47+
loss = self.loss_fn(result_batch)
48+
loss.backward()
49+
if self.optimizer is not None:
50+
self.optimizer.step()
51+
self.optimizer.zero_grad()

0 commit comments

Comments
 (0)