Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ set(WEBGPU_SRCS
runtime/WebGPUGraph.cpp
runtime/WebGPUDelegateHeader.cpp
runtime/WebGPUDevice.cpp
runtime/WebGPUQueryPool.cpp
runtime/ops/OperatorRegistry.cpp
runtime/ops/add/BinaryOp.cpp
runtime/ops/rms_norm/RmsNorm.cpp
runtime/ops/update_cache/UpdateCache.cpp
runtime/ops/sdpa/Sdpa.cpp
runtime/ops/select_as_symint/SelectAsSymint.cpp
runtime/ops/quantized_linear/QuantizedLinear.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down Expand Up @@ -76,6 +79,17 @@ endif()

target_compile_options(webgpu_backend PRIVATE -fexceptions)

# Opt-in GPU timestamp profiling (WebGPUQueryPool); OFF so production builds
# request no TimestampQuery device feature. Mirrors Vulkan's compile-flag gate.
option(EXECUTORCH_BUILD_WEBGPU_PROFILING
"Enable WebGPU GPU timestamp-query profiling" OFF
)
if(EXECUTORCH_BUILD_WEBGPU_PROFILING)
target_compile_definitions(
webgpu_backend PRIVATE WGPU_BACKEND_ENABLE_PROFILING
)
endif()

# Link with --whole-archive for static registration of backend + ops
executorch_target_link_options_shared_lib(webgpu_backend)

Expand Down Expand Up @@ -114,6 +128,11 @@ function(add_webgpu_native_test test_name test_src)
target_link_libraries(${test_name} PRIVATE dl m pthread)
endif()
target_compile_options(${test_name} PRIVATE -fexceptions)
if(EXECUTORCH_BUILD_WEBGPU_PROFILING)
target_compile_definitions(
${test_name} PRIVATE WGPU_BACKEND_ENABLE_PROFILING
)
endif()
set_property(TARGET ${test_name} PROPERTY CXX_STANDARD 17)
endfunction()

Expand Down
19 changes: 19 additions & 0 deletions backends/webgpu/runtime/WebGPUDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include <cstdlib>
#include <memory>
#include <stdexcept>
#ifdef WGPU_BACKEND_ENABLE_PROFILING
#include <vector>
#endif // WGPU_BACKEND_ENABLE_PROFILING

namespace executorch {
namespace backends {
Expand Down Expand Up @@ -137,6 +140,18 @@ WebGPUContext create_webgpu_context() {
WGPUStatus_Success) {
device_desc.requiredLimits = &supported_limits;
}

#ifdef WGPU_BACKEND_ENABLE_PROFILING
// Bench: enable TimestampQuery if available; fail-open (skip timing if not).
std::vector<WGPUFeatureName> required_features;
if (wgpuAdapterHasFeature(ctx.adapter, WGPUFeatureName_TimestampQuery)) {
required_features.push_back(WGPUFeatureName_TimestampQuery);
device_desc.requiredFeatureCount = required_features.size();
device_desc.requiredFeatures = required_features.data();
ctx.timestamp_supported = true;
}
#endif // WGPU_BACKEND_ENABLE_PROFILING

device_desc.uncapturedErrorCallbackInfo.callback = on_device_error;

WGPUWaitStatus device_wait = webgpu_wait(
Expand Down Expand Up @@ -192,6 +207,10 @@ WebGPUContext* get_default_webgpu_context() {
}

void destroy_webgpu_context(WebGPUContext& ctx) {
#ifdef WGPU_BACKEND_ENABLE_PROFILING
// Release device-child GPU resources before the device handle.
ctx.querypool.reset();
#endif // WGPU_BACKEND_ENABLE_PROFILING
if (ctx.queue) {
wgpuQueueRelease(ctx.queue);
ctx.queue = nullptr;
Expand Down
12 changes: 12 additions & 0 deletions backends/webgpu/runtime/WebGPUDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@

#include <webgpu/webgpu.h>

#ifdef WGPU_BACKEND_ENABLE_PROFILING
#include <executorch/backends/webgpu/runtime/WebGPUQueryPool.h>

#include <memory>
#endif // WGPU_BACKEND_ENABLE_PROFILING

namespace executorch {
namespace backends {
namespace webgpu {
Expand All @@ -19,6 +25,12 @@ struct WebGPUContext {
WGPUAdapter adapter = nullptr;
WGPUDevice device = nullptr;
WGPUQueue queue = nullptr;
#ifdef WGPU_BACKEND_ENABLE_PROFILING
// True if the device was created with the TimestampQuery feature (bench).
bool timestamp_supported = false;
// Bench-only: timestamp-query pool, lazily created in execute() (env-gated).
std::unique_ptr<WebGPUQueryPool> querypool;
#endif // WGPU_BACKEND_ENABLE_PROFILING
};

WebGPUContext create_webgpu_context();
Expand Down
71 changes: 70 additions & 1 deletion backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <executorch/backends/webgpu/runtime/WebGPUCompat.h>
#include <executorch/backends/webgpu/runtime/WebGPUDevice.h>

#include <cstdlib>
#include <cstring>
#include <stdexcept>

Expand Down Expand Up @@ -496,18 +497,57 @@ void WebGPUGraph::copy_inputs(
}
}

namespace {
// Bench gate: compiled out unless WGPU_BACKEND_ENABLE_PROFILING; then the
// WEBGPU_TIMESTAMP_QUERY env var enables per-pass GPU timestamp queries.
bool should_timestamp_query() {
#ifdef WGPU_BACKEND_ENABLE_PROFILING
static const bool enabled = std::getenv("WEBGPU_TIMESTAMP_QUERY") != nullptr;
return enabled;
#else
return false;
#endif
}
} // namespace

void WebGPUGraph::execute() {
const size_t n = dispatches_.size();
const size_t chunk = execute_config_.chunk_size;

if (chunk == 0 || n <= chunk) {
#ifdef WGPU_BACKEND_ENABLE_PROFILING
// Bench: timestamp-query pool, null unless env-gated + feature present.
WebGPUQueryPool* qp = nullptr;
if (should_timestamp_query() && n > 0) {
if (auto* ctx = get_default_webgpu_context()) {
if (ctx->timestamp_supported) {
if (!ctx->querypool || ctx->querypool->capacity() < n) {
ctx->querypool = std::make_unique<WebGPUQueryPool>();
ctx->querypool->initialize(device_, static_cast<uint32_t>(n));
}
qp = ctx->querypool.get();
qp->reset(static_cast<uint32_t>(n));
}
}
}
#endif // WGPU_BACKEND_ENABLE_PROFILING

WGPUCommandEncoderDescriptor enc_desc = {};
WGPUCommandEncoder encoder =
wgpuDeviceCreateCommandEncoder(device_, &enc_desc);

// One pass per dispatch: enforces storage RAW ordering across deps.
for (const auto& dispatch : dispatches_) {
for (size_t i = 0; i < n; i++) {
const auto& dispatch = dispatches_[i];
WGPUComputePassDescriptor pass_desc = {};
#ifdef WGPU_BACKEND_ENABLE_PROFILING
// tw must outlive BeginComputePass (the descriptor points at it).
WGPUPassTimestampWrites tw = {};
if (qp) {
tw = qp->writes_for(static_cast<uint32_t>(i));
pass_desc.timestampWrites = &tw;
}
#endif // WGPU_BACKEND_ENABLE_PROFILING
WGPUComputePassEncoder pass =
wgpuCommandEncoderBeginComputePass(encoder, &pass_desc);
wgpuComputePassEncoderSetPipeline(pass, dispatch.pipeline);
Expand All @@ -517,22 +557,51 @@ void WebGPUGraph::execute() {
pass, dispatch.workgroup_count_x, 1, 1);
wgpuComputePassEncoderEnd(pass);
wgpuComputePassEncoderRelease(pass);
#ifdef WGPU_BACKEND_ENABLE_PROFILING
if (qp) {
qp->record(
static_cast<uint32_t>(i),
dispatch.kernel_name,
{dispatch.workgroup_count_x, 1, 1},
{1, 1, 1});
}
#endif // WGPU_BACKEND_ENABLE_PROFILING
}

for (const auto& copy : output_copies_) {
wgpuCommandEncoderCopyBufferToBuffer(
encoder, copy.src_buffer, 0, copy.staging_buffer, 0, copy.nbytes);
}

#ifdef WGPU_BACKEND_ENABLE_PROFILING
if (qp) {
qp->resolve(encoder);
}
#endif // WGPU_BACKEND_ENABLE_PROFILING

WGPUCommandBufferDescriptor cmd_desc = {};
WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(encoder, &cmd_desc);
wgpuQueueSubmit(queue_, 1, &cmd);

wgpuCommandBufferRelease(cmd);
wgpuCommandEncoderRelease(encoder);

#ifdef WGPU_BACKEND_ENABLE_PROFILING
if (qp) {
qp->extract_results(instance_);
qp->print_results();
}
#endif // WGPU_BACKEND_ENABLE_PROFILING
return;
}

// GPU timestamp queries assume one submit; chunked execute is multi-submit.
if (should_timestamp_query()) {
throw std::runtime_error(
"WebGPU: WEBGPU_TIMESTAMP_QUERY is incompatible with chunked execute "
"(multi-submit); disable chunking to use GPU timestamp queries");
}

const size_t first_chunk = execute_config_.initial_chunk_size > 0
? execute_config_.initial_chunk_size
: chunk;
Expand Down
4 changes: 4 additions & 0 deletions backends/webgpu/runtime/WebGPUGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct WebGPUDispatch {
WGPUComputePipeline pipeline = nullptr;
WGPUBindGroup bind_group = nullptr;
uint32_t workgroup_count_x = 1;
std::string kernel_name; // bench label
};

struct OutputCopy {
Expand Down Expand Up @@ -105,6 +106,9 @@ class WebGPUGraph {
int64_t get_int(int id) const {
return ints_[id];
}
bool get_bool(int id) const {
return bools_[id];
}

// Live-scalar (SymInt) API; mirrors the Vulkan SymInt/ParamsBuffer UBO.
// set_symint writes the buffer + marks dirty only if the value changed.
Expand Down
Loading
Loading