Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions deepmd/pt/utils/auto_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,54 @@ def is_oom_error(self, e: Exception) -> bool:
e : Exception
Exception
"""
# several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error,
# such as https://github.com/JuliaGPU/CUDA.jl/issues/1924
# (the meaningless error message should be considered as a bug in cusolver)
if (
isinstance(e, RuntimeError)
and (
"CUDA out of memory." in e.args[0]
or "CUDA driver error: out of memory" in e.args[0]
or "cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR" in e.args[0]
# https://github.com/deepmodeling/deepmd-kit/issues/4594
Comment thread
OutisLi marked this conversation as resolved.
or "CUDA error: out of memory" in e.args[0]
)
) or isinstance(e, torch.cuda.OutOfMemoryError):
# Release all unoccupied cached memory
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
return True

if not isinstance(e, RuntimeError):
return False

# Gather messages from the exception itself and its chain. AOTInductor
# (.pt2) sometimes strips the underlying OOM message when rewrapping,
# but not always; checking ``__cause__`` / ``__context__`` catches the
# remaining cases when the original error is preserved.
msgs: list[str] = []
cur: BaseException | None = e
seen: set[int] = set()
while cur is not None and id(cur) not in seen:
seen.add(id(cur))
if cur.args:
first = cur.args[0]
if isinstance(first, str):
msgs.append(first)
cur = cur.__cause__ or cur.__context__

# Several sources treat CUSOLVER_STATUS_INTERNAL_ERROR as an OOM, e.g.
# https://github.com/JuliaGPU/CUDA.jl/issues/1924
plain_oom_markers = (
"CUDA out of memory.",
"CUDA driver error: out of memory",
"CUDA error: out of memory",
"cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR",
)
if any(m in msg for msg in msgs for m in plain_oom_markers):
torch.cuda.empty_cache()
return True

# AOTInductor (.pt2) wraps the underlying CUDA OOM as a generic
# ``run_func_(...) API call failed at .../model_container_runner.cpp``.
# https://github.com/deepmodeling/deepmd-kit/issues/4594
Comment thread
OutisLi marked this conversation as resolved.
Outdated
# The original "CUDA out of memory" text is printed to stderr only and
# is absent from the Python-level RuntimeError, so we match on the
# wrapper signature. If the root cause turns out to be something
# other than OOM, ``execute()`` will keep shrinking the batch and
# eventually raise ``OutOfMemoryError`` at batch size 1, which is a
# clean failure rather than an uncaught exception.
aoti_wrapped = any(
"run_func_(" in msg and "model_container_runner" in msg for msg in msgs
)
if aoti_wrapped:
torch.cuda.empty_cache()
return True
Comment thread
OutisLi marked this conversation as resolved.

return False
33 changes: 33 additions & 0 deletions source/tests/pt/test_auto_batch_size.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest
from unittest import (
mock,
)

import numpy as np

Expand All @@ -9,6 +12,36 @@


class TestAutoBatchSize(unittest.TestCase):
@mock.patch("deepmd.pt.utils.auto_batch_size.torch.cuda.empty_cache")
def test_is_oom_error_cuda_message(self, empty_cache) -> None:
auto_batch_size = AutoBatchSize(256, 2.0)

self.assertTrue(
auto_batch_size.is_oom_error(RuntimeError("CUDA out of memory."))
)
empty_cache.assert_called_once()

@mock.patch("deepmd.pt.utils.auto_batch_size.torch.cuda.empty_cache")
def test_is_oom_error_empty_runtime_error_from_cuda_oom(self, empty_cache) -> None:
auto_batch_size = AutoBatchSize(256, 2.0)
cause = RuntimeError("CUDA driver error: out of memory")
error = RuntimeError()
error.__cause__ = cause

self.assertTrue(auto_batch_size.is_oom_error(error))
empty_cache.assert_called_once()

@mock.patch("deepmd.pt.utils.auto_batch_size.torch.cuda.empty_cache")
def test_is_oom_error_aoti_wrapper(self, empty_cache) -> None:
auto_batch_size = AutoBatchSize(256, 2.0)
error = RuntimeError(
"run_func_(...) API call failed at "
"/tmp/torchinductor/model_container_runner.cpp"
)

self.assertTrue(auto_batch_size.is_oom_error(error))
empty_cache.assert_called_once()

def test_execute_all(self) -> None:
dd0 = np.zeros((10000, 2, 1, 3, 4))
dd1 = np.ones((10000, 2, 1, 3, 4))
Expand Down
Loading