Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,13 @@ class Model(torch.nn.Module):
num_warps=32,
)
return OUT

# Optional
# If this method is provided, the Xe-Forge will get input tensors by this method.
# If not, Xe-Forge will generate random inputs based on shapes and dtype.
# Don't provide this method (remove it) if you don't need it.
def get_example_inputs(self, input_shapes: list | None = None):
pass
```

### Model with Init Arguments
Expand Down
9 changes: 7 additions & 2 deletions src/xe_forge/core/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def execute(

# Create inputs if not provided
if inputs is None:
if input_shapes:
if hasattr(model, "get_example_inputs"):
inputs = model.get_example_inputs(input_shapes, self.device)
elif input_shapes:
inputs = self._create_inputs(
input_shapes, dtype=dtype, input_dtypes=input_dtypes
)
Expand Down Expand Up @@ -302,7 +304,10 @@ def _check_correctness(

# Shared inputs with deterministic seed
set_all_seeds(123)
inputs = self._create_inputs(input_shapes, dtype=dtype, input_dtypes=input_dtypes)
if hasattr(original_model, "get_example_inputs"):
inputs = original_model.get_example_inputs(input_shapes, self.device)
else:
inputs = self._create_inputs(input_shapes, dtype=dtype, input_dtypes=input_dtypes)

inputs_orig = [inp.clone() for inp in inputs]
inputs_opt = [inp.clone() for inp in inputs]
Expand Down
Loading