Profiler tensor shape to MojoOperator#359
Conversation
…o each op shows up in profiler traces with its class name, forward signature, and input tensor shape/dtype metadata. Adds helpers _get_forward_signature_repr (cached per class), _format_tensor_meta, and _build_mojo_shape_repr to build the record name.
There was a problem hiding this comment.
Code Review
This pull request adds profiling support to operators in mojo_opset/core/operator.py by overriding _call_impl to record execution details when PyTorch's profiler is active. It dynamically extracts the operator's forward signature and formats the shapes and types of tensor arguments. The review feedback suggests improving _build_mojo_shape_repr to recursively handle tensors nested within lists or tuples, which ensures that operators accepting collections of tensors are correctly profiled.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _build_mojo_shape_repr(self, args, kwargs) -> str: | ||
| parts = [] | ||
| for a in args: | ||
| if isinstance(a, torch.Tensor): | ||
| parts.append(self._format_tensor_meta(a)) | ||
| for k, v in kwargs.items(): | ||
| if isinstance(v, torch.Tensor): | ||
| parts.append(f"{k}={self._format_tensor_meta(v)}") | ||
| return ",".join(parts) |
There was a problem hiding this comment.
The current implementation of _build_mojo_shape_repr only inspects direct torch.Tensor arguments. However, many operators accept lists or tuples of tensors (e.g., Concat, Stack, or custom multi-input/multi-output layers). In those cases, the tensor shapes and dtypes will be completely omitted from the profiler trace.
We can make this more robust by recursively formatting elements within lists and tuples so that their shapes and dtypes are also captured in the profiler traces.
def _build_mojo_shape_repr(self, args, kwargs) -> str:
def _format_item(item) -> Optional[str]:
if isinstance(item, torch.Tensor):
return self._format_tensor_meta(item)
elif isinstance(item, (list, tuple)):
inner = [_format_item(x) for x in item]
inner = [x for x in inner if x is not None]
if inner:
bracket_open, bracket_close = ("[", "]") if isinstance(item, list) else ("(", ")")
return f"{bracket_open}{','.join(inner)}{bracket_close}"
return None
parts = []
for a in args:
formatted = _format_item(a)
if formatted is not None:
parts.append(formatted)
for k, v in kwargs.items():
formatted = _format_item(v)
if formatted is not None:
parts.append(f"{k}={formatted}")
return ",".join(parts)There was a problem hiding this comment.
Pull request overview
Adds richer PyTorch profiler annotations for MojoOperator executions by wrapping module calls in torch.profiler.record_function with names that include the operator class, forward signature, and tensor shape/dtype metadata, improving trace readability when diagnosing performance.
Changes:
- Override
MojoOperator._call_implto emit arecord_functionrange when profiling is enabled. - Add per-class cached forward-signature rendering via
inspect.signature. - Add tensor metadata formatting helpers to render
(shape)dtypestrings into profiler range names.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _call_impl(self, *args, **kwargs): | ||
| if torch.autograd._profiler_enabled(): | ||
| sig = self._get_forward_signature_repr() | ||
| shape_repr = self._build_mojo_shape_repr(args, kwargs) | ||
| name = f"{type(self).__name__}{sig}[{shape_repr}]" | ||
| with torch.profiler.record_function(name): | ||
| return super()._call_impl(*args, **kwargs) | ||
| return super()._call_impl(*args, **kwargs) |
| @staticmethod | ||
| def _format_tensor_meta(t: torch.Tensor) -> str: | ||
| dtype_str = str(t.dtype).replace("torch.", "") | ||
| shape = ",".join(str(s) for s in t.shape) | ||
| return f"({shape}){dtype_str}" |
Wrap MojoOperator._call_impl in torch.profiler.record_function so each forward shows up in traces as ClassName(sig)[shapes], with a per-class cached signature and tensor args rendered as (shape)dtype (bf16/fp16/fp32/i32/...).
