Skip to content

Commit 4eced3b

Browse files
Update (base update)
[ghstack-poisoned]
1 parent 5526971 commit 4eced3b

18 files changed

Lines changed: 3244 additions & 2 deletions

backends/webgpu/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@ set(WEBGPU_SRCS
3030
runtime/WebGPUGraph.cpp
3131
runtime/WebGPUDelegateHeader.cpp
3232
runtime/WebGPUDevice.cpp
33+
runtime/WebGPUQueryPool.cpp
3334
runtime/ops/OperatorRegistry.cpp
3435
runtime/ops/add/BinaryOp.cpp
3536
runtime/ops/rms_norm/RmsNorm.cpp
3637
runtime/ops/update_cache/UpdateCache.cpp
38+
runtime/ops/sdpa/Sdpa.cpp
3739
runtime/ops/select_as_symint/SelectAsSymint.cpp
3840
)
3941

backends/webgpu/runtime/WebGPUDevice.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <cstdlib>
1414
#include <memory>
1515
#include <stdexcept>
16+
#include <vector>
1617

1718
namespace executorch {
1819
namespace backends {
@@ -137,6 +138,16 @@ WebGPUContext create_webgpu_context() {
137138
WGPUStatus_Success) {
138139
device_desc.requiredLimits = &supported_limits;
139140
}
141+
142+
// Bench: enable TimestampQuery if available; fail-open (skip timing if not).
143+
std::vector<WGPUFeatureName> required_features;
144+
if (wgpuAdapterHasFeature(ctx.adapter, WGPUFeatureName_TimestampQuery)) {
145+
required_features.push_back(WGPUFeatureName_TimestampQuery);
146+
device_desc.requiredFeatureCount = required_features.size();
147+
device_desc.requiredFeatures = required_features.data();
148+
ctx.timestamp_supported = true;
149+
}
150+
140151
device_desc.uncapturedErrorCallbackInfo.callback = on_device_error;
141152

142153
WGPUWaitStatus device_wait = webgpu_wait(

backends/webgpu/runtime/WebGPUDevice.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
#include <webgpu/webgpu.h>
1212

13+
#include <executorch/backends/webgpu/runtime/WebGPUQueryPool.h>
14+
15+
#include <memory>
16+
1317
namespace executorch {
1418
namespace backends {
1519
namespace webgpu {
@@ -19,6 +23,10 @@ struct WebGPUContext {
1923
WGPUAdapter adapter = nullptr;
2024
WGPUDevice device = nullptr;
2125
WGPUQueue queue = nullptr;
26+
// True if the device was created with the TimestampQuery feature (bench).
27+
bool timestamp_supported = false;
28+
// Bench-only: timestamp-query pool, lazily created in execute() (env-gated).
29+
std::unique_ptr<WebGPUQueryPool> querypool;
2230
};
2331

2432
WebGPUContext create_webgpu_context();

backends/webgpu/runtime/WebGPUGraph.cpp

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <executorch/backends/webgpu/runtime/WebGPUCompat.h>
1616
#include <executorch/backends/webgpu/runtime/WebGPUDevice.h>
1717

18+
#include <cstdlib>
1819
#include <cstring>
1920
#include <stdexcept>
2021

@@ -496,18 +497,48 @@ void WebGPUGraph::copy_inputs(
496497
}
497498
}
498499

500+
namespace {
501+
// Bench gate: WEBGPU_TIMESTAMP_QUERY enables per-pass GPU timestamp queries.
502+
bool should_timestamp_query() {
503+
static const bool enabled = std::getenv("WEBGPU_TIMESTAMP_QUERY") != nullptr;
504+
return enabled;
505+
}
506+
} // namespace
507+
499508
void WebGPUGraph::execute() {
500509
const size_t n = dispatches_.size();
501510
const size_t chunk = execute_config_.chunk_size;
502511

503512
if (chunk == 0 || n <= chunk) {
513+
// Bench: timestamp-query pool, null unless env-gated + feature present.
514+
WebGPUQueryPool* qp = nullptr;
515+
if (should_timestamp_query() && n > 0) {
516+
if (auto* ctx = get_default_webgpu_context()) {
517+
if (ctx->timestamp_supported) {
518+
if (!ctx->querypool || ctx->querypool->capacity() < n) {
519+
ctx->querypool = std::make_unique<WebGPUQueryPool>();
520+
ctx->querypool->initialize(device_, static_cast<uint32_t>(n));
521+
}
522+
qp = ctx->querypool.get();
523+
qp->reset(static_cast<uint32_t>(n));
524+
}
525+
}
526+
}
527+
504528
WGPUCommandEncoderDescriptor enc_desc = {};
505529
WGPUCommandEncoder encoder =
506530
wgpuDeviceCreateCommandEncoder(device_, &enc_desc);
507531

508532
// One pass per dispatch: enforces storage RAW ordering across deps.
509-
for (const auto& dispatch : dispatches_) {
533+
for (size_t i = 0; i < n; i++) {
534+
const auto& dispatch = dispatches_[i];
535+
// tw must outlive BeginComputePass (the descriptor points at it).
536+
WGPUPassTimestampWrites tw = {};
510537
WGPUComputePassDescriptor pass_desc = {};
538+
if (qp) {
539+
tw = qp->writes_for(static_cast<uint32_t>(i));
540+
pass_desc.timestampWrites = &tw;
541+
}
511542
WGPUComputePassEncoder pass =
512543
wgpuCommandEncoderBeginComputePass(encoder, &pass_desc);
513544
wgpuComputePassEncoderSetPipeline(pass, dispatch.pipeline);
@@ -517,22 +548,45 @@ void WebGPUGraph::execute() {
517548
pass, dispatch.workgroup_count_x, 1, 1);
518549
wgpuComputePassEncoderEnd(pass);
519550
wgpuComputePassEncoderRelease(pass);
551+
if (qp) {
552+
qp->record(
553+
static_cast<uint32_t>(i),
554+
dispatch.kernel_name,
555+
{dispatch.workgroup_count_x, 1, 1},
556+
{1, 1, 1});
557+
}
520558
}
521559

522560
for (const auto& copy : output_copies_) {
523561
wgpuCommandEncoderCopyBufferToBuffer(
524562
encoder, copy.src_buffer, 0, copy.staging_buffer, 0, copy.nbytes);
525563
}
526564

565+
if (qp) {
566+
qp->resolve(encoder);
567+
}
568+
527569
WGPUCommandBufferDescriptor cmd_desc = {};
528570
WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(encoder, &cmd_desc);
529571
wgpuQueueSubmit(queue_, 1, &cmd);
530572

531573
wgpuCommandBufferRelease(cmd);
532574
wgpuCommandEncoderRelease(encoder);
575+
576+
if (qp) {
577+
qp->extract_results(instance_);
578+
qp->print_results();
579+
}
533580
return;
534581
}
535582

583+
// GPU timestamp queries assume one submit; chunked execute is multi-submit.
584+
if (should_timestamp_query()) {
585+
throw std::runtime_error(
586+
"WebGPU: WEBGPU_TIMESTAMP_QUERY is incompatible with chunked execute "
587+
"(multi-submit); disable chunking to use GPU timestamp queries");
588+
}
589+
536590
const size_t first_chunk = execute_config_.initial_chunk_size > 0
537591
? execute_config_.initial_chunk_size
538592
: chunk;

backends/webgpu/runtime/WebGPUGraph.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ struct WebGPUDispatch {
3131
WGPUComputePipeline pipeline = nullptr;
3232
WGPUBindGroup bind_group = nullptr;
3333
uint32_t workgroup_count_x = 1;
34+
std::string kernel_name; // bench label
3435
};
3536

3637
struct OutputCopy {
@@ -105,6 +106,9 @@ class WebGPUGraph {
105106
int64_t get_int(int id) const {
106107
return ints_[id];
107108
}
109+
bool get_bool(int id) const {
110+
return bools_[id];
111+
}
108112

109113
// Live-scalar (SymInt) API; mirrors the Vulkan SymInt/ParamsBuffer UBO.
110114
// set_symint writes the buffer + marks dirty only if the value changed.

0 commit comments

Comments
 (0)