diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index 957862935a4..b46442d48d4 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -38,6 +38,7 @@ set(WEBGPU_SRCS runtime/ops/sdpa/Sdpa.cpp runtime/ops/select_as_symint/SelectAsSymint.cpp runtime/ops/quantized_linear/QuantizedLinear.cpp + runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp ) add_library(webgpu_backend ${WEBGPU_SRCS}) diff --git a/backends/webgpu/runtime/WebGPUBackend.cpp b/backends/webgpu/runtime/WebGPUBackend.cpp index aed769da4a4..ceca89d1710 100644 --- a/backends/webgpu/runtime/WebGPUBackend.cpp +++ b/backends/webgpu/runtime/WebGPUBackend.cpp @@ -98,20 +98,21 @@ Error WebGPUBackend::execute( const size_t num_outputs = graph->output_ids().size(); // Copy inputs from EValue tensors to GPU buffers - std::vector> inputs; + std::vector inputs; inputs.reserve(num_inputs); for (size_t i = 0; i < num_inputs; i++) { const auto& tensor = args[i]->toTensor(); - inputs.emplace_back(tensor.const_data_ptr(), tensor.nbytes()); + const bool host_is_int64 = + tensor.scalar_type() == executorch::aten::ScalarType::Long; + inputs.push_back({tensor.const_data_ptr(), tensor.nbytes(), host_is_int64}); } - graph->copy_inputs(inputs); - // Fail loud as a runtime Error so a throw never crosses the backend boundary. try { + graph->copy_inputs(inputs); graph->update_symints_from_inputs(inputs); graph->propagate_resize(); } catch (const std::exception& e) { - ET_LOG(Error, "WebGPU symint refresh/resize failed: %s", e.what()); + ET_LOG(Error, "WebGPU input copy / symint refresh failed: %s", e.what()); return Error::Internal; } diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index 1c977d130dd..8eb0c64f638 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -45,6 +45,19 @@ size_t vk_datatype_size(vkgraph::VkDataType dtype) { } } +bool vk_datatype_is_int(vkgraph::VkDataType dtype) { + switch (dtype) { + case vkgraph::VkDataType::BOOL: + case vkgraph::VkDataType::UINT8: + case vkgraph::VkDataType::INT8: + case vkgraph::VkDataType::INT32: + case vkgraph::VkDataType::INT64: + return true; + default: + return false; + } +} + } // namespace WebGPUGraph::WebGPUGraph() = default; @@ -61,7 +74,7 @@ WGPUBuffer WebGPUGraph::create_scratch_buffer(size_t nbytes) { } void WebGPUGraph::update_symints_from_inputs( - const std::vector>& inputs) { + const std::vector& inputs) { for (const auto& src : symint_sources_) { int pos = -1; for (size_t i = 0; i < input_ids_.size(); i++) { @@ -100,8 +113,8 @@ void WebGPUGraph::update_symints_from_inputs( // Reads the [0,..,index,..,0] element; symint sources are scalar-ish. const int64_t offset = static_cast(index) * stride; // elem_size back-derived from build-time numel (sources are static-shaped). - const void* host = inputs[pos].first; - const size_t elem_size = inputs[pos].second / static_cast(numel); + const void* host = inputs[pos].data; + const size_t elem_size = inputs[pos].nbytes / static_cast(numel); int32_t val; if (elem_size == sizeof(int64_t)) { val = static_cast(static_cast(host)[offset]); @@ -248,7 +261,9 @@ void WebGPUGraph::build( numel *= dims->Get(j); } } - tensor.nbytes = numel * vk_datatype_size(vk_tensor->datatype()); + tensor.elem_size = vk_datatype_size(vk_tensor->datatype()); + tensor.is_int = vk_datatype_is_int(vk_tensor->datatype()); + tensor.nbytes = numel * tensor.elem_size; int constant_id = vk_tensor->constant_id(); int mem_obj_id = vk_tensor->mem_obj_id(); @@ -484,16 +499,40 @@ WGPUBindGroupLayout WebGPUGraph::get_or_create_bgl( return bgl; } -void WebGPUGraph::copy_inputs( - const std::vector>& inputs) { +void WebGPUGraph::copy_inputs(const std::vector& inputs) { for (size_t i = 0; i < inputs.size() && i < input_ids_.size(); i++) { - if (inputs[i].second == 0) { + const InputData& in = inputs[i]; + if (in.nbytes == 0) { continue; } int tid = input_ids_[i]; const auto& tensor = tensors_[tid]; - wgpuQueueWriteBuffer( - queue_, tensor.buffer, 0, inputs[i].first, inputs[i].second); + + // Fast path: host and GPU element types match byte-for-byte. + if (in.nbytes == tensor.nbytes) { + wgpuQueueWriteBuffer(queue_, tensor.buffer, 0, in.data, tensor.nbytes); + continue; + } + + // Narrow int64 host indices into the int32 buffer (mirrors Vulkan). + const bool buffer_is_int32 = tensor.is_int && tensor.elem_size == 4; + if (in.host_is_int64 && buffer_is_int32 && in.nbytes == tensor.nbytes * 2) { + const size_t numel = tensor.nbytes / 4; + const int64_t* src = static_cast(in.data); + std::vector narrowed(numel); + for (size_t e = 0; e < numel; e++) { + narrowed[e] = static_cast(src[e]); + } + wgpuQueueWriteBuffer( + queue_, tensor.buffer, 0, narrowed.data(), tensor.nbytes); + continue; + } + + throw std::runtime_error( + "WebGPU: unsupported input copy for input " + std::to_string(i) + + " (host " + std::to_string(in.nbytes) + " bytes" + + (in.host_is_int64 ? " int64" : "") + " vs buffer " + + std::to_string(tensor.nbytes) + " bytes)"); } } diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index 3cff09ecb6d..5bd5b93b524 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -25,6 +25,16 @@ struct WebGPUTensor { WGPUBuffer buffer = nullptr; std::vector dims; size_t nbytes = 0; + // Serialized (GPU-side) element type, used to narrow wider host inputs. + size_t elem_size = 0; + bool is_int = false; +}; + +// Host-side view of one graph input, passed to copy_inputs. +struct InputData { + const void* data = nullptr; + size_t nbytes = 0; + bool host_is_int64 = false; }; struct WebGPUDispatch { @@ -75,7 +85,7 @@ class WebGPUGraph { const executorch::runtime::NamedDataMap* named_data_map = nullptr); // Copy input tensor data from host pointers into GPU buffers. - void copy_inputs(const std::vector>& inputs); + void copy_inputs(const std::vector& inputs); // Execute all recorded dispatches. void execute(); @@ -138,8 +148,7 @@ class WebGPUGraph { } // Execute-time select_as_symint read; mirrors Vulkan select_as_symint_impl. - void update_symints_from_inputs( - const std::vector>& inputs); + void update_symints_from_inputs(const std::vector& inputs); // Per-SymInt resize hook; mirrors Vulkan DynamicDispatchNode::trigger_resize. void add_resize_hook(int symint_id, std::function fn) { diff --git a/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp b/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp new file mode 100644 index 00000000000..5801b650f27 --- /dev/null +++ b/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp @@ -0,0 +1,248 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace executorch::backends::webgpu { + +namespace { + +// Uniform layout matching the WGSL Params struct (16-byte aligned, 32 bytes). +struct EmbeddingParams { + uint32_t embed_dim; + uint32_t blocks_per_row; + uint32_t num_indices; + uint32_t group_size; + uint32_t groups_per_row; + uint32_t bytes_per_row; + uint32_t total_blocks; + uint32_t _pad; +}; +static_assert( + sizeof(EmbeddingParams) == 32, + "EmbeddingParams must be 32 bytes"); + +uint64_t numel_of(const std::vector& dims) { + uint64_t n = 1; + for (int64_t d : dims) { + n *= static_cast(d); + } + return n; +} + +// arg order mirrors Vulkan EmbeddingQ4gsw.cpp. +void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector& args) { + const int weight_id = args.at(0); + const int scales_id = args.at(1); + const int group_size_id = args.at(2); + const int indices_id = args.at(3); + const int is_linear_weight_id = args.at(4); + const int out_id = args.at(5); + + WGPUDevice device = graph.device(); + + const auto& weight = graph.get_tensor(weight_id); + const auto& scales = graph.get_tensor(scales_id); + const auto& indices = graph.get_tensor(indices_id); + const auto& out = graph.get_tensor(out_id); + + // Only the flat weight path is supported (linear-block unsupported). + bool is_linear = false; + if (graph.get_value_type(is_linear_weight_id) == + WebGPUGraph::ValueType::Bool) { + is_linear = graph.get_bool(is_linear_weight_id); + } else if ( + graph.get_value_type(is_linear_weight_id) == + WebGPUGraph::ValueType::Int) { + is_linear = graph.get_int(is_linear_weight_id) != 0; + } else { + throw std::runtime_error( + "WebGPU embedding_q4gsw: is_linear_weight must be Bool or Int"); + } + if (is_linear) { + throw std::runtime_error( + "WebGPU embedding_q4gsw: is_linear_weight=true is unsupported"); + } + + if (weight.dims.size() < 2 || scales.dims.size() < 2 || out.dims.empty() || + indices.dims.empty()) { + throw std::runtime_error("WebGPU embedding_q4gsw: malformed dims"); + } + + const uint32_t embed_dim = static_cast(out.dims.back()); + if (embed_dim == 0 || embed_dim % 32 != 0) { + throw std::runtime_error( + "WebGPU embedding_q4gsw: embed_dim must be a nonzero multiple of 32"); + } + if (static_cast(weight.dims[1]) * 2 != embed_dim) { + throw std::runtime_error( + "WebGPU embedding_q4gsw: weight row stride mismatch (embed_dim/2)"); + } + + int64_t group_size = 0; + if (graph.get_value_type(group_size_id) == WebGPUGraph::ValueType::Int) { + group_size = graph.get_int(group_size_id); + } + if (group_size <= 0) { + throw std::runtime_error("WebGPU embedding_q4gsw: group_size <= 0"); + } + + // Leading index dims flatten row-major (mirrors Vulkan num_indices). + const uint64_t out_numel = numel_of(out.dims); + const uint32_t num_indices = static_cast(out_numel / embed_dim); + const uint32_t groups_per_row = static_cast(scales.dims[1]); + const uint32_t blocks_per_row = embed_dim / 32u; + const uint32_t bytes_per_row = embed_dim / 2u; + const uint64_t total_blocks = + static_cast(num_indices) * blocks_per_row; + if (static_cast(groups_per_row) * group_size != embed_dim) { + throw std::runtime_error( + "WebGPU embedding_q4gsw: groups_per_row * group_size != embed_dim"); + } + if (weight.buffer == nullptr || scales.buffer == nullptr || + indices.buffer == nullptr || out.buffer == nullptr) { + throw std::runtime_error("WebGPU embedding_q4gsw: null buffer binding"); + } + + // Per-type byte guards (no runtime dtype): indices i32, weight u8, fp32 rest. + const uint64_t indices_numel = numel_of(indices.dims); + const uint64_t weight_numel = numel_of(weight.dims); + const uint64_t scales_numel = numel_of(scales.dims); + if (indices_numel != num_indices || + indices.nbytes != indices_numel * sizeof(int32_t) || + weight.nbytes != weight_numel || + scales.nbytes != scales_numel * sizeof(float) || + out.nbytes != out_numel * sizeof(float)) { + throw std::runtime_error( + "WebGPU embedding_q4gsw: dtype/byte-size mismatch " + "(indices int32, weight uint8, scales/out fp32)"); + } + if (total_blocks > UINT32_MAX) { + throw std::runtime_error( + "WebGPU embedding_q4gsw: total_blocks exceeds uint32 dispatch range"); + } + + // 1D dispatch: one thread per 32-dim block; validate before any alloc. + const uint32_t wg_size = + utils::clamp_workgroup_size(device, kEmbeddingQ4gswWorkgroupSizeX); + const uint32_t workgroup_count = utils::compute_1d_workgroup_count( + device, static_cast(total_blocks), wg_size, "embedding_q4gsw"); + + EmbeddingParams params = {}; + params.embed_dim = embed_dim; + params.blocks_per_row = blocks_per_row; + params.num_indices = num_indices; // std140 layout only; shader derives it + params.group_size = static_cast(group_size); + params.groups_per_row = groups_per_row; + params.bytes_per_row = bytes_per_row; + params.total_blocks = static_cast(total_blocks); + + WGPUBufferDescriptor uniform_desc = {}; + uniform_desc.size = sizeof(EmbeddingParams); + uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; + uniform_desc.mappedAtCreation = true; + WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc); + void* mapped = + wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(EmbeddingParams)); + std::memcpy(mapped, ¶ms, sizeof(EmbeddingParams)); + wgpuBufferUnmap(uniform_buffer); + graph.add_uniform_buffer_bytes(sizeof(EmbeddingParams)); + + WGPUShaderSourceWGSL wgsl_desc = {}; + wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; + wgsl_desc.code = {kEmbeddingQ4gswWGSL, WGPU_STRLEN}; + WGPUShaderModuleDescriptor shader_desc = {}; + shader_desc.nextInChain = &wgsl_desc.chain; + WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); + + // Bind group layout: out (rw) + indices/weight/scales (ro storage) + uniform. + WGPUBindGroupLayoutEntry entries[5] = {}; + entries[0].binding = 0; + entries[0].visibility = WGPUShaderStage_Compute; + entries[0].buffer.type = WGPUBufferBindingType_Storage; + for (uint32_t i = 1; i <= 3; i++) { + entries[i].binding = i; + entries[i].visibility = WGPUShaderStage_Compute; + entries[i].buffer.type = WGPUBufferBindingType_ReadOnlyStorage; + } + entries[4].binding = 4; + entries[4].visibility = WGPUShaderStage_Compute; + entries[4].buffer.type = WGPUBufferBindingType_Uniform; + + WGPUBindGroupLayoutDescriptor bgl_desc = {}; + bgl_desc.entryCount = 5; + bgl_desc.entries = entries; + WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc); + + WGPUPipelineLayoutDescriptor pl_desc = {}; + pl_desc.bindGroupLayoutCount = 1; + pl_desc.bindGroupLayouts = &bgl; + WGPUPipelineLayout pipeline_layout = + wgpuDeviceCreatePipelineLayout(device, &pl_desc); + + WGPUConstantEntry wg_size_constant = {}; + wg_size_constant.key = {"wg_size", WGPU_STRLEN}; + wg_size_constant.value = static_cast(wg_size); + + WGPUComputePipelineDescriptor pipeline_desc = {}; + pipeline_desc.layout = pipeline_layout; + pipeline_desc.compute.module = shader; + pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; + pipeline_desc.compute.constantCount = 1; + pipeline_desc.compute.constants = &wg_size_constant; + WGPUComputePipeline pipeline = + wgpuDeviceCreateComputePipeline(device, &pipeline_desc); + + WGPUBindGroupEntry bg_entries[5] = {}; + bg_entries[0].binding = 0; + bg_entries[0].buffer = out.buffer; + bg_entries[0].size = out.nbytes; + bg_entries[1].binding = 1; + bg_entries[1].buffer = indices.buffer; + bg_entries[1].size = indices.nbytes; + bg_entries[2].binding = 2; + bg_entries[2].buffer = weight.buffer; + bg_entries[2].size = weight.nbytes; + bg_entries[3].binding = 3; + bg_entries[3].buffer = scales.buffer; + bg_entries[3].size = scales.nbytes; + bg_entries[4].binding = 4; + bg_entries[4].buffer = uniform_buffer; + bg_entries[4].size = sizeof(EmbeddingParams); + + WGPUBindGroupDescriptor bg_desc = {}; + bg_desc.layout = bgl; + bg_desc.entryCount = 5; + bg_desc.entries = bg_entries; + WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); + + graph.add_dispatch( + {pipeline, bind_group, workgroup_count, "embedding_q4gsw"}); + + wgpuShaderModuleRelease(shader); + wgpuBindGroupLayoutRelease(bgl); + wgpuPipelineLayoutRelease(pipeline_layout); + wgpuBufferRelease(uniform_buffer); +} + +} // namespace + +WEBGPU_REGISTER_OPERATORS { + WEBGPU_REGISTER_OP(et_vk.embedding_q4gsw.default, embedding_q4gsw_impl); +} + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw.wgsl b/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw.wgsl new file mode 100644 index 00000000000..f16f3760d1c --- /dev/null +++ b/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw.wgsl @@ -0,0 +1,50 @@ +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_indices: array; +@group(0) @binding(2) var t_weight: array; +@group(0) @binding(3) var t_scales: array; + +struct Params { + embed_dim: u32, + blocks_per_row: u32, + num_indices: u32, + group_size: u32, + groups_per_row: u32, + bytes_per_row: u32, + total_blocks: u32, + _pad: u32, +} +@group(0) @binding(4) var params: Params; + +override wg_size: u32 = 64u; + +// One thread per 32-dim block of one gathered row (flat-buffer weight path). +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let block = gid.x; + if (block >= params.total_blocks) { + return; + } + let indices_idx = block / params.blocks_per_row; + let base_dim = (block % params.blocks_per_row) * 32u; + + // token assumed in-range (mirrors Vulkan; no vocab clamp). + let token = u32(t_indices[indices_idx]); + let row_byte_base = token * params.bytes_per_row; + let out_base = indices_idx * params.embed_dim + base_dim; + + for (var t: u32 = 0u; t < 32u; t = t + 1u) { + let dim = base_dim + t; + let byte_idx = row_byte_base + (dim >> 1u); + let word = t_weight[byte_idx >> 2u]; + let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu; + var nib: u32; + if ((dim & 1u) == 0u) { + nib = (b >> 4u) & 0x0Fu; // even dim -> high nibble + } else { + nib = b & 0x0Fu; // odd dim -> low nibble + } + let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7] + let scale = t_scales[token * params.groups_per_row + dim / params.group_size]; + t_out[out_base + t] = q * scale; + } +} diff --git a/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw_wgsl.h b/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw_wgsl.h new file mode 100644 index 00000000000..e44c06a2ac5 --- /dev/null +++ b/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw_wgsl.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from embedding_q4gsw.wgsl - DO NOT EDIT. +// wgsl-sha256: 1fec9ed315696a88bb7db6c16454fc80e08ff73b0e39720b54515fda4ee1ef7c +inline constexpr const char* kEmbeddingQ4gswWGSL = R"( +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_indices: array; +@group(0) @binding(2) var t_weight: array; +@group(0) @binding(3) var t_scales: array; + +struct Params { + embed_dim: u32, + blocks_per_row: u32, + num_indices: u32, + group_size: u32, + groups_per_row: u32, + bytes_per_row: u32, + total_blocks: u32, + _pad: u32, +} +@group(0) @binding(4) var params: Params; + +override wg_size: u32 = 64u; + +// One thread per 32-dim block of one gathered row (flat-buffer weight path). +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let block = gid.x; + if (block >= params.total_blocks) { + return; + } + let indices_idx = block / params.blocks_per_row; + let base_dim = (block % params.blocks_per_row) * 32u; + + // token assumed in-range (mirrors Vulkan; no vocab clamp). + let token = u32(t_indices[indices_idx]); + let row_byte_base = token * params.bytes_per_row; + let out_base = indices_idx * params.embed_dim + base_dim; + + for (var t: u32 = 0u; t < 32u; t = t + 1u) { + let dim = base_dim + t; + let byte_idx = row_byte_base + (dim >> 1u); + let word = t_weight[byte_idx >> 2u]; + let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu; + var nib: u32; + if ((dim & 1u) == 0u) { + nib = (b >> 4u) & 0x0Fu; // even dim -> high nibble + } else { + nib = b & 0x0Fu; // odd dim -> low nibble + } + let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7] + let scale = t_scales[token * params.groups_per_row + dim / params.group_size]; + t_out[out_base + t] = q * scale; + } +} +)"; + +inline constexpr uint32_t kEmbeddingQ4gswWorkgroupSizeX = 64; +inline constexpr uint32_t kEmbeddingQ4gswWorkgroupSizeY = 1; +inline constexpr uint32_t kEmbeddingQ4gswWorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu