diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index 55d188516d40..0f5ba7bcf1cd 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -697,6 +697,14 @@ export class WebGPUContext { bindGroupLayouts: [bindGroupLayout] }); + // Pre-allocate typed array views for pod args (reused across dispatches) + const numPodSlots = podArgIndices.length + 1; // +1 for packDimX + const podArgBytes = numPodSlots * Int32Array.BYTES_PER_ELEMENT; + const podArgsArrayBuffer = new ArrayBuffer(podArgBytes); + const i32ViewCached = new Int32Array(podArgsArrayBuffer); + const u32ViewCached = new Uint32Array(podArgsArrayBuffer); + const f32ViewCached = new Float32Array(podArgsArrayBuffer); + // Function to create the pipeline. const createShaderFunc = (pipeline: GPUComputePipeline): Function => { const submitShader = (...args: Array): void => { @@ -756,35 +764,30 @@ export class WebGPUContext { }); } - const sizeOfI32 = 4; - const bufBytes = (podArgIndices.length + 1) * sizeOfI32; - const podArgBuffer = this.getUniformFromPool(bufBytes); - const i32View = new Int32Array(podArgIndices.length + 1); - const u32View = new Uint32Array(i32View.buffer); - const f32View = new Float32Array(i32View.buffer); + const podArgBuffer = this.getUniformFromPool(podArgBytes); for (let i = 0; i < podArgIndices.length; ++i) { const value = args[podArgIndices[i]]; const dtype = finfo.arg_types[podArgIndices[i]]; if (dtype.startsWith("int")) { - i32View[i] = value; + i32ViewCached[i] = value; } else if (dtype.startsWith("uint")) { - u32View[i] = value; + u32ViewCached[i] = value; } else if (dtype.startsWith("float")) { - f32View[i] = value; + f32ViewCached[i] = value; } else { throw Error("Unknown pod dtype " + dtype); } } - // always pass in dim z launching grid size in - u32View[podArgIndices.length] = packDimX; - this.device.queue.writeBuffer(podArgBuffer, 0, i32View.buffer); + // Pass the original grid X dimension so the shader can recover blockIdx.x from the z-split + u32ViewCached[podArgIndices.length] = packDimX; + this.device.queue.writeBuffer(podArgBuffer, 0, podArgsArrayBuffer); bindGroupEntries.push({ binding: bufferArgIndices.length, resource: { buffer: podArgBuffer, - size: i32View.buffer.byteLength + size: podArgsArrayBuffer.byteLength } });