Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions src/qonnx/custom_op/general/multithreshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_nodeattr_types(self):
"out_dtype": ("s", True, ""),
"out_scale": ("f", False, 1.0),
"out_bias": ("f", False, 0.0),
"data_layout": ("s", False, "NCHW"),
"data_layout": ("s", False, ""),
}

def make_shape_compatible_op(self, model):
Expand Down Expand Up @@ -130,12 +130,32 @@ def execute_node(self, context, graph):
# retrieve attributes if output scaling is used
out_scale = self.get_nodeattr("out_scale")
out_bias = self.get_nodeattr("out_bias")
# transpose input if NHWC data layout is chosen

# Consider the data layout for transposing the input into the format
# accepted by the multithreshold function above, i.e, the channel
# dimension is along the axis with index 1.
data_layout = self.get_nodeattr("data_layout")
channels_last = True if data_layout[-1] == "C" else False
# calculate output
# If there is no layout annotation, guess based on rank of the
# tensor
if not data_layout and len(v.shape) < 5:
# Maps tensor rank to layout annotation
rank_to_layout = {0: None, 1: None, 2: "NC", 3: "NWC", 4: "NCHW"}
# Lookup the layout required by this input shape
data_layout = rank_to_layout[len(v.shape)]
# Lookup the index of the channel dimension in the data layout
# Note: Assumes there is at most one "C" which denotes the channel
# dimension
if data_layout is not None:
cdim = data_layout.index("C") if "C" in data_layout else 1
else:
cdim = 1
# Rearrange the input to the expected (N, C, ...) layout
orig_shape = v.shape
output = multithreshold(v, thresholds, out_scale, out_bias, channels_last)
v = v.swapaxes(cdim, 1)
Comment thread
jmitrevs marked this conversation as resolved.
# Now we can use the multithreshold function to calculate output
output = multithreshold(v, thresholds, out_scale, out_bias)
# Rearrange the output back to the original layout
output = output.swapaxes(cdim, 1)
assert output.shape == orig_shape, "Shape changed during thresholding!"
context[node.output[0]] = output

Expand Down