diff --git a/gimmik/__init__.py b/gimmik/__init__.py index b32ebdc..cd21134 100644 --- a/gimmik/__init__.py +++ b/gimmik/__init__.py @@ -8,6 +8,7 @@ from gimmik.hip import HIPMatMul from gimmik.metal import MetalMatMul from gimmik.opencl import OpenCLMatMul +from gimmik.ptx import PTXMatMul def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm', @@ -22,7 +23,8 @@ def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm', 'cuda': CUDAMatMul, 'ispc': ISPCMatMul, 'hip': HIPMatMul, - 'opencl': OpenCLMatMul + 'opencl': OpenCLMatMul, + 'ptx': PTXMatMul } mm = platmap[platform](alpha*mat, beta, None, n, ldb, ldc) diff --git a/gimmik/base.py b/gimmik/base.py index f547afc..0ecc29a 100644 --- a/gimmik/base.py +++ b/gimmik/base.py @@ -144,7 +144,8 @@ def _render_kernel(self, dtype, tplname, tplargs): src = tpl.render(**tplargs) # At single precision suffix all floating point constants by 'f' - if dtype == 'float': + # (PTX doesn't use an 'f' suffix for FP literals) + if dtype == 'float' and self.platform != 'ptx': src = re.sub(r'(?=\d*[.eE])(?=\.?\d)\d*\.?\d*(?:[eE][+-]?\d+)?', r'\g<0>f', src) diff --git a/gimmik/kernels/ptx/base.mako b/gimmik/kernels/ptx/base.mako new file mode 100644 index 0000000..b64ddc1 --- /dev/null +++ b/gimmik/kernels/ptx/base.mako @@ -0,0 +1,4 @@ +.version 8.7 +.target sm_${cc[0]}${cc[1]}${"a" if cc[0] >= 9 else ""} +.address_size 64 +${next.body()} diff --git a/gimmik/kernels/ptx/bstream-msplit.mako b/gimmik/kernels/ptx/bstream-msplit.mako new file mode 100644 index 0000000..0af5091 --- /dev/null +++ b/gimmik/kernels/ptx/bstream-msplit.mako @@ -0,0 +1,281 @@ +<%inherit file='base'/> + +<% +pftype = 'f32' if dtype == 'float' else 'f64' +dwidth_i = 4 if dtype == 'float' else 8 +fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' +has_zero_rows = any(jx == -1 for jx in afix) +mx = partition(A, into=msplit, by='rows') +bix_list = list(bix) +bchunks = chunk(bix_list, bsz) +m_per_group = max(len(mcx) for mcx in mx) +bsub_bytes = 2 * bsz * blockx * dwidth_i +def bsub_off(buf, idx): + return (buf * bsz + idx) * blockx * dwidth_i +use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, bsub_thread; +% if use_cpasync: + .reg .u32 bsub_sm_thread; +% endif + .reg .${pftype} bv, csub<${m_per_group}>; + .reg .pred p1, p_skip; + .shared .align 8 .b8 _bsub[${bsub_bytes}]; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 bsub_thread, _bsub; + add.u64 bsub_thread, bsub_thread, _tx_off; + } +% if use_cpasync: + { + .reg .u64 _sm64; + cvta.to.shared.u64 _sm64, bsub_thread; + cvt.u32.u64 bsub_sm_thread, _sm64; + } +% endif + +% for cid, mcx in enumerate(mx): +## cid = ${cid}, rows ${mcx} + setp.ne.u32 p_skip, tid_y, ${cid}; + @p_skip bra $L_END_CID_${cid}; + +% if use_cpasync: +## Async fill of chunk 0 +% for idx, kx in enumerate(bchunks[0]): +% if idx % msplit == cid: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [_bptr], ${dwidth_i}; + } +% else: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; +% endif +% endif +% endfor + cp.async.commit_group; + cp.async.wait_all; + bar.sync 0; +% else: +## Sync fill of chunk 0 +% for idx, kx in enumerate(bchunks[0]): +% if idx % msplit == cid: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + .reg .${pftype} _bv; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.global.cg.${pftype} _bv, [_bptr]; + st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; + } +% else: + { + .reg .${pftype} _bv; + ld.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; + st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; + } +% endif +% endif +% endfor + bar.sync 0; +% endif + +## Main loop over B-chunks (double-buffered) +% for bb in range(len(bchunks)): +<% + buf_cur = bb % 2 + buf_next = (bb + 1) % 2 + is_last = (bb == len(bchunks) - 1) +%> +% if not is_last: +% for idx, kx in enumerate(bchunks[bb + 1]): +% if idx % msplit == cid: +% if use_cpasync: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [_bptr], ${dwidth_i}; + } +% else: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; +% endif +% else: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + .reg .${pftype} _bv; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.global.cg.${pftype} _bv, [_bptr]; + st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; + } +% else: + { + .reg .${pftype} _bv; + ld.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; + st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; + } +% endif +% endif +% endif +% endfor +% if use_cpasync: + cp.async.commit_group; +% endif +% endif + +% for idx, kx in enumerate(bchunks[bb]): + ld.shared.${pftype} bv, [bsub_thread + ${bsub_off(buf_cur, idx)}]; +% for j, row_j in enumerate(mcx): +<% jx = A[row_j, kx] %> +% if jx != 0 and kx == afix[row_j]: + mul.${pftype} csub${j}, bv, ${jx}; +% elif jx != 0: + fma.rn.${pftype} csub${j}, bv, ${jx}, csub${j}; +% endif +% if kx == alix[row_j]: +% if beta == 0: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], csub${j}; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.global.${pftype} [_cptr], _ctmp; +% else: + ld.global.${pftype} _ctmp, [c_base + ${ldc*row_j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _ctmp; +% endif + } +% endif +% endif +% endfor +% endfor +% if use_cpasync: +% if not is_last: + cp.async.wait_all; +% endif +% endif + bar.sync 0; +% endfor +## End of Main loop over B-chunks + +## Handle zero rows in this cid's group +% if has_zero_rows: +% for row_j in mcx: +% if afix[row_j] == -1: +% if beta == 0: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif + } +% elif beta != 1: + { + .reg .${pftype} _tmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [_cptr], _tmp; +% else: + ld.global.${pftype} _tmp, [c_base + ${ldc*row_j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif + } +% endif +% endif +% endfor +% endif + +$L_END_CID_${cid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako new file mode 100644 index 0000000..f58e9b3 --- /dev/null +++ b/gimmik/kernels/ptx/bstream.mako @@ -0,0 +1,172 @@ +<%inherit file='base'/> + +<% +pftype = 'f32' if dtype == 'float' else 'f64' +dwidth_i = 4 if dtype == 'float' else 8 +fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' +has_zero_rows = any(jx == -1 for jx in afix) +bix_list = list(bix) +bix_pos = {kx: i for i, kx in enumerate(bix_list)} +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .${pftype} csub<${m}>, bv<${len(bix_list)}>; + .reg .pred p1; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + +## Batch-load active B columns +% for i, kx in enumerate(bix_list): +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} bv${i}, [_bptr]; + } +% else: + ld.weak.global.cg.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +% endfor + +% if beta != 0: +## Pre-load C so per-row completion is a plain store +% for j in range(m): +% if afix[j] != -1: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} csub${j}, [_cptr]; + } +% else: + ld.weak.global.cg.${pftype} csub${j}, [c_base + ${ldc*j*dwidth_i}]; +% endif +% endif +% endfor +% if beta != 0 and beta != 1: +% for j in range(m): +% if afix[j] != -1: + mul.${pftype} csub${j}, csub${j}, ${float(beta)}; +% endif +% endfor +% endif +% endif + +## Main compute +% for kx in bix_list: +% for j, jx in enumerate(A[:, kx]): +% if jx != 0: +% if preload_c: + fma.rn.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}, csub${j}; +% elif kx == afix[j]: + mul.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}; +% else: + fma.rn.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}, csub${j}; +% endif +% endif +% if kx == alix[j]: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], csub${j}; +% endif + +% endif +% endfor +% endfor + +% if has_zero_rows: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% for j, jx in enumerate(afix): +% if jx == -1 and beta == 0: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif + +% elif jx == -1: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [_cptr], _tmp; + } +% else: + ld.global.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif +% endif +% endfor + } +% endif + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream-ksplit.mako b/gimmik/kernels/ptx/cstream-ksplit.mako new file mode 100644 index 0000000..1ba2491 --- /dev/null +++ b/gimmik/kernels/ptx/cstream-ksplit.mako @@ -0,0 +1,179 @@ +<%inherit file='base'/> + +<% +pftype = 'f32' if dtype == 'float' else 'f64' +dwidth_i = 4 if dtype == 'float' else 8 +fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' +kparts = partition(A, ksplit, by='cols') +cchunks = chunk(list(range(m)), csz) +cv_per_thread = -(-csz // ksplit) +bv_per_thread = max(len(kbx) for kbx in kparts) +csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, csub_thread; + .reg .${pftype} bv<${bv_per_thread}>, cv<${cv_per_thread}>, dotp; + .reg .pred p1, p_skip; + .shared .align 8 .b8 _csub[${csub_bytes}]; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 csub_thread, _csub; + add.u64 csub_thread, csub_thread, _tx_off; + } + +% for bid, kbx in enumerate(kparts): +## bid = ${bid}: ${len(kbx)} B columns, ksplit=${ksplit} + setp.ne.u32 p_skip, tid_y, ${bid}; + @p_skip bra $L_END_BID_${bid}; + +<% + loaded = set() + kbx_idx = {kx: i for i, kx in enumerate(kbx)} +%> + +% for cchunk_i, cchunk in enumerate(cchunks): +## Chunk ${cchunk_i}: partial dot-product +% for row_idx, j in enumerate(cchunk): +<% + nz = [(kbx_idx[kx], kx, A[j, kx]) for kx in kbx if A[j, kx] != 0] + owner_bid = row_idx % ksplit +%> +% for (kxi, kx, jx) in nz: +% if kx not in loaded: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.global.nc.${pftype} bv${kxi}, [_bptr]; + } +% else: + ld.global.nc.${pftype} bv${kxi}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +<% loaded.add(kx) %> +% endif +% endfor +% if nz: +% for i, (kxi, kx, jx) in enumerate(nz): +% if i == 0: + mul.${pftype} dotp, bv${kxi}, ${jx}; +% else: + fma.rn.${pftype} dotp, bv${kxi}, ${jx}, dotp; +% endif +% endfor +% else: + mov.${pftype} dotp, ${fzero}; +% endif +% if owner_bid == bid: + mov.${pftype} cv${row_idx // ksplit}, dotp; +% else: +<% csub_idx = bid - (1 if bid > owner_bid else 0) %> + st.shared.${pftype} [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}], dotp; +% endif +% endfor + bar.sync 0; + +## Combine phase (owned rows only) +% for row_idx, j in enumerate(cchunk): +% if row_idx % ksplit == bid: + mov.${pftype} dotp, cv${row_idx // ksplit}; +% for other_bid in range(ksplit): +% if other_bid != bid: +<% csub_idx = other_bid - (1 if other_bid > (row_idx % ksplit) else 0) %> + { + .reg .${pftype} _tmp; + ld.shared.${pftype} _tmp, [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}]; + add.${pftype} dotp, dotp, _tmp; + } +% endif +% endfor +% if beta == 0: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.global.${pftype} [_cptr], _ctmp; +% else: + ld.global.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif + } +% endif + +% endif +% endfor + bar.sync 0; +% endfor + +$L_END_BID_${bid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream-w2.mako b/gimmik/kernels/ptx/cstream-w2.mako new file mode 100644 index 0000000..c82ebab --- /dev/null +++ b/gimmik/kernels/ptx/cstream-w2.mako @@ -0,0 +1,96 @@ +<%inherit file='base'/> + +<% +pftype = 'f64' +dwidth_i = 8 +fzero = '0d0000000000000000' +bix_list = list(bix) +bix_pos = {kx: i for i, kx in enumerate(bix_list)} +row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] + for j in range(m)] +%> + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .f64 bv_a<${len(bix_list)}>, bv_b<${len(bix_list)}>, dotp_a, dotp_b; + .reg .pred p1; + + mov.u32 n, ${-(-n // 2)}; + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x, _tid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 _tid_x, %tid.x; + mad.lo.u32 id, _ctaid_x, ${blockx}, _tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, 16, b; + mad.lo.u64 c_base, _id64, 16, c; + } + +## Batch-load B column pairs +% for i, kx in enumerate(bix_list): + ld.global.nc.v2.f64 {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; +% endfor + +## Main compute: two parallel dot-product streams per thread +% for j in range(m): +% if row_nz[j]: +% for i_nz, (kx, jx) in enumerate(row_nz[j]): +% if i_nz == 0: + mul.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}; + mul.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}; +% else: + fma.rn.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}, dotp_a; + fma.rn.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}, dotp_b; +% endif +% endfor +% if beta == 0: + st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; +% else: + { + .reg .f64 _ca, _cb; + ld.global.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.f64 _ca, _ca, ${float(beta)}, dotp_a; + fma.rn.f64 _cb, _cb, ${float(beta)}, dotp_b; + st.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif + +% else: +## Zero row of A +% if beta == 0: + { + .reg .f64 _z; + mov.f64 _z, ${fzero}; + st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_z, _z}; + } +% elif beta != 1: + { + .reg .f64 _ca, _cb; + ld.global.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + mul.f64 _ca, _ca, ${float(beta)}; + mul.f64 _cb, _cb, ${float(beta)}; + st.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif +% endif +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream.mako b/gimmik/kernels/ptx/cstream.mako new file mode 100644 index 0000000..ec46934 --- /dev/null +++ b/gimmik/kernels/ptx/cstream.mako @@ -0,0 +1,157 @@ +<%inherit file='base'/> + +<% +pftype = 'f32' if dtype == 'float' else 'f64' +dwidth_i = 4 if dtype == 'float' else 8 +fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' +bix_list = list(bix) +bix_pos = {kx: i for i, kx in enumerate(bix_list)} +row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] + for j in range(m)] +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .${pftype} bv<${len(bix_list)}>, dotp; + .reg .pred p1; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + +## Batch-load active B columns +% for i, kx in enumerate(bix_list): +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.global.nc.${pftype} bv${i}, [_bptr]; + } +% else: + ld.global.nc.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +% endfor + +## Compute and store each output row +% for j in range(m): +% if row_nz[j]: +% for i_nz, (kx, jx) in enumerate(row_nz[j]): +% if i_nz == 0: + mul.${pftype} dotp, bv${bix_pos[kx]}, ${jx}; +% else: + fma.rn.${pftype} dotp, bv${bix_pos[kx]}, ${jx}, dotp; +% endif +% endfor +% if beta == 0: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.global.${pftype} [_cptr], _ctmp; +% else: + ld.global.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif + } +% endif + +% else: +## Zero row of A +% if beta == 0: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif + } +% elif beta != 1: + { + .reg .${pftype} _tmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [_cptr], _tmp; +% else: + ld.global.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif + } +% endif +% endif +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako new file mode 100644 index 0000000..ce8066d --- /dev/null +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -0,0 +1,187 @@ +<%inherit file='base'/> + +<% +fzero = '0d0000000000000000' +%> + +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 ag_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .f64 a_frag; +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif + .reg .f64 b_frag_${nt}; + .reg .f64 c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + + { + .reg .u32 cta; + mov.u32 cta, %ctaid.x; + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; + @pwarp_exit bra $L_EXIT; + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${nt * 8}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + // A thread base: &Ag[0] + lane*8 + { + .reg .u64 t64, a_glb_base, lane64; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 ag_thr_base, a_glb_base, t64; + } + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for mt in range(m_tiles): +% if pm_runtime(mt): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${mt * 8}; + setp.lt.u32 pm_${mt}, crow, ${m}; + } +% endif +% endfor + +% for nt in range(nn): +% for mt in range(m_tiles): +% if beta == 0: + mov.f64 c0_${nt}_${mt}, ${fzero}; + mov.f64 c1_${nt}_${mt}, ${fzero}; +% else: +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.f64 c0_${nt}_${mt}, ${fzero}; + mov.f64 c1_${nt}_${mt}, ${fzero}; +% endif + ${pred_emit(f'ld.global.f64 c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} + ${pred_emit(f'ld.global.f64 c1_{nt}_{mt}, [caddr + 8];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} + } +% endif +% endfor +% endfor + +% for ki in range(k_iters): +% for nt in range(nn): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and ki == k_iters - 1) + needs_zero = pvb is not None or k_tail + pbrow = 'pbrow' if k_tail else None +%> + { + .reg .u64 baddr; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.f64 b_frag_${nt}, ${fzero}; +% endif +% if k_tail: + .reg .pred pbrow; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${ki * 4}; + setp.lt.u32 pbrow, brow, ${k}; + } +% endif + ${pred_emit(f'ld.global.nc.f64 b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} + } +% endfor +% for mt in range(m_tiles): + ld.global.nc.f64 a_frag, [ag_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; +% for nt in range(nn): + mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 + {c0_${nt}_${mt}, c1_${nt}_${mt}}, + {a_frag}, + {b_frag_${nt}}, + {c0_${nt}_${mt}, c1_${nt}_${mt}}; +% endfor +% endfor +% endfor + +% for nt in range(nn): +% for mt in range(m_tiles): +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.global.f64 [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} + ${pred_emit(f'st.global.f64 [caddr + 8], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} + } +% endfor +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako new file mode 100644 index 0000000..ec2f013 --- /dev/null +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -0,0 +1,303 @@ +<%inherit file='base'/> + +<% +# Cooperative-copy params (gA-only) +blockx = 32 * warps_per_cta +a_pairs = a_elems // 2 +a_pairs_tail = a_elems % 2 +copy_v2_iters = (a_pairs + blockx - 1) // blockx +bs = bool(block_stealing) +%> + +% if bs: +.shared .align 8 .b64 ${kname}_mbar; +.shared .align 16 .b8 ${kname}_workid[16]; +% endif +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; +.shared .align 16 .b64 ${kname}_As[${a_elems}]; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 as_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .f64 a_frag; +% if bs: + .reg .u32 ctaid; + .reg .u32 mbar_a, work_a; + .reg .pred p_root, p_done, p_have; +% endif +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif + .reg .f64 b_frag_${nt}; + .reg .f64 c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + +% if bs: + setp.eq.u32 p_root, tid, 0; + mov.u32 mbar_a, ${kname}_mbar; + mov.u32 work_a, ${kname}_workid; + @p_root mbarrier.init.shared::cta.b64 [mbar_a], 1; + bar.sync 0; +% endif + + // Cooperative copy A from .global to .shared via v2 loads + { + .reg .u64 a_glb_base, a_smem_base; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + mov.u64 a_smem_base, ${kname}_As; +% for ci in range(copy_v2_iters): +<% + base_pair = ci * blockx + is_last = ci == copy_v2_iters - 1 + pairs_this = min(blockx, a_pairs - base_pair) +%> + { + .reg .u32 pidx; + .reg .u64 off64, gaddr, saddr; + .reg .f64 v0, v1; +% if is_last and pairs_this < blockx: + .reg .pred plast; + add.u32 pidx, tid, ${base_pair}; + setp.lt.u32 plast, pidx, ${a_pairs}; + mul.wide.u32 off64, pidx, 16; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + @plast ld.global.nc.v2.f64 {v0, v1}, [gaddr]; + @plast st.shared.v2.f64 [saddr], {v0, v1}; +% else: + add.u32 pidx, tid, ${base_pair}; + mul.wide.u32 off64, pidx, 16; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + ld.global.nc.v2.f64 {v0, v1}, [gaddr]; + st.shared.v2.f64 [saddr], {v0, v1}; +% endif + } +% endfor +% if a_pairs_tail: + // Tail element (only when a_elems is odd) + { + .reg .pred plast; + .reg .u64 gaddr, saddr; + .reg .f64 v; + setp.eq.u32 plast, tid, 0; + add.u64 gaddr, a_glb_base, ${(a_elems-1) * 8}; + add.u64 saddr, a_smem_base, ${(a_elems-1) * 8}; + @plast ld.global.nc.f64 v, [gaddr]; + @plast st.shared.f64 [saddr], v; + } +% endif + } + bar.sync 0; + + // Lane-only base; lifted out of the optional steal loop + { + .reg .u64 t64, a_smem_base, lane64; + mov.u64 a_smem_base, ${kname}_As; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 as_thr_base, a_smem_base, t64; + } + +% for mt in range(m_tiles): +% if pm_runtime(mt): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${mt * 8}; + setp.lt.u32 pm_${mt}, crow, ${m}; + } +% endif +% endfor + +% if bs: + mov.u32 ctaid, %ctaid.x; +$L_LOOP: +% endif + + { + .reg .u32 cta; +% if bs: + mov.u32 cta, ctaid; +% else: + mov.u32 cta, %ctaid.x; +% endif + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; +% if bs: + @pwarp_exit bra $L_STEAL; +% else: + @pwarp_exit bra $L_EXIT; +% endif + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${nt * 8}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for nt in range(nn): +% for mt in range(m_tiles): +% if beta == 0: + mov.f64 c0_${nt}_${mt}, 0d0000000000000000; + mov.f64 c1_${nt}_${mt}, 0d0000000000000000; +% else: +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.f64 c0_${nt}_${mt}, 0d0000000000000000; + mov.f64 c1_${nt}_${mt}, 0d0000000000000000; +% endif + ${pred_emit(f'ld.global.f64 c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} + ${pred_emit(f'ld.global.f64 c1_{nt}_{mt}, [caddr + 8];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} + } +% endif +% endfor +% endfor + +% for ki in range(k_iters): +% for nt in range(nn): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and ki == k_iters - 1) + needs_zero = pvb is not None or k_tail + pbrow = 'pbrow' if k_tail else None +%> + { + .reg .u64 baddr; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.f64 b_frag_${nt}, 0d0000000000000000; +% endif +% if k_tail: + .reg .pred pbrow; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${ki * 4}; + setp.lt.u32 pbrow, brow, ${k}; + } +% endif + ${pred_emit(f'ld.global.nc.f64 b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} + } +% endfor +% for mt in range(m_tiles): + ld.shared.f64 a_frag, [as_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; +% for nt in range(nn): + mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 + {c0_${nt}_${mt}, c1_${nt}_${mt}}, + {a_frag}, + {b_frag_${nt}}, + {c0_${nt}_${mt}, c1_${nt}_${mt}}; +% endfor +% endfor +% endfor + +% for nt in range(nn): +% for mt in range(m_tiles): +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.global.f64 [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} + ${pred_emit(f'st.global.f64 [caddr + 8], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} + } +% endfor +% endfor + +% if bs: +$L_STEAL: + // Root issues async try_cancel + waits; bar.sync orders the workid load + @!p_root bra $L_AFTER_WAIT; + { + .reg .u64 state; + mbarrier.arrive.expect_tx.shared::cta.b64 state, [mbar_a], 16; + clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [work_a], [mbar_a]; +$L_WAIT: + mbarrier.try_wait.shared::cta.b64 p_done, [mbar_a], state, 10000000; + @!p_done bra $L_WAIT; + } +$L_AFTER_WAIT: + bar.sync 0; + + { + .reg .b128 resp; + ld.shared::cta.b128 resp, [work_a]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_have, resp; + @!p_have bra $L_FIN; + // 1D grid: extract just x + clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 ctaid, resp; + } + bra.uni $L_LOOP; + +$L_FIN: + bar.sync 0; + @p_root mbarrier.inval.shared::cta.b64 [mbar_a]; +% endif + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dense-mma-ws.mako new file mode 100644 index 0000000..e4b576a --- /dev/null +++ b/gimmik/kernels/ptx/dense-mma-ws.mako @@ -0,0 +1,438 @@ +<%inherit file='base'/> +<% +mbar_maxwait = '0x989680' +direct_store = (beta == 0) +%> + +<%def name="producer_init_setup()"> + // Producer warp: initial A bulk-copy + B load for ctaid_x's work + @!p_prod bra.uni $L_AFTER_INIT_B; + { + .reg .b32 n_start0; + .reg .u64 a_glb; + mul.lo.u32 n_start0, ctaid_x, ${n_per_cta}; + mov.u64 a_glb, ${kname}_Ag; + cvta.to.global.u64 a_glb, a_glb; + @p_warp_lead cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes + [a_smem], [a_glb], ${a_elems * 8}, [tma_mbar]; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b1_smem], [bdesc_addr, {n_start0, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes + a_elems * 8}; + bar.warp.sync 0xffffffff; + .reg .b64 state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 state, [tma_mbar]; +$L_TMA_INIT_W: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], state, ${mbar_maxwait}; + @!p1 bra.uni $L_TMA_INIT_W; + .reg .b64 _state2; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state2, [bready_mbar]; + } +$L_AFTER_INIT_B: + + +<%def name="compute_warp_body()"> + // --- Compute Warps + @!p_compute bra.uni $L_AFTER_COMPUTE; + + // Wait on B + { + .reg .pred p1; +$L_WAIT_BRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [bready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_BRDY; + } + + // MMA + { + .reg .b32 b_sm_a; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_sm_a, b2_smem, b1_smem, p_ph; + + .reg .b32 a_thr_a; + { + .reg .b32 t; + shl.b32 t, lane, 3; + add.u32 a_thr_a, a_smem, t; + } +% for nt in range(nn): + .reg .b32 b_thr_a_${nt}; + { + .reg .b32 bcol_g, t_off; + add.u32 bcol_g, base_bcol, ${8 * nt}; + shl.b32 t_off, bcol_g, 3; + add.u32 b_thr_a_${nt}, b_sm_a, t_off; + } +% endfor + +% if direct_store: + // direct_store: skip shared-staging entirely; compute warps store + // MMA outputs straight to global C with N-tail predication. + .reg .u64 c_glob_addr; + ld.param.u64 c_glob_addr, [_c]; + cvta.to.global.u64 c_glob_addr, c_glob_addr; +% else: + .reg .b32 c_thr_smem; + { + .reg .b32 t1, ccol_b; + mul.lo.u32 t1, base_crow, ${n_per_cta * 8}; + shl.b32 ccol_b, base_ccol, 3; + add.u32 c_thr_smem, c_smem, t1; + add.u32 c_thr_smem, c_thr_smem, ccol_b; + } +% endif + + // Zero accumulators +% for mt in range(m_tiles): +% for nt in range(nn): + .reg .f64 d_x_${mt}_${nt}, d_y_${mt}_${nt}; + mov.f64 d_x_${mt}_${nt}, 0d0000000000000000; + mov.f64 d_y_${mt}_${nt}, 0d0000000000000000; +% endfor +% endfor + + .reg .f64 a_f; +% for mt in range(m_tiles): +% for kt in range(k_iters): +<% + k_tail = (k_rem != 0 and kt == k_iters - 1) +%> + { + .reg .b32 a_a; + add.u32 a_a, a_thr_a, ${(kt * 32 + mt * 32 * k_iters) * 8}; + ld.shared.f64 a_f, [a_a]; +% if k_tail: + .reg .pred pbrow_${mt}_${kt}; + { + .reg .b32 brow; + add.u32 brow, base_brow, ${4 * kt}; + setp.lt.u32 pbrow_${mt}_${kt}, brow, ${k}; + } +% endif +% for nt in range(nn): + { + .reg .b32 b_a, b_row; + .reg .f64 b_f; + add.u32 b_row, base_brow, ${4 * kt}; + mul.lo.u32 b_row, b_row, ${n_per_cta * 8}; + add.u32 b_a, b_thr_a_${nt}, b_row; +% if k_tail: + mov.f64 b_f, 0d0000000000000000; + @pbrow_${mt}_${kt} ld.shared.f64 b_f, [b_a]; +% else: + ld.shared.f64 b_f, [b_a]; +% endif + mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 + {d_x_${mt}_${nt}, d_y_${mt}_${nt}}, {a_f}, {b_f}, + {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor + } +% endfor +% endfor + +% if direct_store: + .reg .u64 c_thr_glob_base; + { + .reg .u32 thr_col_off, thr_addr_off_lo; + add.u32 thr_col_off, base_ccol, n_start_curr; + mad.lo.u32 thr_addr_off_lo, base_crow, ${ldc}, thr_col_off; + .reg .u64 thr_byte_off; + mul.wide.u32 thr_byte_off, thr_addr_off_lo, 8; + add.u64 c_thr_glob_base, c_glob_addr, thr_byte_off; + } +% for mt in range(m_tiles): +<% + row_tail = (m_pad > m) and ((mt + 1) * 8 > m) +%> +% if row_tail: + .reg .pred p_row_${mt}; + { + .reg .b32 crow; + add.u32 crow, base_crow, ${8 * mt}; + setp.lt.u32 p_row_${mt}, crow, ${m}; + } +% endif +% for nt in range(nn): + { + .reg .pred p_st; + .reg .u32 g_ccol; + add.u32 g_ccol, base_ccol, ${8 * nt}; + add.u32 g_ccol, g_ccol, n_start_curr; + setp.lt.u32 p_st, g_ccol, ${n}; +% if row_tail: + and.pred p_st, p_st, p_row_${mt}; +% endif + .reg .u64 c_addr; + add.u64 c_addr, c_thr_glob_base, ${(mt * 8 * ldc + nt * 8) * 8}; + @p_st st.global.v2.f64 [c_addr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor +% endfor +% else: + // Wait until producer's prev-iter TMA-store of C has drained. + { + .reg .pred p1; +$L_WAIT_CSTORE: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cstored_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CSTORE; + } + + // Vector-store {d_x, d_y} pairs to csmem. M-tail / N-tail OOB rows + // are dropped by the C tensor map. +% for mt in range(m_tiles): +% for nt in range(nn): + { + .reg .b32 csaddr; + add.u32 csaddr, c_thr_smem, ${mt * c_mtile_smem_stride + nt * c_ntile_smem_stride}; + st.shared.v2.f64 [csaddr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor +% endfor +% endif + +% if not direct_store: + bar.sync 1, ${comp_threads}; + fence.proxy.async.shared::cta; + { + .reg .b64 _state; + @p_tid0 mbarrier.arrive.shared::cta.b64 _state, [cready_mbar]; + } +% endif + + // Wait for new work and unpack + { + .reg .pred p1, p_canc; + .reg .b128 resp; +$L_WAIT_WNEW_C: + mbarrier.try_wait.parity.shared::cta.b64 p1, [wid_new_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_WNEW_C; + + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + @p_canc clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 block_idx_x, resp; + selp.b32 work, 1, 0, p_canc; + + .reg .b64 _state; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_used_mbar]; + } + } +$L_AFTER_COMPUTE: + + +<%def name="data_warp_body()"> + // --- Data Movement Warp + @!p_prod bra.uni $L_AFTER_DATA; + { + .reg .b32 n_c_store; + mul.lo.u32 n_c_store, block_idx_x, ${n_per_cta}; + + // Wait for new work and unpack + { + .reg .pred p1, p_canc; + .reg .b128 resp; +$L_WAIT_WNEW_D: + mbarrier.try_wait.parity.shared::cta.b64 p1, [wid_new_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_WNEW_D; + + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + @p_canc clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 block_idx_x, resp; + selp.b32 work, 1, 0, p_canc; + .reg .b64 _state; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_used_mbar]; + } + + // TMA loads of next B + { + mul.lo.u32 n_start_next, block_idx_x, ${n_per_cta}; + .reg .b32 b_next; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_next, b1_smem, b2_smem, p_ph; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b_next], [bdesc_addr, {n_start_next, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes}; + @p_warp_lead cp.async.bulk.commit_group; + } + bar.warp.sync 0xffffffff; + +% if not direct_store: + // TMA reduce+store of C (beta=1 only; beta=0 uses direct global + // stores from compute warps, so the producer does no C work). + { + .reg .pred p1; + .reg .b64 _c_state; +$L_WAIT_CRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CRDY; + @p_warp_lead cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.tile.bulk_group + [cdesc_addr, {n_c_store, 0}], [c_smem]; + @p_warp_lead cp.async.bulk.commit_group; + @p_warp_lead cp.async.bulk.wait_group 0; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _c_state, [cstored_mbar]; + } +% endif + + // Wait for next B to be ready, then signal B and C ready + { + .reg .b64 b_state, _bready_state, _c_state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 b_state, [tma_mbar]; +$L_WAIT_TMA: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], b_state, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_TMA; + + @p_warp_lead mbarrier.arrive.shared::cta.b64 _bready_state, [bready_mbar]; + } + } +$L_AFTER_DATA: + + +<%def name="ctrl_warp_body()"> + // --- Controller Warp + @!p_steal bra.uni $L_AFTER_CTRL; + { + .reg .pred p1, p2, p_canc; + .reg .b64 _state; + .reg .b128 resp; + @p_warp_lead fence.proxy.async.shared::cta; + @p_warp_lead clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 + [wid_smem], [steal_mbar]; + @p_warp_lead mbarrier.arrive.expect_tx.shared::cta.b64 + _state, [steal_mbar], 16; + +$L_WAIT_STEAL: + mbarrier.try_wait.parity.shared::cta.b64 p1, [steal_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_STEAL; + + // Signal new work + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_new_mbar]; + + // Query if there's new work + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + selp.b32 work, 1, 0, p_canc; + + // Wait for old work to be used +$L_WAIT_WUSED: + mbarrier.try_wait.parity.shared::cta.b64 p2, [wid_used_mbar], phase, ${mbar_maxwait}; + @!p2 bra.uni $L_WAIT_WUSED; + } +$L_AFTER_CTRL: + + +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; +.extern .shared .align 128 .b8 ${kname}_dynm[]; +.const .align 64 .b8 ${kname}_bdesc[128]; +.const .align 64 .b8 ${kname}_cdesc[128]; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +.maxntid ${blockx_total}, 1, 1 +{ + .reg .b32 tid, warp, lane, phase, ctaid_x; + .reg .b32 base_brow, base_bcol, base_crow, base_ccol; + .reg .b32 work, block_idx_x, n_start_curr, n_start_next; + .reg .u64 bdesc_addr, cdesc_addr; + .reg .b32 a_smem, b1_smem, b2_smem, c_smem; + .reg .b32 tma_mbar, wid_new_mbar, bready_mbar, cready_mbar, cstored_mbar, steal_mbar; + .reg .b32 wid_used_mbar, wid_smem; + .reg .pred p_compute, p_prod, p_steal; + .reg .pred p_warp_lead; + .reg .pred p_done; + .reg .pred p_tid0; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + mov.u32 ctaid_x, %ctaid.x; + + .reg .b32 dynm_base; + mov.u32 dynm_base, ${kname}_dynm; + add.u32 b1_smem, dynm_base, ${b1_off}; + add.u32 b2_smem, dynm_base, ${b2_off}; + add.u32 c_smem, dynm_base, ${c_off}; + add.u32 a_smem, dynm_base, ${a_off}; + add.u32 wid_smem, dynm_base, ${wid_off}; + + add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; + add.u32 bready_mbar, dynm_base, ${bready_mbar_off}; + add.u32 cready_mbar, dynm_base, ${cready_mbar_off}; + add.u32 cstored_mbar, dynm_base, ${cstored_mbar_off}; + add.u32 steal_mbar, dynm_base, ${steal_mbar_off}; + add.u32 wid_new_mbar, dynm_base, ${wid_new_mbar_off}; + add.u32 wid_used_mbar, dynm_base, ${wid_used_mbar_off}; + + cvta.const.u64 bdesc_addr, ${kname}_bdesc; + cvta.const.u64 cdesc_addr, ${kname}_cdesc; + + setp.eq.u32 p_tid0, tid, 0; + + setp.lt.u32 p_compute, warp, ${n_comp_warps}; + setp.eq.u32 p_prod, warp, ${prod_warp}; + setp.eq.u32 p_steal, warp, ${steal_warp}; + + { + .reg .b32 _elect_lane; + elect.sync _elect_lane|p_warp_lead, 0xffffffff; + } + + // mbarrier init (tid 0 only); pre-arrive csmem_free so compute iter 0 + // can write csmem immediately. + { + .reg .pred p_init; + setp.eq.u32 p_init, tid, 0; + .reg .b64 _state; + @p_init mbarrier.init.shared::cta.b64 [tma_mbar], 32; + @p_init mbarrier.init.shared::cta.b64 [bready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cstored_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [steal_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [wid_used_mbar], ${n_comp_warps + 1}; + @p_init mbarrier.init.shared::cta.b64 [wid_new_mbar], 1; + @p_init mbarrier.arrive.shared::cta.b64 _state, [cstored_mbar]; + @p_init fence.proxy.async.shared::cta; + } + bar.sync 0; + + // Compute-warp lane geometry (cheap; all warps execute uniformly) + { + .reg .b32 t, w_n_base; + and.b32 base_brow, lane, 3; + shr.u32 base_crow, lane, 2; + mul.lo.u32 w_n_base, warp, ${n_per_warp}; + add.u32 base_bcol, base_crow, w_n_base; + shl.b32 t, base_brow, 1; + add.u32 base_ccol, t, w_n_base; + } + + ${producer_init_setup()} + + mov.u32 block_idx_x, ctaid_x; + mov.u32 work, 1; + mov.u32 phase, 0; + +$L_LOOP: + setp.eq.u32 p_done, work, 0; + @p_done bra.uni $L_EXIT; + + mul.lo.u32 n_start_curr, block_idx_x, ${n_per_cta}; + + ${compute_warp_body()} + + ${data_warp_body()} + + ${ctrl_warp_body()} + + xor.b32 phase, phase, 1; + bra.uni $L_LOOP; + +$L_EXIT: + ret; +} diff --git a/gimmik/ptx.py b/gimmik/ptx.py new file mode 100644 index 0000000..ad429e8 --- /dev/null +++ b/gimmik/ptx.py @@ -0,0 +1,276 @@ +# -*- coding: utf-8 -*- + +import struct + +import numpy as np + +from gimmik.base import MatMul + + +class PTXMatMul(MatMul): + platform = 'ptx' + basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, + 'dynamic_shared': 0} + + @staticmethod + def is_sparse_suitable(arr): + nnz = int(np.count_nonzero(arr)) + nuq = int(len(np.unique(np.abs(arr)))) + density = nnz / arr.size + return (nuq <= 28) or (density <= 0.15) + + # Shape/arch gate for dense DMMA; n/ldb/ldc are validated at generate time + @staticmethod + def is_dense_suitable(arr, dtype, cc): + return (np.dtype(dtype) == np.float64 + and cc is not None and cc >= (9, 0) + and arr.shape[0] <= 128 and arr.shape[1] <= 128) + + @classmethod + def is_suitable(cls, arr, dtype, cc): + return (cls.is_sparse_suitable(arr) + or cls.is_dense_suitable(arr, dtype, cc)) + + def _kernel_generators(self, dtype, dsize, *, compute_capability=None): + base_args = {'cc': compute_capability, + 'pred_emit': self._pred_emit} + + yield from self._sparse_kernel_generators(dtype, dsize, base_args) + yield from self._dense_kernel_generators(dtype, dsize, base_args) + + def _sparse_kernel_generators(self, dtype, dsize, base_args): + if not self.is_sparse_suitable(self.A): + return + + # B loading, C streaming kernel + yield ('cstream', base_args, {'desc': 'cstream'}) + + # B streaming, C accumulation kernel + yield ('bstream', base_args, {'desc': 'bstream'}) + + # Four-way m-split B streaming, C accumulation kernel + ms, bsz, blkx = 4, 24, 32 + args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} + meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'} + yield ('bstream-msplit', args, meta) + + # Single-warp LDGSTS variant for medium-M beta=0 large-K cases + if self.beta == 0 and self.m <= 320 and len(self.bix) >= 64: + ms, bsz, blkx = 1, 32, 64 + args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} + meta = {'block': (blkx, ms, 1), + 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'} + yield ('bstream-msplit', args, meta) + + # Two-way k-split B loading, C streaming kernel + ks, csz, blkx = 2, 24, 32 + args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} + meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'} + yield ('cstream-ksplit', args, meta) + + # Four-way k-split for large K + K_used = len(self.bix) + if K_used > 500: + ks, csz, blkx = 4, 20, 32 + args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} + meta = {'block': (blkx, ks, 1), + 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'} + yield ('cstream-ksplit', args, meta) + + # Width-2 vector cstream for fp64 small-K + if (dtype == 'double' and self.n is not None and self.n % 2 == 0 + and K_used <= 100 + and (self.aligne is None or self.aligne % 2 == 0)): + blkx = 128 + args = base_args | {'blockx': blkx} + meta = {'block': (blkx, 1, 1), 'width': 2, + 'desc': f'cstream-w2/x{blkx}'} + yield ('cstream-w2', args, meta) + + def _dense_kernel_generators(self, dtype, dsize, base_args): + cc = base_args['cc'] or (0, 0) + if not (self.is_dense_suitable(self.A, dtype, cc) + and self.n is not None): + return + + # Some kernels can optional steal blocks + bs_default = cc >= (10, 0) + + if cc >= (10, 0): + # Warp specialised is uniformly better on sm_100+, so no need to JIT + # other versions + dense_configs = [('dense-mma-smem-gA', 4, 4)] + else: + dense_configs = [ + ('dense-mma-smem-gA', 1, 8), + ('dense-mma-smem-gA', 2, 4), + ('dense-mma-smem-gA', 4, 4), + ('dense-mma-gAd', 2, 2), + ('dense-mma-gAd', 4, 2), + ] + + for tpl, nn, w in dense_configs: + blkx = 32 * w + n_per_cta = 8 * nn * w + if n_per_cta > self.n: + continue + bs = (tpl == 'dense-mma-smem-gA') and bs_default + setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) + args = (base_args | {'warps_per_cta': w, 'nn': nn, + 'block_stealing': bs} | setup) + meta = { + 'block': (blkx, 1, 1), + 'grid': (-(-self.n // n_per_cta), 1, 1), + 'desc': f'{tpl}/nn{nn}-w{w}{"-bs" if bs else ""}', + } + yield (tpl, args, meta) + + # Warp-specialised dense DMMA + if cc >= (10, 0): + yield from self._dense_ws_kernel_generators(dtype, dsize, base_args) + + def _dense_ws_kernel_generators(self, dtype, dsize, base_args): + m_pad = -(-self.m // 8) * 8 + k_pad = -(-self.k // 4) * 4 + + # (nn, w_compute) -- block has w_compute + 2 warps (producer, stealer) + ws_configs = [(1, 4), (2, 4), (4, 4)] + for nn, w in ws_configs: + n_per_cta = 8 * nn * w + if n_per_cta > self.n: + continue + blkx = 32 * (w + 2) + setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) + ws_layout = self._dense_ws_layout( + n_comp_warps=w, n_per_cta=n_per_cta, + m_pad=m_pad, k_pad=k_pad, a_elems=setup['a_elems'] + ) + + if ws_layout['dynm_total_bytes'] > 200 * 1024: + continue + + args = (base_args + | {'warps_per_cta': w, 'nn': nn} + | setup | ws_layout) + meta = { + 'block': (blkx, 1, 1), + 'grid': (-(-self.n // n_per_cta), 1, 1), + 'desc': f'dense-mma-ws/nn{nn}-w{w}', + 'ws_tensor_map': True, + 'ws_n_per_cta': n_per_cta, + 'ws_k_pad': k_pad, + 'ws_m_pad': m_pad, + 'dynamic_shared': ws_layout['dynm_total_bytes'], + } + yield ('dense-mma-ws', args, meta) + + @staticmethod + def _dense_ws_layout(*, n_comp_warps, n_per_cta, m_pad, k_pad, a_elems): + n_total_warps = n_comp_warps + 2 + blockx_total = 32 * n_total_warps + + b_tile_bytes = k_pad * n_per_cta * 8 + c_tile_bytes = m_pad * n_per_cta * 8 + a_bytes = a_elems * 8 + + smem_size = {'b1': b_tile_bytes, 'b2': b_tile_bytes, 'c': c_tile_bytes, + 'a': a_bytes, 'wid': 16} + smem_off, off = {}, 0 + for k, v in smem_size.items(): + off = (off + 15) & ~15 + smem_off[f'{k}_off'] = off + off += v + + mbar_names = ('tma', 'bready', 'cready', 'cstored', + 'steal', 'wid_new', 'wid_used') + for k in mbar_names: + smem_off[f'{k}_mbar_off'] = off + off += 8 + + # Pad total to 16-byte multiple + dynm_total_bytes = (off + 15) & ~15 + + params = {'n_comp_warps': n_comp_warps, + 'blockx_total': blockx_total, + 'prod_warp': n_comp_warps, + 'steal_warp': n_comp_warps + 1, + 'comp_threads': 32 * n_comp_warps, + 'm_pad': m_pad, + 'k_pad': k_pad, + 'b_tile_doubles': k_pad * n_per_cta, + 'b_tile_bytes': b_tile_bytes, + 'c_tile_doubles': m_pad * n_per_cta, + 'c_mtile_smem_stride': 8 * n_per_cta * 8, + 'c_ntile_smem_stride': 8 * 8, + 'dynm_total_bytes': dynm_total_bytes, + } + params |= smem_off + return params + + def _dense_mma_setup(self, *, nn, warps_per_cta): + a = self.A + m, k = a.shape + m_tiles = -(-m // 8) + k_rem = k % 4 + k_iters = (k + (4 - k_rem if k_rem else 0)) // 4 + + # A in fragment layout: lane l -> A[m_tile*8 + l/4][k_iter*4 + l%4] + a_u64 = [] + for m_tile in range(m_tiles): + for k_iter in range(k_iters): + for lane in range(32): + i = m_tile * 8 + lane // 4 + j = k_iter * 4 + lane % 4 + v = float(a[i, j]) if (i < m and j < k) else 0.0 + u = struct.unpack(' m + + return { + 'm_tiles': m_tiles, + 'k_rem': k_rem, 'k_iters': k_iters, + 'a_u64': a_u64, + 'n_per_warp': n_per_warp, 'n_per_cta': n_per_cta, + 'a_elems': a_elems, + 'frag_stride_bytes': 32 * 8, + 'b_kiter_stride': 4 * (self.ldb or 0) * 8, + 'b_ntile_stride': 8 * 8, + 'c_mtile_stride': 8 * (self.ldc or 0) * 8, + 'c_ntile_stride': 8 * 8, + 'n_col_aligned': n_col_aligned, + 'pm_runtime': pm_runtime, + } + + @staticmethod + def _pred_emit(instr, *preds, pred_reg=None, indent=' ' * 8): + actual = [p for p in preds if p is not None] + if not actual: + return instr + if len(actual) == 1: + return f'@{actual[0]} {instr}' + if pred_reg is None: + raise ValueError('pred_reg required when combining multiple ' + 'predicates') + lines = [f'.reg .pred {pred_reg};', + f'and.pred {pred_reg}, {actual[0]}, {actual[1]};'] + for p in actual[2:]: + lines.append(f'and.pred {pred_reg}, {pred_reg}, {p};') + lines.append(f'@{pred_reg} {instr}') + return f'\n{indent}'.join(lines) + + def _process_meta(self, meta): + if self.n is not None and 'grid' not in meta: + div = meta['block'][0]*meta['width'] + meta['grid'] = (-(-self.n // div), 1, 1)