diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index 7c5ee0b34f..0dd32b7160 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -553,6 +553,35 @@ struct MixedInputUtils { static constexpr auto KernelConversionMode = Collective::KernelConversionMode; static constexpr auto ModeHasScales = Collective::ModeHasScales; static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable; + static constexpr bool UseNvfp4Block16ScaleBroadcast = + cute::is_same_v && + cute::is_same_v && + (int(size<1>(SmemLayoutScale{})) > 1); + + static constexpr auto + get_mma_smem_layout_scale() { + if constexpr (UseNvfp4Block16ScaleBroadcast) { + auto compact_layout = SmemLayoutScale{}; + constexpr int ScaleK = int(size<1>(SmemLayoutScale{})); + static_assert(int(size<0>(SmemLayoutScale{})) % 16 == 0, + "NVFP4 scale broadcast assumes 16-row scale atoms."); + auto compact_k_stride = + compact_layout(_0{}, _1{}, _0{}) - compact_layout(_0{}, _0{}, _0{}); + auto broadcast_layout = make_layout( + make_shape(shape<0>(compact_layout), + make_shape(Int<16>{}, Int{}), + shape<2>(compact_layout)), + make_stride(stride<0>(compact_layout), + make_stride(Int<0>{}, compact_k_stride), + stride<2>(compact_layout))); + static_assert(cute::cosize_v == + cute::cosize_v); + return broadcast_layout; + } + else { + return SmemLayoutScale{}; + } + } public: static constexpr auto @@ -664,8 +693,14 @@ struct MixedInputUtils { copy(smem_tiled_copy_A, tCsA(_,_,k_block,read_stage), tCrA_copy_view(_,_,k_block)); - if (k_block == 0) { - // We are starting a new k-tile so copy the scale + bool copy_extra_inputs = k_block == 0; + if constexpr (size<1>(SmemLayoutScale{}) != 1) { + copy_extra_inputs = true; + } + + if (copy_extra_inputs) { + // One-scale-per-tile kernels only refresh at the first GMMA k-block. + // NVFP4 block-16 kernels use a broadcast MMA view over compact scale columns. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { // nothing to do } @@ -862,7 +897,7 @@ struct MixedInputUtils { } else if constexpr (UseScaleLookupTable) { constexpr int num_elements = decltype(size(src))::value; - static_assert(is_same_v || is_same_v, + static_assert(is_same_v || is_same_v, "Lookup table supports int4b_t (Two's Complement) and float_e2m1_t (E2M1/FP4) quant types."); static_assert(sizeof_bits_v == 64, "Lookup table only supports 8 8bit scale values now."); static_assert(num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting."); @@ -886,7 +921,7 @@ struct MixedInputUtils { { auto&& scale_neg_ = reinterpret_cast const&>(scales_neg_vm_(i)); auto&& scale_pos_ = reinterpret_cast &>(scales_pos_vm_(i)); - + // Accept CUTLASS pseudo-FP as well if constexpr (cutlass::platform::is_floating_point::value || cute::is_same_v) { @@ -1022,7 +1057,7 @@ struct MixedInputUtils { Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack)); cute::transform(src_arr, dst_arr, Converter::convert); - + if constexpr (ModeHasScales) { auto const& scales = cute::get<1>(partitioned_extra_info)(_,_,_,k_block); @@ -1154,7 +1189,7 @@ struct MixedInputUtils { return cute::make_tuple(); } else if constexpr (UseScaleLookupTable) { - Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), get_mma_smem_layout_scale());// (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsS = mma_thread_slice.partition_A(sS); Tensor tCrS_neg = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); Tensor tCrS_pos = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); @@ -1164,7 +1199,7 @@ struct MixedInputUtils { } } else if constexpr (ModeHasScales) { - Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), get_mma_smem_layout_scale());// (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsS = mma_thread_slice.partition_A(sS); Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); @@ -1172,7 +1207,7 @@ struct MixedInputUtils { return cute::make_tuple(tCsS, tCrS); } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), get_mma_smem_layout_scale());// (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsZ = mma_thread_slice.partition_A(sZ); Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).layout()); return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp index 1ca5e7c96c..222603c271 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -200,7 +200,17 @@ struct CollectiveMma< static constexpr int NumProducerThreadEvents = 1; - using SmemLayoutAtomScale = Layout(SwappedSmemLayoutAtomA{})), cute::Int<1>>>; + static constexpr bool UseNvfp4Block16Scales = + cute::is_same_v && + cute::is_same_v && + ((int(size<2>(TileShape{})) % 16) == 0); + using ScaleAtomM = + cute::conditional_t, + decltype(cute::shape<0>(SwappedSmemLayoutAtomA{}))>; + static constexpr int ScaleAtomK = + UseNvfp4Block16Scales ? int(size<2>(TileShape{})) / 16 : 1; + using SmemLayoutAtomScale = + Layout>>; using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{}))); static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); @@ -234,9 +244,8 @@ struct CollectiveMma< static_assert(cute::is_same_v || cute::is_same_v, "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - // To relax them, we need to handle loading more than 1 row of scales for every main loop iteration. - // We must also handle updating the pipeline transaction bytes on the fly. - static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); + static_assert(size<1>(SmemLayoutAtomScale{}) == 1 || UseNvfp4Block16Scales, + "Only NVFP4 block-16 scales may use multiple scale columns per K tile."); private: static constexpr ConversionMode @@ -378,6 +387,11 @@ struct CollectiveMma< init_M = get<1>(init_shape); init_N = get<0>(init_shape); } + if constexpr (IsGroupedGemmKernel) { + init_M = cute::max(init_M, int(size<0>(TileShape{}))); + init_N = cute::max(init_N, int(size<1>(TileShape{}))); + init_K = cute::max(init_K, int(size<2>(TileShape{}))); + } // Batches/Groups are managed by using appropriate pointers to input matrices const uint32_t mock_L = 1; SwappedElementA const* ptr_A_first_batch; @@ -491,7 +505,9 @@ struct CollectiveMma< else if constexpr (ModeHasScales) { auto scale_k = ceil_div(init_K, args.chunk_size); ElementScale const* ptr_S = reinterpret_cast(args.ptr_S); - StrideScale dS{}; + StrideScale dS = + make_stride(Int<1>{}, static_cast(init_M), + static_cast(init_M) * scale_k); Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(init_M,scale_k,mock_L), dS)); tma_load_scale = make_tma_copy( GmemTiledCopyScale{}, @@ -596,8 +612,14 @@ struct CollectiveMma< const int scale_k = ceil_div(K, args.chunk_size); constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); - implementable = implementable && (args.chunk_size == K || ((args.chunk_size % size<2>(TileShape{})) == 0)); implementable = implementable && args.chunk_size != 0; + if (args.chunk_size != 0) { + implementable = implementable && + (args.chunk_size == K || + ((args.chunk_size % size<2>(TileShape{})) == 0) || + (UseNvfp4Block16Scales && + ((int(size<2>(TileShape{})) % args.chunk_size) == 0))); + } implementable = implementable && (args.ptr_S != nullptr); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { implementable = implementable && (args.ptr_Z == nullptr);