Skip to content

Commit caeb005

Browse files
pytorchbotssjia
authored andcommitted
[ET-VK][conv2d] Auto-route SlidingWindow conv2d to im2col/GEMM via device-aware heuristic (#20190)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #20059 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/557/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/557/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/556/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/557/orig Differential Revision: [D107595816](https://our.internmc.facebook.com/intern/diff/D107595816/) @diff-train-skip-merge Co-authored-by: ssjia <ssjia@devvm1479.ncg0.facebook.com>
1 parent fd2cf88 commit caeb005

4 files changed

Lines changed: 173 additions & 13 deletions

File tree

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1212

1313
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.h>
1415
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1516

1617
#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
@@ -296,6 +297,41 @@ Conv2dMethod get_conv2d_method(
296297
return Conv2dMethod::SlidingWindow;
297298
}
298299

300+
// Decide whether a SlidingWindow conv2d should be computed via the
301+
// im2col + GEMM path (conv2d_gemm_impl) instead of the direct convolution
302+
// shader. Across 26 configs on Mali-G715 (buffer path) and Adreno SM8650
303+
// (texture path): FP32 cases were numerically verified against the reference;
304+
// FP16 cases were routing/dispatch-validated only (the reference is float-only
305+
// for the large shapes, so FP16 outputs were not numerically checked).
306+
//
307+
// Only called for SlidingWindow conv2d (1x1 is routed to conv2d_pw and
308+
// Depthwise/Transposed are handled before the call site).
309+
//
310+
// Preconditions (fall back to direct conv if any fail — the im2col path is
311+
// either not applicable or not beneficial):
312+
// - groups == 1
313+
// - dilation == 1 (all dims)
314+
//
315+
// Selection rule: use im2col on Mali universally, or once the output channel
316+
// count is large enough to amortize the fixed ~N*K_total im2col gather cost.
317+
constexpr int64_t kIm2colMinCOut = 128;
318+
319+
bool should_use_conv2d_im2col(
320+
ComputeGraph& graph,
321+
const ValueRef weight_data,
322+
const int64_t groups_val,
323+
const Kernel2dParams& kernel_params) {
324+
if (groups_val != 1) {
325+
return false;
326+
}
327+
if (kernel_params.dilation[0] != 1 || kernel_params.dilation[1] != 1) {
328+
return false;
329+
}
330+
const auto weight_sizes = graph.sizes_of(weight_data);
331+
const int64_t c_out = weight_sizes.at(0);
332+
return graph.device_is_mali() || c_out >= kIm2colMinCOut;
333+
}
334+
299335
utils::uvec3 create_conv2d_global_wg_size(
300336
ComputeGraph& graph,
301337
const Conv2dMethod method,
@@ -425,7 +461,8 @@ void add_conv2d_node(
425461
const ValueRef out_min,
426462
const ValueRef out_max,
427463
const ValueRef out,
428-
const bool clamp_out) {
464+
const bool clamp_out,
465+
const bool force_direct) {
429466
const bool transposed_val = graph.get_bool(transposed);
430467

431468
float out_min_val = 0.0f;
@@ -473,6 +510,37 @@ void add_conv2d_node(
473510
out_max_val);
474511
}
475512

513+
const Kernel2dParams kernel_params = create_kernel2d_params(
514+
graph,
515+
weight_data,
516+
/*kernel_size_only = */ false,
517+
stride,
518+
padding,
519+
dilation);
520+
521+
// SlidingWindow conv2d: route to the im2col + GEMM path when the heuristic
522+
// indicates it is beneficial, falling back to the direct convolution shader
523+
// otherwise. `force_direct` bypasses the heuristic entirely and forces the
524+
// direct path (used by tests to exercise the direct shader regardless of
525+
// device); the default (false) reproduces the production routing exactly.
526+
const bool use_im2col = !force_direct &&
527+
method == Conv2dMethod::SlidingWindow &&
528+
should_use_conv2d_im2col(graph, weight_data, groups_val, kernel_params);
529+
if (use_im2col) {
530+
return conv2d_gemm_impl(
531+
graph,
532+
in,
533+
weight_data,
534+
bias,
535+
stride,
536+
padding,
537+
dilation,
538+
out,
539+
clamp_out,
540+
out_min_val,
541+
out_max_val);
542+
}
543+
476544
ValueRef arg_weight = prepack_weights(graph, weight_data, method);
477545
ValueRef arg_bias = prepack_biases(
478546
graph,
@@ -489,13 +557,6 @@ void add_conv2d_node(
489557

490558
check_conv_args(graph, in, out);
491559

492-
Kernel2dParams kernel_params = create_kernel2d_params(
493-
graph,
494-
weight_data,
495-
/*kernel_size_only = */ false,
496-
stride,
497-
padding,
498-
dilation);
499560
Conv2dParams extra_params =
500561
create_conv2d_params(graph, weight_data, kernel_params, transposed_val);
501562

backends/vulkan/runtime/graph/ops/impl/Convolution.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,26 @@ void resize_conv2d_node(
5656
const std::vector<ArgGroup>& args,
5757
const std::vector<ValueRef>& extra_args);
5858

59+
// `force_direct` overrides the im2col-vs-direct routing heuristic: when true,
60+
// a SlidingWindow conv2d always takes the direct sliding-window path,
61+
// bypassing should_use_conv2d_im2col(). The default (false) preserves the
62+
// production routing exactly. Pointwise / Depthwise / Transposed methods are
63+
// unaffected by this flag.
64+
void add_conv2d_node(
65+
ComputeGraph& graph,
66+
const ValueRef in,
67+
const ValueRef weight_data,
68+
const ValueRef bias,
69+
const ValueRef stride,
70+
const ValueRef padding,
71+
const ValueRef dilation,
72+
const ValueRef transposed,
73+
const ValueRef output_padding,
74+
const ValueRef groups,
75+
const ValueRef out_min,
76+
const ValueRef out_max,
77+
const ValueRef out,
78+
const bool clamp_out,
79+
const bool force_direct = false);
80+
5981
} // namespace vkcompute

backends/vulkan/test/custom_ops/impl/TestConv2d.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1212
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Convolution.h>
1314

1415
#include <optional>
1516

@@ -29,7 +30,10 @@ void test_conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
2930
// args[10] = output [N, C_out, H_out, W_out]
3031
//
3132
// impl_selector grammar:
32-
// "" -> aten.convolution.default (direct sliding-window)
33+
// "" -> aten.convolution.default (heuristic-routed:
34+
// should_use_conv2d_im2col() picks direct vs im2col)
35+
// "direct" -> add_conv2d_node(force_direct=true): forces the direct
36+
// sliding-window path, bypassing the routing heuristic
3337
// "im2col" -> et_vk.conv2d_gemm.default, auto im2col storage
3438
// "im2col_buffer"-> im2col/GEMM, force buffer im2col intermediate
3539
// "im2col_tex2d" -> im2col/GEMM, force texture2d im2col intermediate
@@ -88,6 +92,31 @@ void test_conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
8892
graph.add_scalar_list<int64_t>(std::vector<int64_t>{0, 0});
8993
ValueRef groups = graph.add_scalar<int64_t>(1);
9094

95+
// The "direct" selector must reach the exact direct sliding-window dispatch
96+
// the heuristic would otherwise pick. The registered op can only route via
97+
// the heuristic, so call add_conv2d_node directly with force_direct=true to
98+
// bypass it (mirroring how the forced-storage variants call
99+
// conv2d_gemm_impl).
100+
if (impl_selector == "direct") {
101+
add_conv2d_node(
102+
graph,
103+
input,
104+
weight,
105+
bias,
106+
stride,
107+
padding,
108+
dilation,
109+
transposed,
110+
output_padding,
111+
groups,
112+
/*out_min=*/kDummyValueRef,
113+
/*out_max=*/kDummyValueRef,
114+
out,
115+
/*clamp_out=*/false,
116+
/*force_direct=*/true);
117+
return;
118+
}
119+
91120
const std::string target_op = (impl_selector == "im2col")
92121
? "et_vk.conv2d_gemm.default"
93122
: "aten.convolution.default";

backends/vulkan/test/custom_ops/test_conv2d.cpp

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -493,13 +493,49 @@ static std::vector<TestCase> generate_conv2d_test_cases() {
493493
true},
494494
};
495495

496-
// Two implementation variants: direct sliding-window (default) and im2col.
497-
const std::vector<std::string> impls = {"", "im2col"};
496+
// Boundary pair straddling the should_use_conv2d_im2col() c_out >= 128
497+
// routing threshold. Spatial dims are tiny (8x8) so the FP32 float reference
498+
// stays cheap, but c_out = 64 / 128 are both >= kRefDimSizeLimit, so these
499+
// get the PERF label. FP32 PERF cases are still numerically VERIFIED (the
500+
// reference's invalid_argument throw that skips the check only fires for
501+
// half), so both implementations are cross-checked against the float
502+
// reference at the boundary. Run all three impls: at c_out = 64 the heuristic
503+
// ("") picks direct on Adreno / im2col on Mali; at c_out = 128 it picks
504+
// im2col on both — and "direct"/"im2col" force each path regardless, proving
505+
// the two implementations agree at the boundary on either device.
506+
std::vector<Conv2dTestConfig> boundary_configs = {
507+
// c_out = 64 (< 128): below the threshold
508+
{InputDims(1, 16, 8, 8),
509+
64,
510+
KernelSize(3, 3),
511+
Stride(1, 1),
512+
Padding(1, 1),
513+
Dilation(1, 1),
514+
false},
515+
// c_out = 128 (== 128): at/above the threshold
516+
{InputDims(1, 16, 8, 8),
517+
128,
518+
KernelSize(3, 3),
519+
Stride(1, 1),
520+
Padding(1, 1),
521+
Dilation(1, 1),
522+
false},
523+
};
524+
525+
// Implementation variants exercised for every small ACCU shape:
526+
// "" -> heuristic-routed (should_use_conv2d_im2col picks direct on
527+
// Adreno for small c_out, im2col on Mali)
528+
// "im2col" -> forced im2col/GEMM path
529+
// "direct" -> forced direct sliding-window path (force_direct=true)
530+
// Including "direct" guarantees the direct shader gets reference-checked on
531+
// BOTH devices — without it, Mali would always route "" to im2col and never
532+
// exercise the direct path.
533+
const std::vector<std::string> impls = {"", "im2col", "direct"};
498534
// Forced-storage im2col variants for the per-variant ACCU coverage.
499535
const std::vector<std::string> forced_storage_impls = {
500536
"im2col_buffer", "im2col_tex2d", "im2col_tex3d"};
501537

502-
// Generate accuracy test cases for both impls and both dtypes. FP16 small
538+
// Generate accuracy test cases for all impls and both dtypes. FP16 small
503539
// shapes get a real reference check (gated in conv2d_reference_impl); we run
504540
// both dtypes so we catch correctness regressions in either path. Large-K
505541
// half stays timing-only via the reference's PERF-shape throw.
@@ -530,7 +566,19 @@ static std::vector<TestCase> generate_conv2d_test_cases() {
530566
}
531567
}
532568

533-
// Generate performance test cases (float and half) for both impls.
569+
// Generate the c_out boundary pair (FP32 only) through all three impls.
570+
// FP32 PERF cases are reference-VERIFIED, so the direct and im2col paths are
571+
// both cross-checked against the float reference at the routing threshold.
572+
for (const auto& config : boundary_configs) {
573+
for (auto st : storage_types) {
574+
for (const auto& impl : impls) {
575+
test_cases.push_back(
576+
create_conv2d_test_case(config, vkapi::kFloat, st, layout, impl));
577+
}
578+
}
579+
}
580+
581+
// Generate performance test cases (float and half) for all impls.
534582
for (const auto& config : perf_configs) {
535583
std::vector<vkapi::ScalarType> dtypes = {vkapi::kFloat, vkapi::kHalf};
536584
for (auto dtype : dtypes) {

0 commit comments

Comments
 (0)