Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,5 @@ jobs:
- name: Test Tokamax ops with unittest
# TODO: Disable benchmarking tests for presubmit until container is fixed
run: |
pytest -s -vv -m "not long" --ignore-glob="tokamax/_src/ops/*" --ignore-glob="tokamax/_src/benchmarking_test*" tokamax
pytest -s -vv -m "not long" --ignore-glob="tokamax/_src/ops/*" tokamax

60 changes: 35 additions & 25 deletions tokamax/_src/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,9 @@

PyTree = Any

# Timer functions return the time delta in ms and a dictionary of metadata.
Timer: TypeAlias = Callable[[bool], tuple[float, dict[str, Any]]]

T = TypeVar('T')
# Timer functions return the time delta in ms and a dictionary of metadata.
Timer: TypeAlias = Callable[[T], tuple[list[float], dict[str, Any]]]
RetT: TypeAlias = T | list[jax.Array] | tuple[T, list[jax.Array]]

TimingMethod: TypeAlias = Literal[
Expand Down Expand Up @@ -386,39 +385,52 @@ def vjp_full(arrays: list[jax.Array]) -> tuple[T, list[jax.Array]]:
return func, arrays


def wallclock_timer(f: Callable[[T], Any], args: T) -> Timer:
def timer(_):
def wallclock_timer(f: Callable[[T], Any], iterations: int) -> Timer[T]:

def timer(args):
jax.block_until_ready(f(args)) # Warmup.
start_time = time.perf_counter()
jax.block_until_ready(f(args))
return (time.perf_counter() - start_time) * 10**3, {}
times = []
for _ in range(iterations):
start_time = time.perf_counter()
jax.block_until_ready(f(args))
times.append((time.perf_counter() - start_time) * 10**3)
return times, {}

return timer


def cupti_timer(f: Callable[[T], Any], args: T) -> Timer:
timer = profiler.Cupti(finalize=False).measure(f)
return lambda _: (timer(args)[1], {})
def cupti_timer(f: Callable[[T], Any], iterations: int) -> Timer[T]:
timer = profiler.Cupti().measure(f, iterations=iterations)
return lambda args: (timer(args)[1], {})


def xprof_timer(f: Callable[[T], Any], args: T) -> Timer:
def timer(return_metadata):
def xprof_timer( # pylint: disable=missing-function-docstring
f: Callable[[T], Any], iterations: int, *, return_metadata: bool = True
) -> Timer[T]:
def timer(args):
jax.block_until_ready(f(args)) # Warmup.
with XprofProfileSession(hermetic=not return_metadata) as profile:
jax.block_until_ready(f(args))

metadata = dict(xprof_url=profile.xprof_url) if return_metadata else {}
return profile.total_op_time / datetime.timedelta(milliseconds=1), metadata
times = []
metadata = {}
for i in range(iterations):
hermetic = not (return_metadata and i == iterations - 1)
with XprofProfileSession(hermetic=hermetic) as profile:
jax.block_until_ready(f(args))
times.append(profile.total_op_time / datetime.timedelta(milliseconds=1))

if not hermetic:
metadata = dict(xprof_url=profile.xprof_url)

return times, metadata

return timer


def hermetic_xprof_timer(f: Callable[[T], Any], args: T) -> Timer:
timer = xprof_timer(f, args)
return lambda _: timer(False)
def hermetic_xprof_timer(f: Callable[[T], Any], iterations: int) -> Timer[T]:
return xprof_timer(f, iterations, return_metadata=False)


_TIMERS: dict[str, Callable[[Callable[[T], Any], T], Timer]] = {
_TIMERS: dict[str, Callable[[Callable[[T], Any], int], Timer[T]]] = {
'wallclock': wallclock_timer,
'cupti': cupti_timer,
'xprof': xprof_timer,
Expand Down Expand Up @@ -488,13 +500,11 @@ def runner(
if platform not in ('gpu', 'tpu'):
raise ValueError('XProf profiling is only supported on GPU or TPU.')

timer = _TIMERS[method](f_compiled, x)
times = [timer(False)[0] for _ in range(iterations - 1)]
dt, metadata = timer(True) # Capture metadata on last iteration.
times, metadata = _TIMERS[method](f_compiled, iterations)(x)
return BenchmarkData(
lower_time_ms=lowering_time * 10**3,
compile_time_ms=compile_time * 10**3,
evaluation_times_ms=(*times, dt),
evaluation_times_ms=tuple(times),
peak_memory_mb=peak_mem_mb,
metadata=metadata,
)
Expand Down
Loading