Skip to content

Commit b3ffea8

Browse files
committed
Update
[ghstack-poisoned]
1 parent d9c6c0a commit b3ffea8

11 files changed

Lines changed: 843 additions & 3 deletions

File tree

backends/xnnpack/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,13 @@ set(xnnpack_third_party pthreadpool extension_threadpool cpuinfo)
101101
include(cmake/Dependencies.cmake)
102102

103103
# Graph runtime sources.
104-
list(APPEND _xnnpack_backend__srcs backends/xnnpack/runtime/core/tensor.cpp
105-
backends/xnnpack/runtime/core/quant_params.cpp
104+
list(
105+
APPEND
106+
_xnnpack_backend__srcs
107+
backends/xnnpack/runtime/core/tensor.cpp
108+
backends/xnnpack/runtime/core/quant_params.cpp
109+
backends/xnnpack/runtime/graph/graph.cpp
110+
backends/xnnpack/runtime/graph/graph_builder.cpp
106111
)
107112

108113
list(TRANSFORM _xnnpack_backend__srcs PREPEND "${EXECUTORCH_ROOT}/")
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#include <executorch/backends/xnnpack/runtime/graph/graph.h>
2+
3+
#include <executorch/backends/xnnpack/runtime/core/variant_util.h>
4+
#include <executorch/runtime/core/error.h>
5+
#include <executorch/runtime/platform/log.h>
6+
7+
#include <cassert>
8+
9+
namespace executorch::backends::xnnpack::graph {
10+
11+
namespace {
12+
13+
void scan_spec(const TensorSpec& spec, uint32_t& max_id) {
14+
for (auto& dim : spec.sizes) {
15+
for (auto& term : dim.coeffs) {
16+
if (term.sym >= max_id) {
17+
max_id = term.sym + 1;
18+
}
19+
}
20+
}
21+
}
22+
23+
void scan_output_spec(const OutputSpec& os, uint32_t& max_id) {
24+
std::visit(
25+
overloaded{
26+
[&](const TensorSpec& s) { scan_spec(s, max_id); },
27+
[&](const std::vector<TensorSpec>& v) {
28+
for (auto& s : v)
29+
scan_spec(s, max_id);
30+
},
31+
},
32+
os);
33+
}
34+
35+
} // namespace
36+
37+
uint32_t Graph::symint_count() const {
38+
uint32_t count = 0;
39+
for (auto& spec : input_specs) {
40+
scan_spec(spec, count);
41+
}
42+
for (auto& node : nodes) {
43+
std::visit(
44+
overloaded{
45+
[](const InputNode&) {},
46+
[](const ConstantNode&) {},
47+
[&](const CallOperatorNode& n) {
48+
scan_output_spec(n.output_specs, count);
49+
},
50+
[&](const CallSubgraphNode& n) {
51+
scan_output_spec(n.output_specs, count);
52+
},
53+
},
54+
node.value);
55+
}
56+
return count;
57+
}
58+
59+
void Graph::update_users() {
60+
for (auto& node : nodes) {
61+
node.users.clear();
62+
}
63+
64+
for (NodeHandle i = 0; i < nodes.size(); ++i) {
65+
std::visit(
66+
overloaded{
67+
[](const InputNode&) {},
68+
[](const ConstantNode&) {},
69+
[&](const CallOperatorNode& n) {
70+
for (auto arg : n.args) {
71+
if (!arg.is_null()) {
72+
nodes[arg.node].users.push_back(i);
73+
}
74+
}
75+
},
76+
[&](const CallSubgraphNode& n) {
77+
for (auto arg : n.args) {
78+
if (!arg.is_null()) {
79+
nodes[arg.node].users.push_back(i);
80+
}
81+
}
82+
},
83+
},
84+
nodes[i].value);
85+
}
86+
}
87+
88+
runtime::Error Graph::compact_nodes() {
89+
std::vector<uint32_t> remap(nodes.size(), UINT32_MAX);
90+
uint32_t new_idx = 0;
91+
for (NodeHandle i = 0; i < nodes.size(); i++) {
92+
if ((nodes[i].flags & NodeFlags::Dead) != NodeFlags::None) {
93+
continue;
94+
}
95+
remap[i] = new_idx++;
96+
}
97+
98+
// Validate that no live node or output references a dead/invalid node before
99+
// mutating any handles, so a failure leaves the graph untouched.
100+
bool valid = true;
101+
auto check_vh = [&](const ValueHandle& vh) {
102+
if (vh.is_null()) {
103+
return;
104+
}
105+
if (vh.node >= remap.size() || remap[vh.node] == UINT32_MAX) {
106+
valid = false;
107+
}
108+
};
109+
for (NodeHandle i = 0; i < nodes.size(); i++) {
110+
if ((nodes[i].flags & NodeFlags::Dead) != NodeFlags::None) {
111+
continue;
112+
}
113+
for (const auto& a : nodes[i].get_args()) {
114+
check_vh(a);
115+
}
116+
}
117+
for (const auto& out : outputs) {
118+
check_vh(out);
119+
}
120+
ET_CHECK_OR_RETURN_ERROR(
121+
valid,
122+
Internal,
123+
"compact_nodes: a live node or output references a dead node");
124+
125+
auto rewrite_vh = [&](ValueHandle& vh) {
126+
if (!vh.is_null()) {
127+
vh.node = remap[vh.node];
128+
}
129+
};
130+
131+
for (NodeHandle i = 0; i < nodes.size(); i++) {
132+
if ((nodes[i].flags & NodeFlags::Dead) != NodeFlags::None) {
133+
continue;
134+
}
135+
std::visit(
136+
overloaded{
137+
[](InputNode&) {},
138+
[](ConstantNode&) {},
139+
[&](CallOperatorNode& n) {
140+
for (auto& a : n.args)
141+
rewrite_vh(a);
142+
},
143+
[&](CallSubgraphNode& n) {
144+
for (auto& a : n.args)
145+
rewrite_vh(a);
146+
},
147+
},
148+
nodes[i].value);
149+
}
150+
151+
for (auto& out : outputs) {
152+
rewrite_vh(out);
153+
}
154+
155+
std::vector<Node> compacted;
156+
compacted.reserve(new_idx);
157+
for (NodeHandle i = 0; i < nodes.size(); i++) {
158+
if ((nodes[i].flags & NodeFlags::Dead) != NodeFlags::None) {
159+
continue;
160+
}
161+
compacted.push_back(std::move(nodes[i]));
162+
}
163+
nodes = std::move(compacted);
164+
165+
update_users();
166+
return runtime::Error::Ok;
167+
}
168+
169+
} // namespace executorch::backends::xnnpack::graph
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#pragma once
2+
3+
#include <executorch/backends/xnnpack/runtime/core/variant_util.h>
4+
#include <executorch/backends/xnnpack/runtime/graph/handles.h>
5+
#include <executorch/backends/xnnpack/runtime/graph/node.h>
6+
#include <executorch/backends/xnnpack/runtime/graph/tensor_spec.h>
7+
#include <executorch/runtime/core/error.h>
8+
#include <vector>
9+
10+
namespace executorch::backends::xnnpack::graph {
11+
12+
/*
13+
* Describes a computational graph.
14+
*/
15+
struct Graph {
16+
std::vector<TensorSpec> input_specs;
17+
std::vector<Node> nodes;
18+
std::vector<ValueHandle> outputs;
19+
20+
/* Clean up nodes marked as dead. */
21+
[[nodiscard]] runtime::Error compact_nodes();
22+
23+
/* Returns the number of symints referenced in the graph. */
24+
uint32_t symint_count() const;
25+
26+
/* Regenerate user metadata on nodes. */
27+
void update_users();
28+
29+
/* Retrieve the output specifier for a given node handle. */
30+
inline OutputSpec get_output_spec_for_node(NodeHandle node) const {
31+
return std::visit(
32+
overloaded{
33+
[&](const InputNode& n) -> OutputSpec {
34+
return input_specs.at(n.input);
35+
},
36+
[](const ConstantNode& n) -> OutputSpec {
37+
TensorSpec spec;
38+
spec.dtype = n.tensor->dtype;
39+
spec.sizes.reserve(n.tensor->sizes.size());
40+
for (auto s : n.tensor->sizes) {
41+
spec.sizes.push_back(
42+
DimSizeSpec::constant(static_cast<int64_t>(s)));
43+
}
44+
spec.quant_params = n.quant_params;
45+
return spec;
46+
},
47+
[](const CallOperatorNode& n) -> OutputSpec {
48+
return n.output_specs;
49+
},
50+
[](const CallSubgraphNode& n) -> OutputSpec {
51+
return n.output_specs;
52+
},
53+
},
54+
nodes[node].value);
55+
}
56+
57+
/* Retrieve the tensor spec for a given value handle. */
58+
inline TensorSpec get_tensor_spec(ValueHandle vh) const {
59+
auto spec = get_output_spec_for_node(vh.node);
60+
return std::visit(
61+
overloaded{
62+
[](const TensorSpec& s) -> TensorSpec { return s; },
63+
[&](const std::vector<TensorSpec>& v) -> TensorSpec {
64+
return v.at(vh.output);
65+
},
66+
},
67+
spec);
68+
}
69+
};
70+
71+
} // namespace executorch::backends::xnnpack::graph
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#include <executorch/backends/xnnpack/runtime/graph/graph_builder.h>
2+
3+
#include <utility>
4+
5+
namespace executorch::backends::xnnpack::graph {
6+
7+
Graph GraphBuilder::build() {
8+
Graph g;
9+
g.input_specs = std::move(input_specs_);
10+
g.nodes = std::move(nodes_);
11+
g.outputs = std::move(outputs_);
12+
return g;
13+
}
14+
15+
ValueHandle GraphBuilder::createInput(TensorSpec spec) {
16+
input_specs_.push_back(std::move(spec));
17+
18+
InputHandle input = next_input_;
19+
next_input_++;
20+
21+
ValueHandle handle{static_cast<uint32_t>(nodes_.size())};
22+
Node node;
23+
node.value = InputNode{input};
24+
nodes_.push_back(std::move(node));
25+
return handle;
26+
}
27+
28+
ValueHandle GraphBuilder::createConstant(
29+
std::shared_ptr<const core::Tensor> tensor,
30+
std::optional<core::QuantParams> quant_params) {
31+
ValueHandle handle{static_cast<uint32_t>(nodes_.size())};
32+
ConstantNode cn;
33+
cn.tensor = std::move(tensor);
34+
cn.quant_params = std::move(quant_params);
35+
Node node;
36+
node.value = std::move(cn);
37+
nodes_.push_back(std::move(node));
38+
return handle;
39+
}
40+
41+
ValueHandle GraphBuilder::createOperator(
42+
Operator op,
43+
TensorSpec output_spec,
44+
ValueHandles args) {
45+
ValueHandle handle{static_cast<uint32_t>(nodes_.size())};
46+
CallOperatorNode con;
47+
con.args = std::move(args);
48+
con.op = op;
49+
con.output_specs = std::move(output_spec);
50+
Node node;
51+
node.value = std::move(con);
52+
nodes_.push_back(std::move(node));
53+
return handle;
54+
}
55+
56+
ValueHandle GraphBuilder::createOperator(
57+
Operator op,
58+
TensorSpec output_spec,
59+
ValueHandles args,
60+
std::vector<ConstantArg> constant_args,
61+
float output_min,
62+
float output_max) {
63+
ValueHandle handle{static_cast<uint32_t>(nodes_.size())};
64+
CallOperatorNode con;
65+
con.args = std::move(args);
66+
con.op = op;
67+
con.output_specs = std::move(output_spec);
68+
con.constant_args = std::move(constant_args);
69+
con.output_min = output_min;
70+
con.output_max = output_max;
71+
Node node;
72+
node.value = std::move(con);
73+
nodes_.push_back(std::move(node));
74+
return handle;
75+
}
76+
77+
ValueHandle GraphBuilder::createOperatorM(
78+
Operator op,
79+
std::vector<TensorSpec> output_specs,
80+
ValueHandles args) {
81+
ValueHandle handle{static_cast<uint32_t>(nodes_.size())};
82+
CallOperatorNode con;
83+
con.args = std::move(args);
84+
con.op = op;
85+
con.output_specs = std::move(output_specs);
86+
Node node;
87+
node.value = std::move(con);
88+
nodes_.push_back(std::move(node));
89+
return handle;
90+
}
91+
92+
OutputHandle GraphBuilder::createOutput(ValueHandle handle) {
93+
OutputHandle output = static_cast<OutputHandle>(outputs_.size());
94+
outputs_.push_back(handle);
95+
return output;
96+
}
97+
98+
SymIntHandle GraphBuilder::createSymInt() {
99+
return next_sym_int_++;
100+
}
101+
102+
} // namespace executorch::backends::xnnpack::graph

0 commit comments

Comments
 (0)