Skip to content

Profiler tensor shape to MojoOperator#359

Open
NASA1473 wants to merge 1 commit into
dev/m13_ilufrom
profile_shape
Open

Profiler tensor shape to MojoOperator#359
NASA1473 wants to merge 1 commit into
dev/m13_ilufrom
profile_shape

Conversation

@NASA1473

@NASA1473 NASA1473 commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Wrap MojoOperator._call_impl in torch.profiler.record_function so each forward shows up in traces as ClassName(sig)[shapes], with a per-class cached signature and tensor args rendered as (shape)dtype (bf16/fp16/fp32/i32/...).
截屏2026-06-02 16 45 33

…o each op shows up in profiler traces with its class name, forward signature, and input tensor shape/dtype metadata. Adds helpers _get_forward_signature_repr (cached per class), _format_tensor_meta, and _build_mojo_shape_repr to build the record name.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds profiling support to operators in mojo_opset/core/operator.py by overriding _call_impl to record execution details when PyTorch's profiler is active. It dynamically extracts the operator's forward signature and formats the shapes and types of tensor arguments. The review feedback suggests improving _build_mojo_shape_repr to recursively handle tensors nested within lists or tuples, which ensures that operators accepting collections of tensors are correctly profiled.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +105 to +113
def _build_mojo_shape_repr(self, args, kwargs) -> str:
parts = []
for a in args:
if isinstance(a, torch.Tensor):
parts.append(self._format_tensor_meta(a))
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
parts.append(f"{k}={self._format_tensor_meta(v)}")
return ",".join(parts)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of _build_mojo_shape_repr only inspects direct torch.Tensor arguments. However, many operators accept lists or tuples of tensors (e.g., Concat, Stack, or custom multi-input/multi-output layers). In those cases, the tensor shapes and dtypes will be completely omitted from the profiler trace.

We can make this more robust by recursively formatting elements within lists and tuples so that their shapes and dtypes are also captured in the profiler traces.

    def _build_mojo_shape_repr(self, args, kwargs) -> str:
        def _format_item(item) -> Optional[str]:
            if isinstance(item, torch.Tensor):
                return self._format_tensor_meta(item)
            elif isinstance(item, (list, tuple)):
                inner = [_format_item(x) for x in item]
                inner = [x for x in inner if x is not None]
                if inner:
                    bracket_open, bracket_close = ("[", "]") if isinstance(item, list) else ("(", ")")
                    return f"{bracket_open}{','.join(inner)}{bracket_close}"
            return None

        parts = []
        for a in args:
            formatted = _format_item(a)
            if formatted is not None:
                parts.append(formatted)
        for k, v in kwargs.items():
            formatted = _format_item(v)
            if formatted is not None:
                parts.append(f"{k}={formatted}")
        return ",".join(parts)

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds richer PyTorch profiler annotations for MojoOperator executions by wrapping module calls in torch.profiler.record_function with names that include the operator class, forward signature, and tensor shape/dtype metadata, improving trace readability when diagnosing performance.

Changes:

  • Override MojoOperator._call_impl to emit a record_function range when profiling is enabled.
  • Add per-class cached forward-signature rendering via inspect.signature.
  • Add tensor metadata formatting helpers to render (shape)dtype strings into profiler range names.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +71 to +78
def _call_impl(self, *args, **kwargs):
if torch.autograd._profiler_enabled():
sig = self._get_forward_signature_repr()
shape_repr = self._build_mojo_shape_repr(args, kwargs)
name = f"{type(self).__name__}{sig}[{shape_repr}]"
with torch.profiler.record_function(name):
return super()._call_impl(*args, **kwargs)
return super()._call_impl(*args, **kwargs)
Comment on lines +99 to +103
@staticmethod
def _format_tensor_meta(t: torch.Tensor) -> str:
dtype_str = str(t.dtype).replace("torch.", "")
shape = ",".join(str(s) for s in t.shape)
return f"({shape}){dtype_str}"
@NASA1473 NASA1473 changed the title Add profiler tensor shape to MojoOperator Profiler tensor shape to MojoOperator Jun 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants