1111#include < executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1212
1313#include < executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
14+ #include < executorch/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.h>
1415#include < executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1516
1617#include < executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
@@ -296,6 +297,41 @@ Conv2dMethod get_conv2d_method(
296297 return Conv2dMethod::SlidingWindow;
297298}
298299
300+ // Decide whether a SlidingWindow conv2d should be computed via the
301+ // im2col + GEMM path (conv2d_gemm_impl) instead of the direct convolution
302+ // shader. Across 26 configs on Mali-G715 (buffer path) and Adreno SM8650
303+ // (texture path): FP32 cases were numerically verified against the reference;
304+ // FP16 cases were routing/dispatch-validated only (the reference is float-only
305+ // for the large shapes, so FP16 outputs were not numerically checked).
306+ //
307+ // Only called for SlidingWindow conv2d (1x1 is routed to conv2d_pw and
308+ // Depthwise/Transposed are handled before the call site).
309+ //
310+ // Preconditions (fall back to direct conv if any fail — the im2col path is
311+ // either not applicable or not beneficial):
312+ // - groups == 1
313+ // - dilation == 1 (all dims)
314+ //
315+ // Selection rule: use im2col on Mali universally, or once the output channel
316+ // count is large enough to amortize the fixed ~N*K_total im2col gather cost.
317+ constexpr int64_t kIm2colMinCOut = 128 ;
318+
319+ bool should_use_conv2d_im2col (
320+ ComputeGraph& graph,
321+ const ValueRef weight_data,
322+ const int64_t groups_val,
323+ const Kernel2dParams& kernel_params) {
324+ if (groups_val != 1 ) {
325+ return false ;
326+ }
327+ if (kernel_params.dilation [0 ] != 1 || kernel_params.dilation [1 ] != 1 ) {
328+ return false ;
329+ }
330+ const auto weight_sizes = graph.sizes_of (weight_data);
331+ const int64_t c_out = weight_sizes.at (0 );
332+ return graph.device_is_mali () || c_out >= kIm2colMinCOut ;
333+ }
334+
299335utils::uvec3 create_conv2d_global_wg_size (
300336 ComputeGraph& graph,
301337 const Conv2dMethod method,
@@ -425,7 +461,8 @@ void add_conv2d_node(
425461 const ValueRef out_min,
426462 const ValueRef out_max,
427463 const ValueRef out,
428- const bool clamp_out) {
464+ const bool clamp_out,
465+ const bool force_direct) {
429466 const bool transposed_val = graph.get_bool (transposed);
430467
431468 float out_min_val = 0 .0f ;
@@ -473,6 +510,37 @@ void add_conv2d_node(
473510 out_max_val);
474511 }
475512
513+ const Kernel2dParams kernel_params = create_kernel2d_params (
514+ graph,
515+ weight_data,
516+ /* kernel_size_only = */ false ,
517+ stride,
518+ padding,
519+ dilation);
520+
521+ // SlidingWindow conv2d: route to the im2col + GEMM path when the heuristic
522+ // indicates it is beneficial, falling back to the direct convolution shader
523+ // otherwise. `force_direct` bypasses the heuristic entirely and forces the
524+ // direct path (used by tests to exercise the direct shader regardless of
525+ // device); the default (false) reproduces the production routing exactly.
526+ const bool use_im2col = !force_direct &&
527+ method == Conv2dMethod::SlidingWindow &&
528+ should_use_conv2d_im2col (graph, weight_data, groups_val, kernel_params);
529+ if (use_im2col) {
530+ return conv2d_gemm_impl (
531+ graph,
532+ in,
533+ weight_data,
534+ bias,
535+ stride,
536+ padding,
537+ dilation,
538+ out,
539+ clamp_out,
540+ out_min_val,
541+ out_max_val);
542+ }
543+
476544 ValueRef arg_weight = prepack_weights (graph, weight_data, method);
477545 ValueRef arg_bias = prepack_biases (
478546 graph,
@@ -489,13 +557,6 @@ void add_conv2d_node(
489557
490558 check_conv_args (graph, in, out);
491559
492- Kernel2dParams kernel_params = create_kernel2d_params (
493- graph,
494- weight_data,
495- /* kernel_size_only = */ false ,
496- stride,
497- padding,
498- dilation);
499560 Conv2dParams extra_params =
500561 create_conv2d_params (graph, weight_data, kernel_params, transposed_val);
501562
0 commit comments