1- import os
1+ import pickle
22import shutil
33from dataclasses import dataclass
44from pathlib import Path
@@ -31,6 +31,90 @@ class CustomComponentRegisterable:
3131 custom_config : type
3232
3333
34+ class SteppableProfilerIF :
35+ def __enter__ (self ):
36+ raise NotImplementedError
37+
38+ def __exit__ (self , exc_type , exc_value , traceback ):
39+ raise NotImplementedError
40+
41+ def step (self ):
42+ raise NotImplementedError
43+
44+
45+ class SteppableMemoryProfiler (SteppableProfilerIF ):
46+ MEMORY_SNAPSHOT_MAX_ENTRIES = 100_000
47+
48+ def __init__ (self , memory_snapshot_path : Path , num_wait_steps : int , num_warmup_steps : int , num_active_steps : int ):
49+ self ._memory_snapshot_path = memory_snapshot_path
50+ self ._curr_step = None
51+ self ._num_wait_steps = num_wait_steps
52+ self ._num_warmup_steps = num_warmup_steps
53+ self ._num_active_steps = num_active_steps
54+
55+ def __enter__ (self ):
56+ self ._curr_step = 0
57+ # start recording memory history if there is no wait / warmup steps
58+ if self ._curr_step == self ._num_wait_steps + self ._num_warmup_steps and self ._num_active_steps > 0 :
59+ torch .cuda .memory ._record_memory_history (max_entries = SteppableMemoryProfiler .MEMORY_SNAPSHOT_MAX_ENTRIES )
60+ return self
61+
62+ def __exit__ (self , exc_type , exc_value , traceback ):
63+ if self ._curr_step is None :
64+ raise RuntimeError ("SteppableMemoryProfilerContext exited without being entered" )
65+ if self ._curr_step < self ._num_wait_steps + self ._num_warmup_steps + self ._num_active_steps :
66+ # if we exit before finishing all steps, dump the memory snapshot
67+ raise RuntimeError ("SteppableMemoryProfilerContext exited before finishing all steps" )
68+ return
69+
70+ def step (self ):
71+ if self ._curr_step is None :
72+ raise RuntimeError ("SteppableMemoryProfilerContext.step() called outside of context manager" )
73+ self ._curr_step += 1
74+ if self ._curr_step < self ._num_wait_steps + self ._num_warmup_steps :
75+ return
76+ elif self ._curr_step == self ._num_wait_steps + self ._num_warmup_steps :
77+ torch .cuda .memory ._record_memory_history (max_entries = SteppableMemoryProfiler .MEMORY_SNAPSHOT_MAX_ENTRIES )
78+ elif (
79+ self ._curr_step == self ._num_wait_steps + self ._num_warmup_steps + self ._num_active_steps
80+ and self ._num_active_steps > 0
81+ ):
82+ with open (self ._memory_snapshot_path , "wb" ) as output :
83+ pickle .dump (torch .cuda .memory ._snapshot (), output )
84+
85+
86+ class ProfilerListContext (SteppableProfilerIF ):
87+ def __init__ (self , profiler_cms : list [SteppableProfilerIF ]):
88+ self .profiler_cms = profiler_cms
89+ self ._entered = None
90+
91+ def __enter__ (self ):
92+ if self ._entered is not None :
93+ raise RuntimeError ("ProfilerListContext entered multiple times without exiting" )
94+ self ._entered = []
95+ for profiler_cm in self .profiler_cms :
96+ return_val = profiler_cm .__enter__ ()
97+ if return_val is not None :
98+ self ._entered .append (return_val )
99+ else :
100+ self ._entered .append (profiler_cm )
101+
102+ return self
103+
104+ def __exit__ (self , exc_type , exc_value , traceback ):
105+ if self ._entered is None :
106+ raise RuntimeError ("ProfilerListContext exited without being entered" )
107+ for profiler_cm in self ._entered :
108+ profiler_cm .__exit__ (exc_type , exc_value , traceback )
109+ self ._entered = None
110+
111+ def step (self ):
112+ if self ._entered is None :
113+ raise RuntimeError ("ProfilerListContext.step() called outside of context manager" )
114+ for profiler_cm in self ._entered :
115+ profiler_cm .step ()
116+
117+
34118class ModalitiesProfilerStarter :
35119 """Starter class to run profiling either in single process or distributed mode."""
36120
@@ -71,15 +155,13 @@ def run_distributed(
71155
72156 global_rank = torch .distributed .get_rank ()
73157 world_size = torch .distributed .get_world_size ()
74- local_rank = int (os .environ ["LOCAL_RANK" ])
75158
76159 ModalitiesProfilerStarter ._run_helper (
77160 config_file_path = config_file_path ,
78161 num_measurement_steps = num_measurement_steps ,
79162 num_wait_steps = num_wait_steps ,
80163 num_warmup_steps = num_warmup_steps ,
81164 experiment_folder_path = experiment_root_path / experiment_id ,
82- local_rank = local_rank ,
83165 global_rank = global_rank ,
84166 world_size = world_size ,
85167 profiled_ranks = profiled_ranks ,
@@ -122,7 +204,6 @@ def run_single_process(
122204
123205 global_rank = 0
124206 world_size = 1
125- local_rank = 0
126207 profiled_ranks = [0 ]
127208
128209 ModalitiesProfilerStarter ._run_helper (
@@ -133,7 +214,6 @@ def run_single_process(
133214 experiment_folder_path = experiment_root_path / experiment_id ,
134215 global_rank = global_rank ,
135216 world_size = world_size ,
136- local_rank = local_rank ,
137217 profiled_ranks = profiled_ranks ,
138218 custom_component_registerables = custom_component_registerables ,
139219 )
@@ -161,22 +241,35 @@ def _run_helper(
161241 experiment_folder_path : Path ,
162242 profiled_ranks : list [int ],
163243 global_rank : int ,
164- local_rank : int ,
165244 world_size : int ,
166245 custom_component_registerables : list [CustomComponentRegisterable ] | None = None ,
167246 ):
168- # build profiler
169- profiler_activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ]
170- profile_context_manager = profile (
247+ # build profilers
248+ profiler_activities = [ProfilerActivity .CUDA ] # ProfilerActivity.CPU,
249+ kernel_profiler = profile (
171250 activities = profiler_activities ,
172251 schedule = schedule (wait = num_wait_steps , warmup = num_warmup_steps , active = num_measurement_steps ),
173- record_shapes = True ,
174- profile_memory = True ,
175- with_flops = True ,
176- with_stack = True ,
177- with_modules = True ,
252+ record_shapes = False ,
253+ profile_memory = False ,
254+ with_flops = False ,
255+ with_stack = False ,
256+ with_modules = False ,
257+ # record_shapes=True,
258+ # profile_memory=True,
259+ # with_flops=True,
260+ # with_stack=True,
261+ # with_modules=True,
178262 )
179263
264+ SteppableMemoryProfiler (
265+ memory_snapshot_path = experiment_folder_path / f"memory_snapshot_ranks_{ world_size } _rank_{ global_rank } .pkl" ,
266+ num_wait_steps = num_wait_steps ,
267+ num_warmup_steps = num_warmup_steps ,
268+ num_active_steps = num_measurement_steps ,
269+ )
270+
271+ profile_context_manager = ProfilerListContext (profiler_cms = [kernel_profiler ]) # , memory_profiler]
272+
180273 # register custom components and build components from config
181274 # workaround to avoid triggering synchronization of experiment id in single process
182275 experiment_id = experiment_folder_path .name if world_size == 1 else None
@@ -199,15 +292,12 @@ def _run_helper(
199292 show_progress = (global_rank == profiled_ranks [0 ]), # only show progress on a single rank that is profiled
200293 )
201294 trace_output_path = experiment_folder_path / f"profiler_trace_ranks_{ world_size } _rank_{ global_rank } .json"
202- memory_output_path = experiment_folder_path / f"profiler_memory_ranks_{ world_size } _rank_{ global_rank } .html"
203295 summary_output_path = experiment_folder_path / f"profiler_summary_ranks_{ world_size } _rank_{ global_rank } .txt"
204296
205297 ModalitiesProfiler .export_profiling_results (
206- profiler_context_manager = profile_context_manager ,
298+ profiler_context_manager = kernel_profiler ,
207299 trace_output_path = trace_output_path ,
208- memory_output_path = memory_output_path ,
209300 summary_output_path = summary_output_path ,
210- local_rank = local_rank ,
211301 global_rank = global_rank ,
212302 profiled_ranks = profiled_ranks ,
213303 )
@@ -218,7 +308,7 @@ class ModalitiesProfiler:
218308 def profile (
219309 steppable_component : SteppableComponentIF ,
220310 num_total_steps : int ,
221- profile_context_manager : torch . profiler . profile ,
311+ profile_context_manager : SteppableProfilerIF ,
222312 show_progress : bool = False ,
223313 ) -> None :
224314 """Profile a steppable component using the provided profiler context manager.
@@ -243,10 +333,8 @@ def profile(
243333 def export_profiling_results (
244334 profiler_context_manager : torch .profiler .profile ,
245335 trace_output_path : Path ,
246- memory_output_path : Path ,
247336 summary_output_path : Path ,
248337 global_rank : int ,
249- local_rank : int ,
250338 profiled_ranks : list [int ],
251339 ) -> None :
252340 """Export profiling results to specified output paths if the current rank is in profiled_ranks.
@@ -263,8 +351,6 @@ def export_profiling_results(
263351 if global_rank in profiled_ranks :
264352 logger .info (f"Saving profiling results for rank { global_rank } ..." )
265353 profiler_context_manager .export_chrome_trace (trace_output_path .as_posix ())
266- device = local_rank if local_rank is not None else None
267- profiler_context_manager .export_memory_timeline (memory_output_path .as_posix (), device = device )
268354 table = profiler_context_manager .key_averages ().table ()
269355 with open (summary_output_path , "w" , encoding = "utf-8" ) as f :
270356 f .write (table )
0 commit comments