Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ sections=FUTURE,STDLIB,TEST,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
default_section=THIRDPARTY
multi_line_output=3
profile=black
ignore_comments=true
ignore_whitespace=true
honor_noqa=true
use_parentheses=true
70 changes: 70 additions & 0 deletions src/qonnx/transformation/composed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copies (deep-copies) python objects
Comment thread
iksnagreb marked this conversation as resolved.
Outdated
import copy

# QONNX wrapper of ONNX model graphs
from qonnx.core.modelwrapper import ModelWrapper

# QONNX graph transformations for annotating the graph with datatype and shape
# information
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes

# Cleanup transformations removing identities like multiplication by one or
# addition of zero
from qonnx.transformation.remove import RemoveIdentityOps

# Base class for all QONNX graph transformations and some basic cleanup
# transformations
# fmt: off
from qonnx.transformation.general import ( # isort: skip
GiveReadableTensorNames, GiveUniqueNodeNames, Transformation
)


# fmt: on


# Composes graph transformations such that each individual transformation as
# well as the whole sequence is applied exhaustively
class ComposedTransformation(Transformation):
# Initializes the transformation given a list of transformations
def __init__(self, transformations: list[Transformation]):
# Initialize the transformation base class
super().__init__()
# Register the list of transformations to be applied in apply()
self.transformations = transformations

# Applies the transform to a whole model graph
def apply(self, model: ModelWrapper): # noqa
# Keep track of whether the graph has been modified
graph_modified = False
# Iterate all transformations to be applied
for transformation in self.transformations:
# Start each transformation on a deep copy of the model to mimic the
# behavior of ModelWrapper.transform()
model = copy.deepcopy(model)
# Exhaustively apply the transformation until it no longer modifies
# the graph
while True:
# Apply the transformation once, reporting back whether any node
# or pattern has been modified
model, _graph_modified = transformation.apply(model)
# Keep track whether the graph has been modified at least once
graph_modified = graph_modified or _graph_modified
# Break the loop if this transformation did not change anything
if not _graph_modified:
break
# Apply the cleanup transformations of the ModelWrapper
model.cleanup()
# Apply some further cleanup transformations to the model graph
# removing some clutter and keeping all names readable and ordered
# at any time
model = model.transform(RemoveIdentityOps())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
# Return the transformed model and indicate whether the graph actually
# has been transformed by at least one transformation so the whole
# sequence of transformations will be reapplied
return model, graph_modified
Loading