From e347ec3a1c01b767886b01db7d1b7b28ba61d11b Mon Sep 17 00:00:00 2001 From: Fletterio Date: Tue, 24 Mar 2026 21:37:55 -0300 Subject: [PATCH 01/15] mid subgroup fft --- examples_tests | 2 +- include/nbl/builtin/hlsl/fft2/common.hlsl | 178 ++++++++++++++++++++ include/nbl/builtin/hlsl/math/intutil.hlsl | 9 +- include/nbl/builtin/hlsl/subgroup2/fft.hlsl | 111 ++++++++++++ 4 files changed, 297 insertions(+), 3 deletions(-) create mode 100644 include/nbl/builtin/hlsl/fft2/common.hlsl create mode 100644 include/nbl/builtin/hlsl/subgroup2/fft.hlsl diff --git a/examples_tests b/examples_tests index 8f045a1c27..02fac2d163 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 8f045a1c27a198f8542456378f865032765378b8 +Subproject commit 02fac2d1633dd0405210e40a463e1180426133b9 diff --git a/include/nbl/builtin/hlsl/fft2/common.hlsl b/include/nbl/builtin/hlsl/fft2/common.hlsl new file mode 100644 index 0000000000..e265ae841a --- /dev/null +++ b/include/nbl/builtin/hlsl/fft2/common.hlsl @@ -0,0 +1,178 @@ +#ifndef _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_ +#define _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_ + +#include +#include +#include +#include +#include +#include + +namespace nbl +{ +namespace hlsl +{ +namespace fft2 +{ + +template 0 && N <= 4 && mpl::is_pot_v) +/** +* @brief Returns the size of the full FFT computed, in terms of number of complex elements. If the signal is real, you MUST provide a valid value for `firstAxis` (this is to run two real FFTs as one complex). + If the signal is complex, you must NOT pass any value to `firstAxis`. +* The padding rule is the following: if FFT is subgroup-sized, size is rounded up to PoT. If bigger than that, it's rounded to the smallest number that is either PoT or of the form +* 3 * 2^n or 5 * 2^n. +* +* @tparam N Number of dimensions of the signal to perform FFT on. +* +* @param [in] dimensions Size of the signal. +* @param [in] firstAxis Indicates which axis the FFT is performed on first. Only relevant for real-valued signals. +*/ +inline vector padDimensions(vector dimensions, uint16_t firstAxis = N) +{ + const subgroupFFTSize = SubgroupSize << 1; + vector newDimensions; + for (uint16_t i = 0u; i < N; i++) + { + newDimensions[i] = hlsl::roundUpToPoT(dimensions[i]); + if (dimensions[i] <= subgroupFFTSize) + continue; + // Consider a factor of 3 + newDimensions[i] = hlsl::min(newDimensions[i], 3 * hlsl::roundUpToPoT(hlsl::ceilDiv(dimensions[i], 3))); + // Do the same for a factor of 5 + newDimensions[i] = hlsl::min(newDimensions[i], 5 * hlsl::roundUpToPoT(hlsl::ceilDiv(dimensions[i], 5))); + // TODO: Consider if factor of 7 is viable + } + // If real, first axis gets halved since we run two real FFTs at once + if (firstAxis < N) + newDimensions[firstAxis] /= 2; + return newDimensions; +} + +template 0 && N <= 4) +/** +* @brief Returns the size required by a buffer to hold the result of the FFT of a signal after a certain pass. +* +* @tparam N Number of dimensions of the signal to perform FFT on. +* +* @param [in] numChannels Number of channels of the signal. +* @param [in] inputDimensions Size of the signal. +* @param [in] passIx Which pass the size is being computed for. +* @param [in] axisPassOrder Order of the axis in which the FFT is computed in. Default is xyzw. +* @param [in] realFFT True if the signal is real. False by default. +* @param [in] halfFloats True if using half-precision floats. False by default. +*/ +inline uint64_t getOutputBufferSize( + uint32_t numChannels, + vector inputDimensions, + uint16_t passIx, + vector axisPassOrder = _static_cast >(uint16_t4(0, 1, 2, 3)), + bool realFFT = false, + bool halfFloats = false +) +{ + const vector paddedDimensions = padDimensions(inputDimensions, realFFT ? axisPassOrder[0] : N); + vector axesDone = promote, bool>(false); + for (uint16_t i = 0; i <= passIx; i++) + axesDone[axisPassOrder[i]] = true; + const vector passOutputDimension = lerp(inputDimensions, paddedDimensions, axesDone); + uint64_t numberOfComplexElements = uint64_t(numChannels); + for (uint16_t i = 0; i < N; i++) + numberOfComplexElements *= uint64_t(passOutputDimension[i]); + return numberOfComplexElements * (halfFloats ? sizeof(complex_t) : sizeof(complex_t)); +} + +template 0 && N <= 4) +/** +* @brief Returns the size required by a buffer to hold the result of the FFT of a signal after a certain pass, when using the FFT to convolve it against a kernel. +* +* @tparam N Number of dimensions of the signal to perform FFT on. +* +* @param [in] numChannels Number of channels of the signal. +* @param [in] inputDimensions Size of the signal. +* @param [in] kernelDimensions Size of the kernel. +* @param [in] passIx Which pass the size is being computed for. +* @param [in] axisPassOrder Order of the axis in which the FFT is computed in. Default is xyzw. +* @param [in] realFFT True if the signal is real. False by default. +* @param [in] halfFloats True if using half-precision floats. False by default. +*/ +inline uint64_t getOutputBufferSizeConvolution( + uint32_t numChannels, + vector inputDimensions, + vector kernelDimensions, + uint16_t passIx, + vector axisPassOrder = _static_cast >(uint16_t4(0, 1, 2, 3)), + bool realFFT = false, + + bool halfFloats = false +) +{ + const vector paddedDimensions = padDimensions(inputDimensions + kernelDimensions, realFFT ? axisPassOrder[0] : N); + vector axesDone = promote, bool>(false); + for (uint16_t i = 0; i <= passIx; i++) + axesDone[axisPassOrder[i]] = true; + const vector passOutputDimension = lerp(inputDimensions, paddedDimensions, axesDone); + uint64_t numberOfComplexElements = uint64_t(numChannels); + for (uint16_t i = 0; i < N; i++) + numberOfComplexElements *= uint64_t(passOutputDimension[i]); + return numberOfComplexElements * (halfFloats ? sizeof(complex_t) : sizeof(complex_t)); +} + + +// Computes the kth element in the group of N roots of unity +// Notice 0 <= k < N/2, rotating counterclockwise in the forward (DIF) transform and clockwise in the inverse (DIT) +template +complex_t twiddle(uint32_t k, uint32_t halfN) +{ + complex_t retVal; + const Scalar kthRootAngleRadians = numbers::pi *Scalar(k) / Scalar(halfN); + retVal.real(cos(kthRootAngleRadians)); + if (!inverse) + retVal.imag(sin(-kthRootAngleRadians)); + else + retVal.imag(sin(kthRootAngleRadians)); + return retVal; +} + +template +struct DIX +{ + static void radix2(complex_t twiddle, NBL_REF_ARG(complex_t) lo, NBL_REF_ARG(complex_t) hi) + { + plus_assign< complex_t > plusAss; + //Decimation in time - inverse + if (inverse) { + complex_t wHi = twiddle * hi; + hi = lo - wHi; + plusAss(lo, wHi); + } + //Decimation in frequency - forward + else { + complex_t diff = lo - hi; + plusAss(lo, hi); + hi = twiddle * diff; + } + } +}; + +template +using DIT = DIX; + +template +using DIF = DIX; + +// ------------------------------------------------- Utils --------------------------------------------------------- +// +// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi +template +void unpack(NBL_REF_ARG(complex_t) lo, NBL_REF_ARG(complex_t) hi) +{ + complex_t x = (lo + conj(hi)) * Scalar(0.5); + hi = rotateRight(lo - conj(hi)) * Scalar(0.5); + lo = x; +} + +} +} +} + +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/math/intutil.hlsl b/include/nbl/builtin/hlsl/math/intutil.hlsl index 7394e03ae4..1f00b53cf6 100644 --- a/include/nbl/builtin/hlsl/math/intutil.hlsl +++ b/include/nbl/builtin/hlsl/math/intutil.hlsl @@ -34,11 +34,16 @@ NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer roundDownToPoT(Integer value) return Integer(0x1u) << hlsl::findMSB(value); } +template) + NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer ceilDiv(Integer dividend, Integer divisor) +{ + return (dividend + divisor - 1) / divisor; +} + template) NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer roundUp(Integer value, Integer multiple) { - Integer tmp = (value + multiple - 1u) / multiple; - return tmp * multiple; + return ceilDiv(value, multiple) * multiple; } template) diff --git a/include/nbl/builtin/hlsl/subgroup2/fft.hlsl b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl new file mode 100644 index 0000000000..34044fb964 --- /dev/null +++ b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl @@ -0,0 +1,111 @@ +#ifndef _NBL_BUILTIN_HLSL_SUBGROUP_FFT_INCLUDED_ +#define _NBL_BUILTIN_HLSL_SUBGROUP_FFT_INCLUDED_ + +#include "nbl/builtin/hlsl/fft/common.hlsl" +#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" +#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl" +#include "nbl/builtin/hlsl/concepts/accessors/fft.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace subgroup2 +{ + +// ----------------------------------------------------------------------------------------------------------------------------------------------------------------- +template +struct FFT +{ + template + static void __call(NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor); +}; + +// ---------------------------------------- Radix 2 forward transform - DIF ------------------------------------------------------- + +template +struct FFT +{ + static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t) lo, NBL_REF_ARG(complex_t) hi) + { + const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & stride); + const vector toTrade = topHalf ? vector (lo.real(), lo.imag()) : vector (hi.real(), hi.imag()); + const vector exchanged = glsl::subgroupShuffleXor< vector > (toTrade, stride); + if (topHalf) + { + lo.real(exchanged.x); + lo.imag(exchanged.y); + } + else + { + hi.real(exchanged.x); + hi.imag(exchanged.y); + } + // Get twiddle with k = subgroupInvocation mod stride, halfN = stride + fft::DIF::radix2(fft::twiddle(glsl::gl_SubgroupInvocationID() & (stride - 1), stride), lo, hi); + } + + static void __call(NBL_REF_ARG(complex_t) lo, NBL_REF_ARG(complex_t) hi) + { + const uint32_t subgroupSize = glsl::gl_SubgroupSize(); //This is N/2 + + // special first iteration + fft::DIF::radix2(fft::twiddle(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi); + + // Decimation in Frequency + [unroll] + for (uint32_t stride = subgroupSize >> 1; stride > 0; stride >>= 1) + FFT_loop(stride, lo, hi); + } +}; + + +// ---------------------------------------- Radix 2 inverse transform - DIT ------------------------------------------------------- + +template +struct FFT +{ + static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t) lo, NBL_REF_ARG(complex_t) hi) + { + // Get twiddle with k = subgroupInvocation mod stride, halfN = stride + fft::DIT::radix2(fft::twiddle(glsl::gl_SubgroupInvocationID() & (stride - 1), stride), lo, hi); + + const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & stride); + const vector toTrade = topHalf ? vector (lo.real(), lo.imag()) : vector (hi.real(), hi.imag()); + const vector exchanged = glsl::subgroupShuffleXor< vector > (toTrade, stride); + if (topHalf) + { + lo.real(exchanged.x); + lo.imag(exchanged.y); + } + else + { + hi.real(exchanged.x); + hi.imag(exchanged.y); + } + } + + static void __call(NBL_REF_ARG(complex_t) lo, NBL_REF_ARG(complex_t) hi) + { + const uint32_t subgroupSize = glsl::gl_SubgroupSize(); //This is N/2 + const uint32_t doubleSubgroupSize = subgroupSize << 1; //This is N + + // Decimation in Time + [unroll] + for (uint32_t stride = 1; stride < subgroupSize; stride <<= 1) + FFT_loop(stride, lo, hi); + + // special last iteration + fft::DIT::radix2(fft::twiddle(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi); + divides_assign< complex_t > divAss; + divAss(lo, Scalar(doubleSubgroupSize)); + divAss(hi, Scalar(doubleSubgroupSize)); + } +}; + + +} +} +} + +#endif \ No newline at end of file From c6756bbe286290a66f9a328d69b0f6f9492f59b9 Mon Sep 17 00:00:00 2001 From: Fletterio Date: Wed, 8 Apr 2026 16:37:08 -0300 Subject: [PATCH 02/15] Subgroup Interleaved FFT, workgroup indexing math --- examples_tests | 2 +- include/nbl/builtin/hlsl/fft2/common.hlsl | 41 +-- include/nbl/builtin/hlsl/math/intutil.hlsl | 97 ++++++- include/nbl/builtin/hlsl/subgroup2/fft.hlsl | 258 +++++++++++++++---- include/nbl/builtin/hlsl/workgroup2/fft.hlsl | 199 ++++++++++++++ 5 files changed, 527 insertions(+), 70 deletions(-) create mode 100644 include/nbl/builtin/hlsl/workgroup2/fft.hlsl diff --git a/examples_tests b/examples_tests index 02fac2d163..1d787b3725 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 02fac2d1633dd0405210e40a463e1180426133b9 +Subproject commit 1d787b3725970ae8c706c858fc7b675faff952a1 diff --git a/include/nbl/builtin/hlsl/fft2/common.hlsl b/include/nbl/builtin/hlsl/fft2/common.hlsl index e265ae841a..e1f7530083 100644 --- a/include/nbl/builtin/hlsl/fft2/common.hlsl +++ b/include/nbl/builtin/hlsl/fft2/common.hlsl @@ -1,5 +1,5 @@ -#ifndef _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_ -#define _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_ +#ifndef _NBL_BUILTIN_HLSL_FFT2_COMMON_INCLUDED_ +#define _NBL_BUILTIN_HLSL_FFT2_COMMON_INCLUDED_ #include #include @@ -15,7 +15,21 @@ namespace hlsl namespace fft2 { -template 0 && N <= 4 && mpl::is_pot_v) +uint32_t padDimension(uint32_t dimension, uint16_t subgroupSize) +{ + const uint16_t subgroupFFTSize = subgroupSize << 1; + uint32_t padded = hlsl::roundUpToPoT(dimension); + if (padded <= subgroupFFTSize) + return subgroupFFTSize; + // Consider a factor of 3 + padded = hlsl::min(padded, 3 * hlsl::roundUpToPoT(hlsl::ceilDiv(dimension, 3))); + // Do the same for a factor of 5 + padded = hlsl::min(padded, 5 * hlsl::roundUpToPoT(hlsl::ceilDiv(dimension, 5))); + // TODO: Consider if factor of 7 is viable + return padded; +} + +template 0 && N <= 4) /** * @brief Returns the size of the full FFT computed, in terms of number of complex elements. If the signal is real, you MUST provide a valid value for `firstAxis` (this is to run two real FFTs as one complex). If the signal is complex, you must NOT pass any value to `firstAxis`. @@ -27,21 +41,11 @@ template 0 && N <= 4 && * @param [in] dimensions Size of the signal. * @param [in] firstAxis Indicates which axis the FFT is performed on first. Only relevant for real-valued signals. */ -inline vector padDimensions(vector dimensions, uint16_t firstAxis = N) +inline vector padDimensions(vector dimensions, uint16_t subgroupSize, uint16_t firstAxis = N) { - const subgroupFFTSize = SubgroupSize << 1; vector newDimensions; for (uint16_t i = 0u; i < N; i++) - { - newDimensions[i] = hlsl::roundUpToPoT(dimensions[i]); - if (dimensions[i] <= subgroupFFTSize) - continue; - // Consider a factor of 3 - newDimensions[i] = hlsl::min(newDimensions[i], 3 * hlsl::roundUpToPoT(hlsl::ceilDiv(dimensions[i], 3))); - // Do the same for a factor of 5 - newDimensions[i] = hlsl::min(newDimensions[i], 5 * hlsl::roundUpToPoT(hlsl::ceilDiv(dimensions[i], 5))); - // TODO: Consider if factor of 7 is viable - } + newDimensions[i] = padDimension(dimensions[i], subgroupSize); // If real, first axis gets halved since we run two real FFTs at once if (firstAxis < N) newDimensions[firstAxis] /= 2; @@ -65,12 +69,13 @@ inline uint64_t getOutputBufferSize( uint32_t numChannels, vector inputDimensions, uint16_t passIx, + uint16_t subgroupSize, vector axisPassOrder = _static_cast >(uint16_t4(0, 1, 2, 3)), bool realFFT = false, bool halfFloats = false ) { - const vector paddedDimensions = padDimensions(inputDimensions, realFFT ? axisPassOrder[0] : N); + const vector paddedDimensions = padDimensions(inputDimensions, subgroupSize, realFFT ? axisPassOrder[0] : N); vector axesDone = promote, bool>(false); for (uint16_t i = 0; i <= passIx; i++) axesDone[axisPassOrder[i]] = true; @@ -100,13 +105,13 @@ inline uint64_t getOutputBufferSizeConvolution( vector inputDimensions, vector kernelDimensions, uint16_t passIx, + uint16_t subgroupSize, vector axisPassOrder = _static_cast >(uint16_t4(0, 1, 2, 3)), bool realFFT = false, - bool halfFloats = false ) { - const vector paddedDimensions = padDimensions(inputDimensions + kernelDimensions, realFFT ? axisPassOrder[0] : N); + const vector paddedDimensions = padDimensions(inputDimensions + kernelDimensions, subgroupSize, realFFT ? axisPassOrder[0] : N); vector axesDone = promote, bool>(false); for (uint16_t i = 0; i <= passIx; i++) axesDone[axisPassOrder[i]] = true; diff --git a/include/nbl/builtin/hlsl/math/intutil.hlsl b/include/nbl/builtin/hlsl/math/intutil.hlsl index 1f00b53cf6..d588f17725 100644 --- a/include/nbl/builtin/hlsl/math/intutil.hlsl +++ b/include/nbl/builtin/hlsl/math/intutil.hlsl @@ -21,11 +21,17 @@ NBL_CONSTEXPR_FORCED_INLINE_FUNC bool isPoT(Integer value) return !isNPoT(value); } +template) +// Returns ceiled log2 +NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer log2ceil(Integer value) +{ + return Integer(1 + hlsl::findMSB(value - Integer(1))); +} template) NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer roundUpToPoT(Integer value) { - return Integer(0x1u) << Integer(1 + hlsl::findMSB(value - Integer(1))); // this wont result in constexpr because findMSB is not one + return Integer(0x1u) << log2ceil(value); } template) @@ -35,7 +41,7 @@ NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer roundDownToPoT(Integer value) } template) - NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer ceilDiv(Integer dividend, Integer divisor) +NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer ceilDiv(Integer dividend, Integer divisor) { return (dividend + divisor - 1) / divisor; } @@ -63,6 +69,93 @@ NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer align(Integer alignment, Integer size, return address = nextAlignedAddr; } +// Bitshift utils +// TODO: These can be expanded to shift by more than just one position at a time +// TODO: Can be made to wrok on uint64_t + +// Given an N-bit number stored as uint32_t, performs a circular bit shift right on the upper H bits +template +enable_if_t<(1 < H) && (H <= N) && (N <= 32), uint32_t> circularBitShiftRightHigher(uint32_t i) +{ + // Highest H bits are numbered N-1 through N - H + // N - H is then the middle bit + // Lowest bits numbered from 0 through N - H - 1 + NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1; + NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = 1 << (N - H); + NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~(lowMask | midMask); + + uint32_t low = i & lowMask; + uint32_t mid = i & midMask; + uint32_t high = i & highMask; + + high >>= 1; + mid <<= H - 1; + + return mid | high | low; +} + +// Given an N-bit number stored as uint32_t, performs a circular bit shift left on the upper H bits +template +enable_if_t<(1 < H) && (H <= N) && (N < 32), uint32_t> circularBitShiftLeftHigher(uint32_t i) +{ + // Highest H bits are numbered N-1 through N - H + // N - 1 is then the highest bit, and N - 2 through N - H are the middle bits + // Lowest bits numbered from 0 through N - H - 1 + NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1; + NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = 1 << (N - 1); + NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = ~(lowMask | highMask); + + uint32_t low = i & lowMask; + uint32_t mid = i & midMask; + uint32_t high = i & highMask; + + mid <<= 1; + high >>= H - 1; + + return mid | high | low; +} +// Perform a circular bit shift right on the lower L bits of a number +template +enable_if_t<(1 < L), uint32_t> circularBitShiftRightLower(uint32_t i) +{ + // Lowest bit is indexed 0 + // Middle bits numbered 1 to L-1 + // Highest bits numbered from L through N-1 + NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = 1; + NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = ((1 << L) - 1) ^ 1; + NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~(lowMask | midMask); + + uint32_t low = i & lowMask; + uint32_t mid = i & midMask; + uint32_t high = i & highMask; + + low <<= L - 1; + mid >>= 1; + + return high | low | mid; +} + +// Perform a circular bit shift left on the lower L bits of a number +template +enable_if_t<(1 < L), uint32_t> circularBitShiftLeftLower(uint32_t i) +{ + // Lowest L - 1 bits numbered 0 through L - 2 + // L - 1 is then the middle bit + // L through N-1 the higher bits + NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (L - 1)) - 1; + NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = 1 << (L - 1); + NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~(lowMask | midMask); + + uint32_t low = i & lowMask; + uint32_t mid = i & midMask; + uint32_t high = i & highMask; + + low <<= 1; + mid >>= L - 1; + + return high | low | mid; +} + // ------------------------------------- CPP ONLY ---------------------------------------------------------- #ifndef __HLSL_VERSION diff --git a/include/nbl/builtin/hlsl/subgroup2/fft.hlsl b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl index 34044fb964..fffadd4aa8 100644 --- a/include/nbl/builtin/hlsl/subgroup2/fft.hlsl +++ b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl @@ -1,7 +1,7 @@ -#ifndef _NBL_BUILTIN_HLSL_SUBGROUP_FFT_INCLUDED_ -#define _NBL_BUILTIN_HLSL_SUBGROUP_FFT_INCLUDED_ +#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_FFT_INCLUDED_ +#define _NBL_BUILTIN_HLSL_SUBGROUP2_FFT_INCLUDED_ -#include "nbl/builtin/hlsl/fft/common.hlsl" +#include "nbl/builtin/hlsl/fft2/common.hlsl" #include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" #include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl" #include "nbl/builtin/hlsl/concepts/accessors/fft.hlsl" @@ -14,92 +14,252 @@ namespace subgroup2 { // ----------------------------------------------------------------------------------------------------------------------------------------------------------------- -template +template struct FFT { template - static void __call(NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor); + static void __call(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor); }; // ---------------------------------------- Radix 2 forward transform - DIF ------------------------------------------------------- -template -struct FFT +template +struct FFT { - static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t) lo, NBL_REF_ARG(complex_t) hi) + template + static void FFT_loop(uint32_t stride, uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) { const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & stride); - const vector toTrade = topHalf ? vector (lo.real(), lo.imag()) : vector (hi.real(), hi.imag()); - const vector exchanged = glsl::subgroupShuffleXor< vector > (toTrade, stride); - if (topHalf) + // Get twiddle with k = subgroupInvocation mod stride, halfN = stride + const complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID() & (stride - 1), stride); + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) { - lo.real(exchanged.x); - lo.imag(exchanged.y); + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + const vector toTrade = topHalf ? vector (lo.real(), lo.imag()) : vector (hi.real(), hi.imag()); + const vector exchanged = glsl::subgroupShuffleXor< vector >(toTrade, stride); + if (topHalf) + { + lo.real(exchanged.x); + lo.imag(exchanged.y); + } + else + { + hi.real(exchanged.x); + hi.imag(exchanged.y); + } + fft2::DIF::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); } - else + } + + template + static void __call(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) + { + // special first iteration + const complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID(), SubgroupSize); + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) { - hi.real(exchanged.x); - hi.imag(exchanged.y); - } - // Get twiddle with k = subgroupInvocation mod stride, halfN = stride - fft::DIF::radix2(fft::twiddle(glsl::gl_SubgroupInvocationID() & (stride - 1), stride), lo, hi); + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIF::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + + // Decimation in Frequency + [unroll] + for (uint32_t stride = SubgroupSize >> 1; stride > 0; stride >>= 1) + FFT_loop(stride, lowChannel, highChannel, loAccessor, hiAccessor); } - static void __call(NBL_REF_ARG(complex_t) lo, NBL_REF_ARG(complex_t) hi) + // Interleaved versions of the above methods, required to implement the first steps in Interleaved DIF + template + static void FFT_loop(uint32_t elementStride, uint32_t threadStride, uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) { - const uint32_t subgroupSize = glsl::gl_SubgroupSize(); //This is N/2 + const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & threadStride); + // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod elementStride, halfN = elementStride + const uint32_t loLaneIndex = (glsl::gl_SubgroupInvocationID() << NumSubgroupsLog2) + glsl::gl_SubgroupID(); + const complex_t twiddle = fft2::twiddle(loLaneIndex & (elementStride - 1), elementStride); + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + const vector toTrade = topHalf ? vector (lo.real(), lo.imag()) : vector (hi.real(), hi.imag()); + const vector exchanged = glsl::subgroupShuffleXor< vector >(toTrade, threadStride); + if (topHalf) + { + lo.real(exchanged.x); + lo.imag(exchanged.y); + } + else + { + hi.real(exchanged.x); + hi.imag(exchanged.y); + } + fft2::DIF::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + } + // Only uses subgroup methods, but is actually used at workgroup level + template + static void __callInterleaved(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) + { // special first iteration - fft::DIF::radix2(fft::twiddle(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi); - + // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod WorkgroupSize, halfN = WorkgroupSize + const uint32_t loLaneIndex = (glsl::gl_SubgroupInvocationID() << NumSubgroupsLog2) + glsl::gl_SubgroupID(); + const complex_t twiddle = fft2::twiddle(loLaneIndex, WorkgroupSize); + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIF::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + // Decimation in Frequency [unroll] - for (uint32_t stride = subgroupSize >> 1; stride > 0; stride >>= 1) - FFT_loop(stride, lo, hi); + uint32_t threadStride = SubgroupSize >> 1; + for (uint32_t elementStride = WorkgroupSize >> 1; elementStride > SubgroupSize; elementStride >>= 1) + { + FFT_loop(elementStride, threadStride, lowChannel, highChannel, loAccessor, hiAccessor); + threadStride >>= 1; + } } }; // ---------------------------------------- Radix 2 inverse transform - DIT ------------------------------------------------------- -template -struct FFT +template +struct FFT { - static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t) lo, NBL_REF_ARG(complex_t) hi) + template + static void FFT_loop(uint32_t stride, uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) { + const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & stride); // Get twiddle with k = subgroupInvocation mod stride, halfN = stride - fft::DIT::radix2(fft::twiddle(glsl::gl_SubgroupInvocationID() & (stride - 1), stride), lo, hi); + const complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID() & (stride - 1), stride); - const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & stride); - const vector toTrade = topHalf ? vector (lo.real(), lo.imag()) : vector (hi.real(), hi.imag()); - const vector exchanged = glsl::subgroupShuffleXor< vector > (toTrade, stride); - if (topHalf) + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) { - lo.real(exchanged.x); - lo.imag(exchanged.y); + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIT::radix2(twiddle, lo, hi); + + const vector toTrade = topHalf ? vector (lo.real(), lo.imag()) : vector (hi.real(), hi.imag()); + const vector exchanged = glsl::subgroupShuffleXor< vector >(toTrade, stride); + if (topHalf) + { + lo.real(exchanged.x); + lo.imag(exchanged.y); + } + else + { + hi.real(exchanged.x); + hi.imag(exchanged.y); + } + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); } - else + } + + template + static void __call(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) + { + // Decimation in Time + [unroll] + for (uint32_t stride = 1; stride < SubgroupSize; stride <<= 1) + FFT_loop(stride, lowChannel, highChannel, loAccessor, hiAccessor); + + // special last iteration + const complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID(), SubgroupSize); + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) { - hi.real(exchanged.x); - hi.imag(exchanged.y); + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIT::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); } } - static void __call(NBL_REF_ARG(complex_t) lo, NBL_REF_ARG(complex_t) hi) + // Interleaved versions of the above methods, required to implement the last steps in Interleaved DIT + template + static void FFT_loop(uint32_t elementStride, uint32_t threadStride, uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) + { + const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & threadStride); + // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod elementStride, halfN = elementStride + const uint32_t loLaneIndex = (glsl::gl_SubgroupInvocationID() << NumSubgroupsLog2) + glsl::gl_SubgroupID(); + const complex_t twiddle = fft2::twiddle(loLaneIndex & (elementStride - 1), elementStride); + + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIT::radix2(twiddle, lo, hi); + + const vector toTrade = topHalf ? vector (lo.real(), lo.imag()) : vector (hi.real(), hi.imag()); + const vector exchanged = glsl::subgroupShuffleXor< vector >(toTrade, threadStride); + if (topHalf) + { + lo.real(exchanged.x); + lo.imag(exchanged.y); + } + else + { + hi.real(exchanged.x); + hi.imag(exchanged.y); + } + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + } + + template + static void __callInterleaved(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) { - const uint32_t subgroupSize = glsl::gl_SubgroupSize(); //This is N/2 - const uint32_t doubleSubgroupSize = subgroupSize << 1; //This is N - // Decimation in Time [unroll] - for (uint32_t stride = 1; stride < subgroupSize; stride <<= 1) - FFT_loop(stride, lo, hi); - + uint32_t threadStride = SubgroupSize >> (NumSubgroupsLog2 - 1); + for (uint32_t elementStride = SubgroupSize << 1; elementStride < WorkgroupSize; elementStride <<= 1) + { + FFT_loop(elementStride, threadStride, lowChannel, highChannel, loAccessor, hiAccessor); + threadStride <<= 1; + } + // special last iteration - fft::DIT::radix2(fft::twiddle(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi); - divides_assign< complex_t > divAss; - divAss(lo, Scalar(doubleSubgroupSize)); - divAss(hi, Scalar(doubleSubgroupSize)); + // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod WorkgroupSize, halfN = WorkgroupSize + const uint32_t loLaneIndex = (glsl::gl_SubgroupInvocationID() << NumSubgroupsLog2) + glsl::gl_SubgroupID(); + const complex_t twiddle = fft2::twiddle(loLaneIndex, WorkgroupSize); + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIT::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } } }; diff --git a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl new file mode 100644 index 0000000000..c657324d85 --- /dev/null +++ b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl @@ -0,0 +1,199 @@ +#include +#include +#include + +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_FFT_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP2_FFT_INCLUDED_ + +// ------------------------------- COMMON ----------------------------------------- + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup2 +{ +namespace fft +{ +// Minimum size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT +template +uint32_t minimumSharedMemoryDWORDs(uint16_t workgroupSizeLog2) +{ + NBL_IF_CONSTEXPR(Interleaved) + { + return (sizeof(complex_t) / sizeof(uint32_t)) << (workgroupSizeLog2 + 1); + } + else + { + return (sizeof(complex_t) / sizeof(uint32_t)) << workgroupSizeLog2; + } +} + +template 1 && !(_ElementsPerInvocation & 1) && _WorkgroupSizeLog2 >= 5) +struct ConstevalParameters +{ + using scalar_t = _Scalar; + + NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = _ElementsPerInvocation; + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTTotalSize = ElementsPerInvocation * (uint32_t(1) << WorkgroupSizeLog2); + NBL_CONSTEXPR_STATIC_INLINE uint16_t ShuffledElementsPerRound = _ShuffledElementsPerRound; + NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = ShuffledElementsPerRound * (sizeof(complex_t) / sizeof(uint32_t)) << WorkgroupSizeLog2; + + NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(1) << WorkgroupSizeLog2; +}; +} + +struct OptimalFFTParameters +{ + uint16_t elementsPerInvocation : 8; + uint16_t workgroupSizeLog2 : 8; + + // Used to check if the parameters returned by `optimalFFTParameters` are valid + bool areValid() + { + return elementsPerInvocation > 0 && workgroupSizeLog2 > 0; + } +}; + +/** +* @brief Returns the best parameters (according to our metric) to run an FFT +* +* @param [in] maxWorkgroupSize The max number of threads that can be launched in a single workgroup +* @param [in] inputArrayLength The length of the array to run an FFT on +* @param [in] subgroupSize Number of threads running in a subgroup +*/ +inline OptimalFFTParameters optimalFFTParameters(uint32_t maxWorkgroupSize, uint32_t inputArrayLength, uint32_t subgroupSize) +{ + NBL_CONSTEXPR_STATIC OptimalFFTParameters invalidParameters = { 0 , 0 }; + + if (subgroupSize < 2 || maxWorkgroupSize < subgroupSize || inputArrayLength <= subgroupSize) + return invalidParameters; + // Pad inputarrayLength to size that FFT algo handles + const uint32_t FFTLength = hlsl::fft2::padDimension(inputArrayLength, subgroupSize); + // Round maxWorkgroupSize down to PoT + const uint32_t actualMaxWorkgroupSize = hlsl::roundDownToPoT(maxWorkgroupSize); + // Max number of threads that can run the FFT + uint32_t maxThreads = FFTLength / 2; + // Factors of 3 and 5 do not contribute to the amount of max threads since those are handled in-register to keep everything PoT later + maxThreads /= maxThreads % 3 ? 1 : 3; + maxThreads /= maxThreads % 5 ? 1 : 5; + // Both are PoT + const uint16_t workgroupSizeLog2 = findMSB(min(maxThreads, actualMaxWorkgroupSize)); + + // Parameters are valid if the workgroup size is at most half of the FFT Length and at least as big as the subgroupSize + if ((FFTLength >> workgroupSizeLog2) <= 1 || subgroupSize > (1u << workgroupSizeLog2)) + { + return invalidParameters; + } + + const uint16_t elementsPerInvocation = FFTLength >> workgroupSizeLog2; + const OptimalFFTParameters retVal = { elementsPerInvocation, workgroupSizeLog2 }; + + return retVal; +} + +namespace impl +{ +template +struct FFTIndexingUtilsHelper +{ + // Maps the lane of index `laneIdx` at the end of the DIF diagram to its corresponding frequency position as an element of the DFT. + static uint32_t mapLaneToFreq(uint32_t laneIdx) + { + NBL_IF_CONSTEXPR(ExtraPrimeFactor > 1) + { + const uint32_t radix2mask = (1 << Radix2FFTSizeLog2) - 1; + return ExtraPrimeFactor * hlsl::bitReverseAs(laneIdx, Radix2FFTSizeLog2) + (laneIdx >> Radix2FFTSizeLog2); + } + else + { + return hlsl::bitReverseAs(laneIdx, Radix2FFTSizeLog2); + } + } + + // Implements fast division by 3 or 5, needed by `mapFreqtoLane` + static uint32_t fastDiv(uint32_t x) + { + NBL_IF_CONSTEXPR(ExtraPrimeFactor == 3) + { + return (x * 43691u) >> 17; // valid for x <= 98303 + } + else // ExtraPrimeFactor == 5 + { + return (x * 52429u) >> 18; // valid for x <= 81919 + } + } + + // Inverse of `mapLaneToFreq`. Maps a frequency index `freqIdx` into the DFT to the lane in the DIF diagram that outputs it. + static uint32_t mapFreqToLane(uint32_t freqIdx) + { + NBL_IF_CONSTEXPR(ExtraPrimeFactor > 1) + { + const uint32_t divByPrimeFactor = fastDiv(freqIdx); + return hlsl::bitReverseAs(divByPrimeFactor, Radix2FFTSizeLog2) + ((freqIdx - ExtraPrimeFactor * divByPrimeFactor) << Radix2FFTSizeLog2); + } + else + { + return hlsl::bitReverseAs(freqIdx, Radix2FFTSizeLog2); + } + } + + // log2(ElementsPerInvocation / ExtraPrimeFactor) + NBL_CONSTEXPR_STATIC_INLINE uint16_t Radix2FFTSizeLog2 = WorkgroupSizeLog2 + Radix2ElementsPerInvocationLog2; + // Size of the full FFT if no mixed radix used, otherwise size of the sub-FFTs computed after the first radix-3/5 step in the DIF forward FFT + NBL_CONSTEXPR_STATIC_INLINE uint16_t Radix2FFTSize = uint32_t(1) << (Radix2FFTSizeLog2); + // Total size of the FFT computed + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = ExtraPrimeFactor * Radix2FFTSize; +}; +} + +template +struct FFTIndexingUtils +{ + using helper_t = impl::FFTIndexingUtilsHelper; + + // Maps the array index 'arrayIdx' of the output of an FFT in workgroup-linear order (meaning all threads write their local element 0 contiguously + // and in ascending order by threadIndex, then element 1 and so on) to its corresponding frequency position as an element of the DFT. + static uint32_t mapArrayToFreq(uint32_t arrayIdx) + { + return helper_t::mapLaneToFreq(circularBitShiftLeftLower(arrayIdx)); + } + + // Maps a frequency index 'freqIdx' into the DFT to its corresponding position in the output array of an FFT when written in workgroup-linear order. + static uint32_t mapFreqToArray(uint32_t freqIdx) + { + return circularBitShiftRightLower(helper_t::mapFreqToLane(freqIdx)); + } + + // Mirrors an index about the Nyquist frequency in the DFT order + static uint32_t getDFTMirrorIndex(uint32_t freqIdx) + { + return (FFTSize - freqIdx) & (FFTSize - 1); + } + + // Given an index `arrayIdx` of an element into the output array of an FFT, get the index into the same array of the element corresponding + // to its negative frequency + static uint32_t getNablaMirrorIndex(uint32_t arrayIdx) + { + return mapFreqToArray(getDFTMirrorIndex(mapArrayToFreq(arrayIdx))); + } + + // log2(ElementsPerInvocation / ExtraPrimeFactor) + NBL_CONSTEXPR_STATIC_INLINE uint16_t Radix2FFTSizeLog2 = helper_t::Radix2FFTSizeLog2; + // Size of the full FFT if no mixed radix used, otherwise size of the sub-FFTs computed after the first radix-3/5 step in the DIF forward FFT + NBL_CONSTEXPR_STATIC_INLINE uint16_t Radix2FFTSize = helper_t::Radix2FFTSize; + // Total size of the FFT computed + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = helper_t::FFTSize; +}; + +} +} +} +// ------------------------------- END COMMON --------------------------------------------- + + + + +#endif From 98764973cb124aff960a6fadc006ead5a8868985 Mon Sep 17 00:00:00 2001 From: Fletterio Date: Thu, 9 Apr 2026 21:34:59 -0300 Subject: [PATCH 03/15] Working, division policy missing --- examples_tests | 2 +- include/nbl/builtin/hlsl/fft2/common.hlsl | 6 +- include/nbl/builtin/hlsl/math/intutil.hlsl | 24 +- include/nbl/builtin/hlsl/mpl.hlsl | 6 + include/nbl/builtin/hlsl/subgroup2/fft.hlsl | 6 +- include/nbl/builtin/hlsl/workgroup2/fft.hlsl | 287 ++++++++++++++++++- 6 files changed, 304 insertions(+), 27 deletions(-) diff --git a/examples_tests b/examples_tests index 1d787b3725..d84bcfcb04 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 1d787b3725970ae8c706c858fc7b675faff952a1 +Subproject commit d84bcfcb04960d5d69b47fe969e10c676ac8ca86 diff --git a/include/nbl/builtin/hlsl/fft2/common.hlsl b/include/nbl/builtin/hlsl/fft2/common.hlsl index e1f7530083..f765c3153d 100644 --- a/include/nbl/builtin/hlsl/fft2/common.hlsl +++ b/include/nbl/builtin/hlsl/fft2/common.hlsl @@ -176,8 +176,8 @@ void unpack(NBL_REF_ARG(complex_t) lo, NBL_REF_ARG(complex_t) hi lo = x; } -} -} -} +} //namespace fft2 +} //namespace hlsl +} //namespace nbl #endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/math/intutil.hlsl b/include/nbl/builtin/hlsl/math/intutil.hlsl index d588f17725..2a76ec6234 100644 --- a/include/nbl/builtin/hlsl/math/intutil.hlsl +++ b/include/nbl/builtin/hlsl/math/intutil.hlsl @@ -80,9 +80,9 @@ enable_if_t<(1 < H) && (H <= N) && (N <= 32), uint32_t> circularBitShiftRightHig // Highest H bits are numbered N-1 through N - H // N - H is then the middle bit // Lowest bits numbered from 0 through N - H - 1 - NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1; - NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = 1 << (N - H); - NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~(lowMask | midMask); + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t lowMask = (1 << (N - H)) - 1; + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t midMask = 1 << (N - H); + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t highMask = ~(lowMask | midMask); uint32_t low = i & lowMask; uint32_t mid = i & midMask; @@ -101,9 +101,9 @@ enable_if_t<(1 < H) && (H <= N) && (N < 32), uint32_t> circularBitShiftLeftHighe // Highest H bits are numbered N-1 through N - H // N - 1 is then the highest bit, and N - 2 through N - H are the middle bits // Lowest bits numbered from 0 through N - H - 1 - NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1; - NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = 1 << (N - 1); - NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = ~(lowMask | highMask); + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t lowMask = (1 << (N - H)) - 1; + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t highMask = 1 << (N - 1); + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t midMask = ~(lowMask | highMask); uint32_t low = i & lowMask; uint32_t mid = i & midMask; @@ -121,9 +121,9 @@ enable_if_t<(1 < L), uint32_t> circularBitShiftRightLower(uint32_t i) // Lowest bit is indexed 0 // Middle bits numbered 1 to L-1 // Highest bits numbered from L through N-1 - NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = 1; - NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = ((1 << L) - 1) ^ 1; - NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~(lowMask | midMask); + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t lowMask = 1; + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t midMask = ((1 << L) - 1) ^ 1; + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t highMask = ~(lowMask | midMask); uint32_t low = i & lowMask; uint32_t mid = i & midMask; @@ -142,9 +142,9 @@ enable_if_t<(1 < L), uint32_t> circularBitShiftLeftLower(uint32_t i) // Lowest L - 1 bits numbered 0 through L - 2 // L - 1 is then the middle bit // L through N-1 the higher bits - NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (L - 1)) - 1; - NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = 1 << (L - 1); - NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~(lowMask | midMask); + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t lowMask = (1 << (L - 1)) - 1; + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t midMask = 1 << (L - 1); + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t highMask = ~(lowMask | midMask); uint32_t low = i & lowMask; uint32_t mid = i & midMask; diff --git a/include/nbl/builtin/hlsl/mpl.hlsl b/include/nbl/builtin/hlsl/mpl.hlsl index 7734dea15f..9fb5372a8f 100644 --- a/include/nbl/builtin/hlsl/mpl.hlsl +++ b/include/nbl/builtin/hlsl/mpl.hlsl @@ -123,6 +123,12 @@ struct find_lsb }; template NBL_CONSTEXPR_INLINE_NSPC_SCOPE_VAR uint64_t find_lsb_v = find_lsb::value; + +template +struct ceil_div : integral_constant {}; +template +NBL_CONSTEXPR_INLINE_NSPC_SCOPE_VAR uint64_t ceil_div_v = ceil_div::value; + } } } diff --git a/include/nbl/builtin/hlsl/subgroup2/fft.hlsl b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl index fffadd4aa8..54b116aa21 100644 --- a/include/nbl/builtin/hlsl/subgroup2/fft.hlsl +++ b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl @@ -264,8 +264,8 @@ struct FFT }; -} -} -} +} //namespace subgroup2 +} //namespace hlsl +} //namespace nbl #endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl index c657324d85..fa895b38d4 100644 --- a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl @@ -13,6 +13,7 @@ namespace hlsl { namespace workgroup2 { + namespace fft { // Minimum size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT @@ -29,21 +30,51 @@ uint32_t minimumSharedMemoryDWORDs(uint16_t workgroupSizeLog2) } } -template 1 && !(_ElementsPerInvocation & 1) && _WorkgroupSizeLog2 >= 5) +// The DFT (and DFT) have two different formulations. One of them doesn't divide in DFT and divides by N in the IDFT. This makes the determinant of the DFT +// sqrt(N) and the determinant of the IDFT as sqrt(N)^(-1). This formulation is problematic, for example, when performing FFT Convolution of images when +// using half-precision, since if N is big this can make the FFT along the second axis (and sometimes even the first!) exceed the representable range and become +// NaNs. The other formulation divides by sqrt(N) on both the DFT and IDFT, giving a determinant of 1 for both transforms. This can be used to avoid +// overflow. +// These policies describe different ways of avoiding overflow by dividing at different moments through the algorithm. +struct DivisionPolicy +{ + // No division performed at any step of the FFT + NBL_CONSTEXPR_STATIC_INLINE uint16_t NoDivision = 0; + // Divides the array by sqrt(FFTSize) at the time of the last workgroup barrier before subgroupFFT (forward) or at the time of the first workgroup barrier + // after subgroupFFT (inverse). + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivBySqrtHalfway = NoDivision + 1; + // Divides the array by sqrt(FFTSize) right at the end of the algorithm + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivBySqrtAtEnd = DivBySqrtHalfway + 1; + // Divides the array by sqrt(FFTSize) by considering `sqrt(FFTSize) = a * b`, dividing by `a` halfway (as described in `DivBySqrtHalfway`) and then dividing + // by `b` at the end. `a` and `b` are chosen so that their weight is proportional to the number of butterflies before the division. + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivBySqrtByParts = DivBySqrtAtEnd + 1; + // The three following all perform divisions in the same manner as their counterparts above, but they divide the array by `FFTSize`. + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivByFullSizeHalfway = DivBySqrtByParts + 1; + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivByFullSizeAtEnd = DivByFullSizeHalfway + 1; + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivByFullSizeByParts = DivByFullSizeAtEnd + 1; +}; + +template //NBL_PRIMARY_REQUIRES(_ElementsPerInvocation > 1 && !(_ElementsPerInvocation & 1) && _WorkgroupSizeLog2 >= 5) struct ConstevalParameters { using scalar_t = _Scalar; NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = _ElementsPerInvocation; + NBL_CONSTEXPR_STATIC_INLINE uint16_t Channels = ElementsPerInvocation >> 1; NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = 1 << SubgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(1) << WorkgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t NumSubgroupsLog2 = WorkgroupSizeLog2 - SubgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTTotalSize = ElementsPerInvocation * (uint32_t(1) << WorkgroupSizeLog2); - NBL_CONSTEXPR_STATIC_INLINE uint16_t ShuffledElementsPerRound = _ShuffledElementsPerRound; - NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = ShuffledElementsPerRound * (sizeof(complex_t) / sizeof(uint32_t)) << WorkgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t ShuffledChannelsPerRound = _ShuffledChannelsPerRound; + NBL_CONSTEXPR_STATIC_INLINE uint16_t ShuffleRounds = mpl::ceil_div_v; + NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = ShuffledChannelsPerRound * ((sizeof(complex_t) / sizeof(uint32_t)) << (WorkgroupSizeLog2 + (_Interleaved ? 1 : 0))); - NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(1) << WorkgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivisionPolicy = _DivisionPolicy; + NBL_CONSTEXPR_STATIC_INLINE uint32_t }; -} +} //namespace fft struct OptimalFFTParameters { @@ -120,8 +151,8 @@ struct FFTIndexingUtilsHelper { return (x * 43691u) >> 17; // valid for x <= 98303 } - else // ExtraPrimeFactor == 5 - { + else // ExtraPrimeFactor == 5 + { return (x * 52429u) >> 18; // valid for x <= 81919 } } @@ -147,7 +178,7 @@ struct FFTIndexingUtilsHelper // Total size of the FFT computed NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = ExtraPrimeFactor * Radix2FFTSize; }; -} +} // namespace impl template struct FFTIndexingUtils @@ -188,12 +219,252 @@ struct FFTIndexingUtils NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = helper_t::FFTSize; }; +// TODO: Implement when doing 2D FFTConv +template +struct FFTMirrorTradeUtils; + } } } // ------------------------------- END COMMON --------------------------------------------- +// ------------------------------- HLSL ONLY --------------------------------------------- + +#ifdef __HLSL_VERSION + +#include "nbl/builtin/hlsl/subgroup2/fft.hlsl" +#include "nbl/builtin/hlsl/workgroup/basic.hlsl" +#include "nbl/builtin/hlsl/glsl_compat/core.hlsl" +#include "nbl/builtin/hlsl/mpl.hlsl" +#include "nbl/builtin/hlsl/memory_accessor.hlsl" +#include "nbl/builtin/hlsl/bit.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup2 +{ + +//-------------- ---------------------------------------- UTILS -------------------------------------------------------- + +namespace fft +{ +namespace impl +{ + +template +struct exchangeValues +{ + static void __call(uint32_t threadID, uint32_t ownedSmemIndex, uint32_t lowChannel, uint32_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor, NBL_REF_ARG(bool) pingPong) + { + const bool topHalf = bool(threadID & stride); + const uint32_t writeIndex = pingPong ? ownedSmemIndex ^ stride : ownedSmemIndex; + const uint32_t readIndex = pingPong ? ownedSmemIndex : ownedSmemIndex ^ stride; + // Write elements to sharedmem + uint32_t adaptorOffset = 0; + [unroll] + for (uint32_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + vector toExchange = topHalf ? vector(lo.real(), lo.imag()) : vector(hi.real(), hi.imag()); + sharedmemAdaptor.template set >(adaptorOffset | writeIndex, toExchange); + + adaptorOffset += WorkgroupSize; + } + // Wait until all writes are done before reading + sharedmemAdaptor.workgroupExecutionAndMemoryBarrier(); + + // Read elements from sharedmem + adaptorOffset = 0; + [unroll] + for (uint32_t channel = lowChannel; channel <= highChannel; channel++) + { + vector exchanged; + sharedmemAdaptor.template get >(adaptorOffset | readIndex, exchanged); + complex_t complex_exchanged = { exchanged.x, exchanged.y }; + if (topHalf) + { + loAccessor.set(channel, complex_exchanged); + } + else + { + hiAccessor.set(channel, complex_exchanged); + } + + adaptorOffset += WorkgroupSize; + } + } + + static void __callInterleaved() + { + // TODO + } +}; + +} //namespace impl +} //namespace fft + +//-------------- ------------------------------------ END UTILS -------------------------------------------------------- +template +struct FFT; + +// Non-interleaved (shuffle after every butterfly) forward FFT +template +struct FFT, device_capabilities> +{ + using consteval_parameters_t = fft::ConstevalParameters; + using scalar_t = typename consteval_parameters_t::scalar_t; + + template + static void FFT_loop(uint32_t stride, uint32_t threadID, NBL_REF_ARG(uint32_t) ownedSmemIndex, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor) + { + const uint32_t ShuffleRounds = consteval_parameters_t::ShuffleRounds; + const uint16_t Channels = consteval_parameters_t::Channels; + // Get twiddle with k = threadID mod stride, halfN = stride + const complex_t twiddle = hlsl::fft::twiddle(threadID & (stride - 1), stride); + + bool pingPong = false; + [unroll] + for (uint32_t round = 0; round < ShuffleRounds; round++) + { + if (round) + pingPong = !pingPong; // ping pong on sharedmem to avoid barriering - this eploits that we XOR with the same stride every consecutive round + const uint32_t lowChannel = round * ShuffledChannelsPerRound; + const uint32_t highChannel = min(Channels, lowChannel + ShuffledChannelsPerRound) - 1; + [unroll] + for (uint32_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIF::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + + fft::impl::exchangeValues::__call(threadID, ownedSmemIndex, lowChannel, highChannel, loAccessor, hiAccessor, stride >> 1, sharedmemAdaptor, pingPong); + } + // After the last exchangeValues, the memory we just read from is now owned by us, so update + ownedSmemIndex = pingPong ? ownedSmemIndex : ownedSmemIndex ^ (stride >> 1); + } + + template + static void __call(NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor) + { + const uint16_t Channels = consteval_parameters_t::Channels; + const uint16_t SubgroupSize = consteval_parameters_t::SubgroupSize; + const uint16_t WorkgroupSize = consteval_parameters_t::WorkgroupSize; + + // Get workgroup threadID + const uint32_t threadID = uint32_t(workgroup::SubgroupContiguousIndex()); + + // If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps + if (WorkgroupSize > SubgroupSize) + { + // Set up the memory adaptor + using adaptor_t = accessor_adaptors::StructureOfArrays; + adaptor_t sharedmemAdaptor; + sharedmemAdaptor.accessor = sharedmemAccessor; + + uint32_t ownedSmemIndex = threadID; + [unroll] + for (uint32_t stride = WorkgroupSize; stride > SubgroupSize; stride >>= 1) + { + FFT_loop(stride, threadID, ownedSmemIndex, loAccessor, hiAccessor, sharedmemAdaptor); + } + + // Remember to update the accessor's state + sharedmemAccessor = sharedmemAdaptor.accessor; + } + // Subgroup-sized FFT + subgroup2::FFT::__call(0, Channels - 1, loAccessor, hiAccessor); + } +}; + +// Non-interleaved (shuffle after every butterfly) inverse FFT +template +struct FFT, device_capabilities> +{ + using consteval_parameters_t = fft::ConstevalParameters; + using scalar_t = typename consteval_parameters_t::scalar_t; + + template + static void FFT_loop(uint32_t stride, uint32_t threadID, NBL_REF_ARG(uint32_t) ownedSmemIndex, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor) + { + const uint32_t ShuffleRounds = consteval_parameters_t::ShuffleRounds; + const uint16_t Channels = consteval_parameters_t::Channels; + // Get twiddle with k = threadID mod stride, halfN = stride + const complex_t twiddle = hlsl::fft::twiddle(threadID & ((stride << 1) - 1), stride << 1); + + bool pingPong = false; + [unroll] + for (uint32_t round = 0; round < ShuffleRounds; round++) + { + if (round) + pingPong = !pingPong; // ping pong on sharedmem to avoid barriering - this eploits that we XOR with the same stride every consecutive round + const uint32_t lowChannel = round * ShuffledChannelsPerRound; + const uint32_t highChannel = min(Channels, lowChannel + ShuffledChannelsPerRound) - 1; + + fft::impl::exchangeValues::__call(threadID, ownedSmemIndex, lowChannel, highChannel, loAccessor, hiAccessor, stride, sharedmemAdaptor, pingPong); + + [unroll] + for (uint32_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIT::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + } + // After the last exchangeValues, the memory we just read from is now owned by us, so update + ownedSmemIndex = pingPong ? ownedSmemIndex : ownedSmemIndex ^ (stride >> 1); + } + + template + static void __call(NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor) + { + const uint16_t Channels = consteval_parameters_t::Channels; + const uint16_t SubgroupSize = consteval_parameters_t::SubgroupSize; + const uint16_t WorkgroupSize = consteval_parameters_t::WorkgroupSize; + + // Subgroup-sized FFT at the start + subgroup2::FFT::__call(0, Channels - 1, loAccessor, hiAccessor); + + // Get workgroup threadID + const uint32_t threadID = uint32_t(workgroup::SubgroupContiguousIndex()); + + // If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps + if (WorkgroupSize > SubgroupSize) + { + // Set up the memory adaptor + using adaptor_t = accessor_adaptors::StructureOfArrays; + adaptor_t sharedmemAdaptor; + sharedmemAdaptor.accessor = sharedmemAccessor; + + uint32_t ownedSmemIndex = threadID; + [unroll] + for (uint32_t stride = SubgroupSize; stride < WorkgroupSize; stride <<= 1) + { + FFT_loop(stride, threadID, ownedSmemIndex, loAccessor, hiAccessor, sharedmemAdaptor); + } + + // Remember to update the accessor's state + sharedmemAccessor = sharedmemAdaptor.accessor; + } + } +}; + +} //namespace workgroup2 +} //namespace hlsl +} //namespace nbl + +#endif #endif From bc4ae7e6aacc02cc3d2ae7fb0b97cd3b118fc27d Mon Sep 17 00:00:00 2001 From: Fletterio Date: Sat, 11 Apr 2026 02:21:18 -0300 Subject: [PATCH 04/15] Add new radices --- include/nbl/builtin/hlsl/complex.hlsl | 72 ++++++----- include/nbl/builtin/hlsl/fft2/common.hlsl | 124 ++++++++++++++++++- include/nbl/builtin/hlsl/math/functions.hlsl | 2 +- include/nbl/builtin/hlsl/workgroup2/fft.hlsl | 2 +- 4 files changed, 161 insertions(+), 39 deletions(-) diff --git a/include/nbl/builtin/hlsl/complex.hlsl b/include/nbl/builtin/hlsl/complex.hlsl index 7e8f6526ec..96ca799b92 100644 --- a/include/nbl/builtin/hlsl/complex.hlsl +++ b/include/nbl/builtin/hlsl/complex.hlsl @@ -7,6 +7,7 @@ #include #include +#include using namespace nbl::hlsl; @@ -32,22 +33,6 @@ struct complex_t : public std::complex } }; -// Fast mul by i -template -complex_t rotateLeft(NBL_CONST_REF_ARG(complex_t) value) -{ - complex_t retVal = { -value.imag(), value.real() }; - return retVal; -} - -// Fast mul by -i -template -complex_t rotateRight(NBL_CONST_REF_ARG(complex_t) value) -{ - complex_t retVal = { value.imag(), -value.real() }; - return retVal; -} - } } @@ -414,28 +399,11 @@ complex_t proj(const complex_t c) template complex_t polar(const Scalar r, const Scalar theta) { - complex_t retVal = {r * cos(theta), r * sin(theta)}; + complex_t retVal = {r * nbl::hlsl::cos(theta), r * nbl::hlsl::sin(theta)}; return retVal; } -// --------------------------------------------- Some more functions that come in handy -------------------------------------- -// Fast mul by i -template -complex_t rotateLeft(NBL_CONST_REF_ARG(complex_t) value) -{ - complex_t retVal = { -value.imag(), value.real() }; - return retVal; -} - -// Fast mul by -i -template -complex_t rotateRight(NBL_CONST_REF_ARG(complex_t) value) -{ - complex_t retVal = { value.imag(), -value.real() }; - return retVal; -} - } } @@ -456,4 +424,40 @@ NBL_REGISTER_OBJ_TYPE(complex_t,::nbl::hlsl::alignment_of_v +complex_t rotateLeft(NBL_CONST_REF_ARG(complex_t) value) +{ + complex_t retVal = { -value.imag(), value.real() }; + return retVal; +} + +// Fast mul by -i +template +complex_t rotateRight(NBL_CONST_REF_ARG(complex_t) value) +{ + complex_t retVal = { value.imag(), -value.real() }; + return retVal; +} + +// Fast square +template +complex_t square(NBL_CONST_REF_ARG(complex_t) value) +{ + Scalar real = value.real() * value.real() - value.imag() * value.imag(); + Scalar imag = 2 * value.real() * value.imag(); + complex_t retVal = { real, imag }; + return retVal; +} + +} +} + #endif diff --git a/include/nbl/builtin/hlsl/fft2/common.hlsl b/include/nbl/builtin/hlsl/fft2/common.hlsl index f765c3153d..7158cd3ca4 100644 --- a/include/nbl/builtin/hlsl/fft2/common.hlsl +++ b/include/nbl/builtin/hlsl/fft2/common.hlsl @@ -130,11 +130,13 @@ complex_t twiddle(uint32_t k, uint32_t halfN) { complex_t retVal; const Scalar kthRootAngleRadians = numbers::pi *Scalar(k) / Scalar(halfN); - retVal.real(cos(kthRootAngleRadians)); + Scalar cosine = nbl::hlsl::cos(kthRootAngleRadians); + Scalar sine = nbl::hlsl::sin(kthRootAngleRadians); + retVal.real(cosine); if (!inverse) - retVal.imag(sin(-kthRootAngleRadians)); + retVal.imag(-sine); else - retVal.imag(sin(kthRootAngleRadians)); + retVal.imag(sine); return retVal; } @@ -157,6 +159,122 @@ struct DIX hi = twiddle * diff; } } + + static void radix3( + complex_t twiddle1, + complex_t twiddle2, + NBL_REF_ARG(complex_t) lo, + NBL_REF_ARG(complex_t) mid, + NBL_REF_ARG(complex_t) hi) + { + plus_assign< complex_t > plusAss; + NBL_CONSTEXPR_FUNC_SCOPE_VAR Scalar SQRT3_OVER_2 = Scalar(0.8660254037844386); + + // Decimation in time - inverse + if (inverse) { + // Apply twiddles first, then butterfly + complex_t w1 = twiddle1 * mid; + complex_t w2 = twiddle2 * hi; + + complex_t s = w1 + w2; + complex_t d = w1 - w2; + + // u = i * sqrt(3)/2 * d + complex_t u = complex_t(d.imag() * (-SQRT3_OVER_2), d.real() * SQRT3_OVER_2); + + complex_t t = lo - s * Scalar(0.5); + plusAss(lo, s); // lo = lo + s + mid = t + u; // inverse: swapped signs vs forward + hi = t - u; + } + // Decimation in frequency - forward + else { + // Butterfly first, then apply twiddles + complex_t s = mid + hi; + complex_t d = mid - hi; + + // u = i * sqrt(3)/2 * d + complex_t u = complex_t(d.imag() * (-SQRT3_OVER_2), d.real() * SQRT3_OVER_2); + + complex_t t = lo - s * Scalar(0.5); + plusAss(lo, s); // lo = lo + s + mid = twiddle1 * (t - u); + hi = twiddle2 * (t + u); + } + } + + static void radix5( + complex_t twiddle1, + complex_t twiddle2, + complex_t twiddle3, + complex_t twiddle4, + NBL_REF_ARG(complex_t) x0, + NBL_REF_ARG(complex_t) x1, + NBL_REF_ARG(complex_t) x2, + NBL_REF_ARG(complex_t) x3, + NBL_REF_ARG(complex_t) x4) + { + plus_assign< complex_t > plusAss; + + NBL_CONSTEXPR_FUNC_SCOPE_VAR Scalar COS_2PI_5 = Scalar(0.30901699437494742); // (sqrt(5) - 1)/4 + NBL_CONSTEXPR_FUNC_SCOPE_VAR Scalar COS_4PI_5 = Scalar(-0.80901699437494742); // -(sqrt(5) + 1)/4 + NBL_CONSTEXPR_FUNC_SCOPE_VAR Scalar SIN_2PI_5 = Scalar(0.95105651629515357); + NBL_CONSTEXPR_FUNC_SCOPE_VAR Scalar SIN_4PI_5 = Scalar(0.58778525229247312); + + //Decimation in time - inverse + if (inverse) { + // Apply twiddles first + complex_t w1 = twiddle1 * x1; + complex_t w2 = twiddle2 * x2; + complex_t w3 = twiddle3 * x3; + complex_t w4 = twiddle4 * x4; + + // 5-point inverse DFT: exploit W_5 conjugate symmetry via sum/diff pairs + complex_t s1 = w1 + w4; + complex_t d1 = w1 - w4; + complex_t s2 = w2 + w3; + complex_t d2 = w2 - w3; + + complex_t tA = x0 + COS_2PI_5 * s1 + COS_4PI_5 * s2; + complex_t tB = x0 + COS_4PI_5 * s1 + COS_2PI_5 * s2; + + // uA = i * (sA * d1 + sB * d2), uB = i * (sB * d1 - sA * d2) + complex_t vA = SIN_2PI_5 * d1 + SIN_4PI_5 * d2; + complex_t vB = SIN_4PI_5 * d1 - SIN_2PI_5 * d2; + complex_t uA = rotateLeft(vA); + complex_t uB = rotateLeft(vB); + + plusAss(x0, s1 + s2); // x0 = x0 + s1 + s2 + // inverse flips the sign on u compared to forward + x1 = tA + uA; + x4 = tA - uA; + x2 = tB + uB; + x3 = tB - uB; + } + //Decimation in frequency - forward + else { + // 5-point DFT first + complex_t s1 = x1 + x4; + complex_t d1 = x1 - x4; + complex_t s2 = x2 + x3; + complex_t d2 = x2 - x3; + + complex_t tA = x0 + COS_2PI_5 * s1 + COS_4PI_5 * s2; + complex_t tB = x0 + COS_4PI_5 * s1 + COS_2PI_5 * s2; + + complex_t vA = SIN_2PI_5 * d1 + SIN_4PI_5 * d2; + complex_t vB = SIN_4PI_5 * d1 - SIN_2PI_5 * d2; + complex_t uA = rotateLeft(vA); + complex_t uB = rotateLeft(vB); + + plusAss(x0, s1 + s2); // x0 = x0 + s1 + s2 + // Apply twiddles to the DFT outputs + x1 = twiddle1 * (tA - uA); + x4 = twiddle4 * (tA + uA); + x2 = twiddle2 * (tB - uB); + x3 = twiddle3 * (tB + uB); + } + } }; template diff --git a/include/nbl/builtin/hlsl/math/functions.hlsl b/include/nbl/builtin/hlsl/math/functions.hlsl index f7db44b9fb..f4a02833c5 100644 --- a/include/nbl/builtin/hlsl/math/functions.hlsl +++ b/include/nbl/builtin/hlsl/math/functions.hlsl @@ -91,7 +91,7 @@ scalar_type_t lpNorm(NBL_CONST_REF_ARG(T) v) return impl::lp_norm::__call(v); } - +// [Francisco] sqrt hits the SFU the same as a sin call, calling both sin and cos might just be faster and more accurate? // valid only for `theta` in [-PI,PI] template ) void sincos(T theta, NBL_REF_ARG(T) s, NBL_REF_ARG(T) c) diff --git a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl index fa895b38d4..870942236a 100644 --- a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl @@ -54,6 +54,7 @@ struct DivisionPolicy NBL_CONSTEXPR_STATIC_INLINE uint16_t DivByFullSizeByParts = DivByFullSizeAtEnd + 1; }; +// TODO: Separate parallelFFTs from Channels as two different concepts (dictates FFTSize) template //NBL_PRIMARY_REQUIRES(_ElementsPerInvocation > 1 && !(_ElementsPerInvocation & 1) && _WorkgroupSizeLog2 >= 5) struct ConstevalParameters { @@ -72,7 +73,6 @@ struct ConstevalParameters NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = ShuffledChannelsPerRound * ((sizeof(complex_t) / sizeof(uint32_t)) << (WorkgroupSizeLog2 + (_Interleaved ? 1 : 0))); NBL_CONSTEXPR_STATIC_INLINE uint16_t DivisionPolicy = _DivisionPolicy; - NBL_CONSTEXPR_STATIC_INLINE uint32_t }; } //namespace fft From dd50514be1f6078797716f8af540eeb82e98f784 Mon Sep 17 00:00:00 2001 From: Fletterio Date: Sat, 11 Apr 2026 18:38:06 -0300 Subject: [PATCH 05/15] Reduce register pressure from unrolling --- include/nbl/builtin/hlsl/workgroup2/fft.hlsl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl index 870942236a..22edd9d258 100644 --- a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl @@ -328,7 +328,9 @@ struct FFT twiddle = hlsl::fft::twiddle(threadID & (stride - 1), stride); bool pingPong = false; - [unroll] + // Unrolling this loop increases register pressure. Why? Who knows. + // It's not like it can't reuse the registers, and calls to exchangeValues are inlined anyway. + //[unroll] for (uint32_t round = 0; round < ShuffleRounds; round++) { if (round) @@ -371,6 +373,7 @@ struct FFT SubgroupSize; stride >>= 1) { @@ -401,7 +404,7 @@ struct FFT twiddle = hlsl::fft::twiddle(threadID & ((stride << 1) - 1), stride << 1); bool pingPong = false; - [unroll] + //[unroll] for (uint32_t round = 0; round < ShuffleRounds; round++) { if (round) @@ -449,7 +452,7 @@ struct FFT Date: Tue, 14 Apr 2026 17:32:34 -0300 Subject: [PATCH 06/15] dd the option to run opt passes at the end of default pass --- include/nbl/asset/utils/IShaderCompiler.h | 1 + include/nbl/video/ILogicalDevice.h | 1 + src/nbl/asset/utils/CHLSLCompiler.cpp | 4 ++-- src/nbl/video/ILogicalDevice.cpp | 1 + 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/include/nbl/asset/utils/IShaderCompiler.h b/include/nbl/asset/utils/IShaderCompiler.h index 05116b8d52..3ab6feafc2 100644 --- a/include/nbl/asset/utils/IShaderCompiler.h +++ b/include/nbl/asset/utils/IShaderCompiler.h @@ -261,6 +261,7 @@ class NBL_API2 IShaderCompiler : public core::IReferenceCounted SPreprocessorOptions preprocessorOptions = {}; CCache* readCache = nullptr; CCache* writeCache = nullptr; + bool optimizerIsExtraPasses = false; // Instead of disabling the default opt passes, run the provided optimization passes at the end }; class CCache final : public IReferenceCounted diff --git a/include/nbl/video/ILogicalDevice.h b/include/nbl/video/ILogicalDevice.h index 756b417c79..234f5575f3 100644 --- a/include/nbl/video/ILogicalDevice.h +++ b/include/nbl/video/ILogicalDevice.h @@ -833,6 +833,7 @@ class NBL_API2 ILogicalDevice : public core::IReferenceCounted, public IDeviceMe std::span extraDefines = {}; hlsl::ShaderStage stage = hlsl::ShaderStage::ESS_ALL_OR_LIBRARY; core::bitflag debugInfoFlags = asset::IShaderCompiler::E_DEBUG_INFO_FLAGS::EDIF_NONE; + bool optimizerIsExtraPasses = false; // Instead of disabling the default opt passes, run the provided optimization passes at the end }; core::smart_refctd_ptr compileShader(const SShaderCreationParameters& creationParams); diff --git a/src/nbl/asset/utils/CHLSLCompiler.cpp b/src/nbl/asset/utils/CHLSLCompiler.cpp index 6fec81c8cc..dd506afc54 100644 --- a/src/nbl/asset/utils/CHLSLCompiler.cpp +++ b/src/nbl/asset/utils/CHLSLCompiler.cpp @@ -584,13 +584,13 @@ core::smart_refctd_ptr CHLSLCompiler::compileToSPIRV_impl(const std::st // TODO: add entry point to `CHLSLCompiler::SOptions` and handle it properly in `dxc_compile_flags.empty()` arguments.push_back(L"main"); } - // If a custom SPIR-V optimizer is specified, use that instead of DXC's spirv-opt. + // If a custom SPIR-V optimizer is specified and set to replace default optimization passes, use that instead of DXC's spirv-opt. // This is how we can get more optimizer options. // // Optimization is also delegated to SPIRV-Tools. Right now there are no difference between // optimization levels greater than zero; they will all invoke the same optimization recipe. // https://github.com/Microsoft/DirectXShaderCompiler/blob/main/docs/SPIR-V.rst#optimization - if (hlslOptions.spirvOptimizer) + if (hlslOptions.spirvOptimizer && !hlslOptions.optimizerIsExtraPasses) arguments.push_back(L"-O0"); } // diff --git a/src/nbl/video/ILogicalDevice.cpp b/src/nbl/video/ILogicalDevice.cpp index 1752717879..466a0b300f 100644 --- a/src/nbl/video/ILogicalDevice.cpp +++ b/src/nbl/video/ILogicalDevice.cpp @@ -362,6 +362,7 @@ core::smart_refctd_ptr ILogicalDevice::compileShader(const SShad commonCompileOptions.stage = creationParams.stage; commonCompileOptions.debugInfoFlags = creationParams.debugInfoFlags; commonCompileOptions.spirvOptimizer = creationParams.optimizer; + commonCompileOptions.optimizerIsExtraPasses = creationParams.optimizerIsExtraPasses; commonCompileOptions.preprocessorOptions.targetSpirvVersion = m_physicalDevice->getLimits().spirvVersion; commonCompileOptions.readCache = creationParams.readCache; From 4065f210509eb1c6ff19be83ed3d29c92b6c9659 Mon Sep 17 00:00:00 2001 From: Fletterio Date: Tue, 14 Apr 2026 20:29:40 -0300 Subject: [PATCH 07/15] Give the compiler choices to spit out preprocessed and spv files, useful for debug --- include/nbl/asset/utils/IShaderCompiler.h | 2 + include/nbl/video/ILogicalDevice.h | 2 + src/nbl/asset/utils/CHLSLCompiler.cpp | 48 +++++++++++++++++++++++ src/nbl/video/ILogicalDevice.cpp | 3 ++ 4 files changed, 55 insertions(+) diff --git a/include/nbl/asset/utils/IShaderCompiler.h b/include/nbl/asset/utils/IShaderCompiler.h index 3ab6feafc2..526a9dd80e 100644 --- a/include/nbl/asset/utils/IShaderCompiler.h +++ b/include/nbl/asset/utils/IShaderCompiler.h @@ -262,6 +262,8 @@ class NBL_API2 IShaderCompiler : public core::IReferenceCounted CCache* readCache = nullptr; CCache* writeCache = nullptr; bool optimizerIsExtraPasses = false; // Instead of disabling the default opt passes, run the provided optimization passes at the end + std::string preprocessedOutputPath = ""; + std::string spvOutputPath = ""; }; class CCache final : public IReferenceCounted diff --git a/include/nbl/video/ILogicalDevice.h b/include/nbl/video/ILogicalDevice.h index 234f5575f3..742cb506c6 100644 --- a/include/nbl/video/ILogicalDevice.h +++ b/include/nbl/video/ILogicalDevice.h @@ -834,6 +834,8 @@ class NBL_API2 ILogicalDevice : public core::IReferenceCounted, public IDeviceMe hlsl::ShaderStage stage = hlsl::ShaderStage::ESS_ALL_OR_LIBRARY; core::bitflag debugInfoFlags = asset::IShaderCompiler::E_DEBUG_INFO_FLAGS::EDIF_NONE; bool optimizerIsExtraPasses = false; // Instead of disabling the default opt passes, run the provided optimization passes at the end + std::string preprocessedOutputPath = ""; + std::string spvOutputPath = ""; }; core::smart_refctd_ptr compileShader(const SShaderCreationParameters& creationParams); diff --git a/src/nbl/asset/utils/CHLSLCompiler.cpp b/src/nbl/asset/utils/CHLSLCompiler.cpp index dd506afc54..661edd18b8 100644 --- a/src/nbl/asset/utils/CHLSLCompiler.cpp +++ b/src/nbl/asset/utils/CHLSLCompiler.cpp @@ -546,6 +546,30 @@ core::smart_refctd_ptr CHLSLCompiler::compileToSPIRV_impl(const std::st auto newCode = preprocessShader(std::string(code), stage, hlslOptions.preprocessorOptions, dxc_compile_flags, dependencies); if (newCode.empty()) return nullptr; + if (!options.preprocessedOutputPath.empty()) + { + core::smart_refctd_ptr preprocessedFile; + + system::ISystem::future_t> future; + m_system->deleteFile(options.preprocessedOutputPath); + m_system->createFile(future, options.preprocessedOutputPath, system::IFile::ECF_WRITE); + if (future.wait()) + { + future.acquire().move_into(preprocessedFile); + if (preprocessedFile) + { + system::IFile::success_t succ; + preprocessedFile->write(succ, newCode.data(), 0, newCode.size()); + if (!succ) + logger.log("Failed Writing To Preprocessed Output File.", nbl::system::ILogger::ELL_ERROR); + } + else + logger.log("Failed Creating Preprocessed Output File.", nbl::system::ILogger::ELL_ERROR); + } + else + logger.log("Failed Creating Preprocessed Output File.", nbl::system::ILogger::ELL_ERROR); + } + // Suffix is the shader model version std::wstring targetProfile(SHADER_MODEL_PROFILE); @@ -651,6 +675,30 @@ core::smart_refctd_ptr CHLSLCompiler::compileToSPIRV_impl(const std::st if (hlslOptions.spirvOptimizer) outSpirv = hlslOptions.spirvOptimizer->optimize(outSpirv.get(), logger); + if (!options.spvOutputPath.empty()) + { + core::smart_refctd_ptr spvFile; + + system::ISystem::future_t> future; + m_system->deleteFile(options.spvOutputPath); + m_system->createFile(future, options.spvOutputPath, system::IFile::ECF_WRITE); + if (future.wait()) + { + future.acquire().move_into(spvFile); + if (spvFile) + { + system::IFile::success_t succ; + spvFile->write(succ, outSpirv->getPointer(), 0, outSpirv->getSize()); + if (!succ) + logger.log("Failed Writing To SPIR-V Output File.", nbl::system::ILogger::ELL_ERROR); + } + else + logger.log("Failed Creating SPIR-V Output File.", nbl::system::ILogger::ELL_ERROR); + } + else + logger.log("Failed Creating SPIR-V Output File.", nbl::system::ILogger::ELL_ERROR); + } + return core::make_smart_refctd_ptr(std::move(outSpirv), IShader::E_CONTENT_TYPE::ECT_SPIRV, hlslOptions.preprocessorOptions.sourceIdentifier.data()); } diff --git a/src/nbl/video/ILogicalDevice.cpp b/src/nbl/video/ILogicalDevice.cpp index 466a0b300f..bee6381f7a 100644 --- a/src/nbl/video/ILogicalDevice.cpp +++ b/src/nbl/video/ILogicalDevice.cpp @@ -368,6 +368,9 @@ core::smart_refctd_ptr ILogicalDevice::compileShader(const SShad commonCompileOptions.readCache = creationParams.readCache; commonCompileOptions.writeCache = creationParams.writeCache; + commonCompileOptions.preprocessedOutputPath = creationParams.preprocessedOutputPath; + commonCompileOptions.spvOutputPath = creationParams.spvOutputPath; + if (sourceContent==asset::IShader::E_CONTENT_TYPE::ECT_HLSL) { // TODO: add specific HLSLCompiler::SOption params From 85d14c7ef8c0df324cf93373a3df27bd0d88ab1d Mon Sep 17 00:00:00 2001 From: Fletterio Date: Wed, 15 Apr 2026 15:01:50 -0300 Subject: [PATCH 08/15] ex pointer update for merge --- examples_tests | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples_tests b/examples_tests index d84bcfcb04..21133becb3 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit d84bcfcb04960d5d69b47fe969e10c676ac8ca86 +Subproject commit 21133becb3e24e9c60dc329ca96d5be07adc89d4 From ba922ed2a8b53508554fe1ab9fcc549eb8c33ab9 Mon Sep 17 00:00:00 2001 From: Fletterio Date: Wed, 15 Apr 2026 15:02:06 -0300 Subject: [PATCH 09/15] ex pointer update for merge --- examples_tests | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples_tests b/examples_tests index 1345dae922..21133becb3 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 1345dae9220598734e73ed425225b49dc3c3cfe6 +Subproject commit 21133becb3e24e9c60dc329ca96d5be07adc89d4 From 7a1747830945177b7d8e55678abb0283669d434f Mon Sep 17 00:00:00 2001 From: Fletterio Date: Wed, 15 Apr 2026 15:03:14 -0300 Subject: [PATCH 10/15] restore examples pointer --- examples_tests | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples_tests b/examples_tests index 21133becb3..2e6ea2fb6f 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 21133becb3e24e9c60dc329ca96d5be07adc89d4 +Subproject commit 2e6ea2fb6f573c7c89414e3c60b14df034d8c822 From 27af760501dab7e23c5019eaa762d0e006d91bd8 Mon Sep 17 00:00:00 2001 From: Fletterio Date: Wed, 22 Apr 2026 00:00:12 -0300 Subject: [PATCH 11/15] solve uint16 size warnings --- include/nbl/builtin/hlsl/subgroup2/fft.hlsl | 22 +++++++++---------- include/nbl/builtin/hlsl/workgroup/basic.hlsl | 2 +- include/nbl/builtin/hlsl/workgroup2/fft.hlsl | 10 ++++----- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/include/nbl/builtin/hlsl/subgroup2/fft.hlsl b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl index 54b116aa21..8e0f212afe 100644 --- a/include/nbl/builtin/hlsl/subgroup2/fft.hlsl +++ b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl @@ -130,8 +130,8 @@ struct FFT } // Decimation in Frequency - [unroll] uint32_t threadStride = SubgroupSize >> 1; + [unroll] for (uint32_t elementStride = WorkgroupSize >> 1; elementStride > SubgroupSize; elementStride >>= 1) { FFT_loop(elementStride, threadStride, lowChannel, highChannel, loAccessor, hiAccessor); @@ -238,8 +238,8 @@ struct FFT static void __callInterleaved(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) { // Decimation in Time - [unroll] uint32_t threadStride = SubgroupSize >> (NumSubgroupsLog2 - 1); + [unroll] for (uint32_t elementStride = SubgroupSize << 1; elementStride < WorkgroupSize; elementStride <<= 1) { FFT_loop(elementStride, threadStride, lowChannel, highChannel, loAccessor, hiAccessor); @@ -251,15 +251,15 @@ struct FFT const uint32_t loLaneIndex = (glsl::gl_SubgroupInvocationID() << NumSubgroupsLog2) + glsl::gl_SubgroupID(); const complex_t twiddle = fft2::twiddle(loLaneIndex, WorkgroupSize); [unroll] - for (uint16_t channel = lowChannel; channel <= highChannel; channel++) - { - complex_t lo, hi; - loAccessor.get(channel, lo); - hiAccessor.get(channel, hi); - fft2::DIT::radix2(twiddle, lo, hi); - loAccessor.set(channel, lo); - hiAccessor.set(channel, hi); - } + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIT::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } } }; diff --git a/include/nbl/builtin/hlsl/workgroup/basic.hlsl b/include/nbl/builtin/hlsl/workgroup/basic.hlsl index 3467b46407..dc69f37e17 100644 --- a/include/nbl/builtin/hlsl/workgroup/basic.hlsl +++ b/include/nbl/builtin/hlsl/workgroup/basic.hlsl @@ -29,7 +29,7 @@ uint16_t SubgroupContiguousIndex() assert(retval> 1; NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2; - NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = 1 << SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(1) << SubgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(1) << WorkgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t NumSubgroupsLog2 = WorkgroupSizeLog2 - SubgroupSizeLog2; @@ -102,7 +102,7 @@ inline OptimalFFTParameters optimalFFTParameters(uint32_t maxWorkgroupSize, uint if (subgroupSize < 2 || maxWorkgroupSize < subgroupSize || inputArrayLength <= subgroupSize) return invalidParameters; // Pad inputarrayLength to size that FFT algo handles - const uint32_t FFTLength = hlsl::fft2::padDimension(inputArrayLength, subgroupSize); + const uint32_t FFTLength = hlsl::fft2::padDimension(inputArrayLength, uint16_t(subgroupSize)); // Round maxWorkgroupSize down to PoT const uint32_t actualMaxWorkgroupSize = hlsl::roundDownToPoT(maxWorkgroupSize); // Max number of threads that can run the FFT @@ -111,7 +111,7 @@ inline OptimalFFTParameters optimalFFTParameters(uint32_t maxWorkgroupSize, uint maxThreads /= maxThreads % 3 ? 1 : 3; maxThreads /= maxThreads % 5 ? 1 : 5; // Both are PoT - const uint16_t workgroupSizeLog2 = findMSB(min(maxThreads, actualMaxWorkgroupSize)); + const uint16_t workgroupSizeLog2 = uint16_t(findMSB(min(maxThreads, actualMaxWorkgroupSize))); // Parameters are valid if the workgroup size is at most half of the FFT Length and at least as big as the subgroupSize if ((FFTLength >> workgroupSizeLog2) <= 1 || subgroupSize > (1u << workgroupSizeLog2)) @@ -119,7 +119,7 @@ inline OptimalFFTParameters optimalFFTParameters(uint32_t maxWorkgroupSize, uint return invalidParameters; } - const uint16_t elementsPerInvocation = FFTLength >> workgroupSizeLog2; + const uint16_t elementsPerInvocation = uint16_t(FFTLength >> workgroupSizeLog2); const OptimalFFTParameters retVal = { elementsPerInvocation, workgroupSizeLog2 }; return retVal; @@ -174,7 +174,7 @@ struct FFTIndexingUtilsHelper // log2(ElementsPerInvocation / ExtraPrimeFactor) NBL_CONSTEXPR_STATIC_INLINE uint16_t Radix2FFTSizeLog2 = WorkgroupSizeLog2 + Radix2ElementsPerInvocationLog2; // Size of the full FFT if no mixed radix used, otherwise size of the sub-FFTs computed after the first radix-3/5 step in the DIF forward FFT - NBL_CONSTEXPR_STATIC_INLINE uint16_t Radix2FFTSize = uint32_t(1) << (Radix2FFTSizeLog2); + NBL_CONSTEXPR_STATIC_INLINE uint16_t Radix2FFTSize = uint16_t(1) << (Radix2FFTSizeLog2); // Total size of the FFT computed NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = ExtraPrimeFactor * Radix2FFTSize; }; From 61d5f70311d9865063046c11b5070163f12f0859 Mon Sep 17 00:00:00 2001 From: Fletterio Date: Tue, 5 May 2026 00:53:25 -0300 Subject: [PATCH 12/15] Halfway added subgroup shared twiddles --- examples_tests | 2 +- include/nbl/builtin/hlsl/subgroup2/fft.hlsl | 51 ++++++++++++++------ include/nbl/builtin/hlsl/workgroup2/fft.hlsl | 18 ++++--- 3 files changed, 46 insertions(+), 25 deletions(-) diff --git a/examples_tests b/examples_tests index 2e6ea2fb6f..f466d64c96 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 2e6ea2fb6f573c7c89414e3c60b14df034d8c822 +Subproject commit f466d64c96374cb9c80769c664c7aba99f228a69 diff --git a/include/nbl/builtin/hlsl/subgroup2/fft.hlsl b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl index 8e0f212afe..9463876d62 100644 --- a/include/nbl/builtin/hlsl/subgroup2/fft.hlsl +++ b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl @@ -27,11 +27,9 @@ template struct FFT { template - static void FFT_loop(uint32_t stride, uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) + static void FFT_loop(uint32_t threadStride, uint16_t lowChannel, uint16_t highChannel, NBL_CONST_REF_ARG(complex_t) twiddle, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) { - const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & stride); - // Get twiddle with k = subgroupInvocation mod stride, halfN = stride - const complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID() & (stride - 1), stride); + const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & threadStride); [unroll] for (uint16_t channel = lowChannel; channel <= highChannel; channel++) { @@ -39,7 +37,7 @@ struct FFT loAccessor.get(channel, lo); hiAccessor.get(channel, hi); const vector toTrade = topHalf ? vector (lo.real(), lo.imag()) : vector (hi.real(), hi.imag()); - const vector exchanged = glsl::subgroupShuffleXor< vector >(toTrade, stride); + const vector exchanged = glsl::subgroupShuffleXor< vector >(toTrade, threadStride); if (topHalf) { lo.real(exchanged.x); @@ -56,11 +54,11 @@ struct FFT } } - template + template static void __call(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) { // special first iteration - const complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID(), SubgroupSize); + complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID(), SubgroupSize); [unroll] for (uint16_t channel = lowChannel; channel <= highChannel; channel++) { @@ -73,19 +71,39 @@ struct FFT } // Decimation in Frequency - [unroll] - for (uint32_t stride = SubgroupSize >> 1; stride > 0; stride >>= 1) - FFT_loop(stride, lowChannel, highChannel, loAccessor, hiAccessor); + // Compute all twiddles at the start, then reshare them among threads + if (ShareTwiddles) + { + uint32_t iteration = 1; + [unroll] + for (uint32_t threadStride = SubgroupSize >> 1; threadStride > 0; threadStride >>= 1) + { + const vector toTrade = vector (twiddle.real(), twiddle.imag()); + const vector otherTwiddle = glsl::subgroupShuffle< vector >(toTrade, (glsl::gl_SubgroupInvocationID() & (threadStride - 1)) << iteration); + twiddle.real(otherTwiddle.x); + twiddle.imag(otherTwiddle.y); + FFT_loop(threadStride, lowChannel, highChannel, twiddle, loAccessor, hiAccessor); + iteration++; + } + } + // Recompute twiddles at every step + else + { + [unroll] + for (uint32_t threadStride = SubgroupSize >> 1; threadStride > 0; threadStride >>= 1) + { + // Get twiddle with k = subgroupInvocation mod threadStride, halfN = threadStride + const complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID() & (threadStride - 1), threadStride); + FFT_loop(threadStride, lowChannel, highChannel, twiddle, loAccessor, hiAccessor); + } + } } // Interleaved versions of the above methods, required to implement the first steps in Interleaved DIF template - static void FFT_loop(uint32_t elementStride, uint32_t threadStride, uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) + static void FFT_loop(uint32_t threadStride, uint16_t lowChannel, uint16_t highChannel, NBL_CONST_REF_ARG(complex_t) twiddle, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) { const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & threadStride); - // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod elementStride, halfN = elementStride - const uint32_t loLaneIndex = (glsl::gl_SubgroupInvocationID() << NumSubgroupsLog2) + glsl::gl_SubgroupID(); - const complex_t twiddle = fft2::twiddle(loLaneIndex & (elementStride - 1), elementStride); [unroll] for (uint16_t channel = lowChannel; channel <= highChannel; channel++) { @@ -117,7 +135,7 @@ struct FFT // special first iteration // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod WorkgroupSize, halfN = WorkgroupSize const uint32_t loLaneIndex = (glsl::gl_SubgroupInvocationID() << NumSubgroupsLog2) + glsl::gl_SubgroupID(); - const complex_t twiddle = fft2::twiddle(loLaneIndex, WorkgroupSize); + complex_t twiddle = fft2::twiddle(loLaneIndex, WorkgroupSize); [unroll] for (uint16_t channel = lowChannel; channel <= highChannel; channel++) { @@ -134,7 +152,8 @@ struct FFT [unroll] for (uint32_t elementStride = WorkgroupSize >> 1; elementStride > SubgroupSize; elementStride >>= 1) { - FFT_loop(elementStride, threadStride, lowChannel, highChannel, loAccessor, hiAccessor); + twiddle = fft2::twiddle(loLaneIndex & (elementStride - 1), elementStride); + FFT_loop(threadStride, lowChannel, highChannel, twiddle, loAccessor, hiAccessor); threadStride >>= 1; } } diff --git a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl index 5d999bfad8..4dba242882 100644 --- a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl @@ -55,7 +55,7 @@ struct DivisionPolicy }; // TODO: Separate parallelFFTs from Channels as two different concepts (dictates FFTSize) -template //NBL_PRIMARY_REQUIRES(_ElementsPerInvocation > 1 && !(_ElementsPerInvocation & 1) && _WorkgroupSizeLog2 >= 5) +template //NBL_PRIMARY_REQUIRES(_ElementsPerInvocation > 1 && !(_ElementsPerInvocation & 1) && _WorkgroupSizeLog2 >= 5) struct ConstevalParameters { using scalar_t = _Scalar; @@ -72,6 +72,8 @@ struct ConstevalParameters NBL_CONSTEXPR_STATIC_INLINE uint16_t ShuffleRounds = mpl::ceil_div_v; NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = ShuffledChannelsPerRound * ((sizeof(complex_t) / sizeof(uint32_t)) << (WorkgroupSizeLog2 + (_Interleaved ? 1 : 0))); + NBL_CONSTEXPR_STATIC_INLINE bool ShareTwiddles = _ShareTwiddles; + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivisionPolicy = _DivisionPolicy; }; } //namespace fft @@ -313,10 +315,10 @@ template -struct FFT, device_capabilities> +template +struct FFT, device_capabilities> { - using consteval_parameters_t = fft::ConstevalParameters; + using consteval_parameters_t = fft::ConstevalParameters; using scalar_t = typename consteval_parameters_t::scalar_t; template @@ -384,15 +386,15 @@ struct FFT::__call(0, Channels - 1, loAccessor, hiAccessor); + subgroup2::FFT::template __call(0, Channels - 1, loAccessor, hiAccessor); } }; // Non-interleaved (shuffle after every butterfly) inverse FFT -template -struct FFT, device_capabilities> +template +struct FFT, device_capabilities> { - using consteval_parameters_t = fft::ConstevalParameters; + using consteval_parameters_t = fft::ConstevalParameters; using scalar_t = typename consteval_parameters_t::scalar_t; template From f67bd93f0ec5779c389cd420742a4cb7314e72a9 Mon Sep 17 00:00:00 2001 From: Fletterio Date: Tue, 5 May 2026 20:19:34 -0300 Subject: [PATCH 13/15] Refactoring channels, shared subgroup twiddles done --- include/nbl/builtin/hlsl/subgroup2/fft.hlsl | 117 +++++++------------ include/nbl/builtin/hlsl/workgroup2/fft.hlsl | 10 +- 2 files changed, 46 insertions(+), 81 deletions(-) diff --git a/include/nbl/builtin/hlsl/subgroup2/fft.hlsl b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl index 9463876d62..d01f332b70 100644 --- a/include/nbl/builtin/hlsl/subgroup2/fft.hlsl +++ b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl @@ -99,46 +99,17 @@ struct FFT } } - // Interleaved versions of the above methods, required to implement the first steps in Interleaved DIF - template - static void FFT_loop(uint32_t threadStride, uint16_t lowChannel, uint16_t highChannel, NBL_CONST_REF_ARG(complex_t) twiddle, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) - { - const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & threadStride); - [unroll] - for (uint16_t channel = lowChannel; channel <= highChannel; channel++) - { - complex_t lo, hi; - loAccessor.get(channel, lo); - hiAccessor.get(channel, hi); - const vector toTrade = topHalf ? vector (lo.real(), lo.imag()) : vector (hi.real(), hi.imag()); - const vector exchanged = glsl::subgroupShuffleXor< vector >(toTrade, threadStride); - if (topHalf) - { - lo.real(exchanged.x); - lo.imag(exchanged.y); - } - else - { - hi.real(exchanged.x); - hi.imag(exchanged.y); - } - fft2::DIF::radix2(twiddle, lo, hi); - loAccessor.set(channel, lo); - hiAccessor.set(channel, hi); - } - } - - // Only uses subgroup methods, but is actually used at workgroup level + // Only uses subgroup methods, but is actually used at workgroup level. Used by the interleaved workgroup FFT at bigger than subgroup strides template static void __callInterleaved(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) { - // special first iteration - // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod WorkgroupSize, halfN = WorkgroupSize const uint32_t loLaneIndex = (glsl::gl_SubgroupInvocationID() << NumSubgroupsLog2) + glsl::gl_SubgroupID(); - complex_t twiddle = fft2::twiddle(loLaneIndex, WorkgroupSize); + // special first iteration [unroll] for (uint16_t channel = lowChannel; channel <= highChannel; channel++) { + // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod WorkgroupSize, halfN = WorkgroupSize + const complex_t twiddle = fft2::twiddle(loLaneIndex, WorkgroupSize); complex_t lo, hi; loAccessor.get(channel, lo); hiAccessor.get(channel, hi); @@ -152,8 +123,9 @@ struct FFT [unroll] for (uint32_t elementStride = WorkgroupSize >> 1; elementStride > SubgroupSize; elementStride >>= 1) { - twiddle = fft2::twiddle(loLaneIndex & (elementStride - 1), elementStride); - FFT_loop(threadStride, lowChannel, highChannel, twiddle, loAccessor, hiAccessor); + // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod elementStride, halfN = elementStride + const complex_t twiddle = fft2::twiddle(loLaneIndex & (elementStride - 1), elementStride); + FFT_loop(threadStride, lowChannel, highChannel, twiddle, loAccessor, hiAccessor); threadStride >>= 1; } } @@ -166,11 +138,9 @@ template struct FFT { template - static void FFT_loop(uint32_t stride, uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) + static void FFT_loop(uint32_t stride, uint16_t lowChannel, uint16_t highChannel, NBL_CONST_REF_ARG(complex_t) twiddle, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) { const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & stride); - // Get twiddle with k = subgroupInvocation mod stride, halfN = stride - const complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID() & (stride - 1), stride); [unroll] for (uint16_t channel = lowChannel; channel <= highChannel; channel++) @@ -197,13 +167,36 @@ struct FFT } } - template + template static void __call(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) { // Decimation in Time - [unroll] - for (uint32_t stride = 1; stride < SubgroupSize; stride <<= 1) - FFT_loop(stride, lowChannel, highChannel, loAccessor, hiAccessor); + // Compute all twiddles at the start, then shuffle them + if (ShareTwiddles) + { + const complex_t ownedTwiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID(), SubgroupSize); + uint32_t reverseStride = SubgroupSize; + [unroll] + for (uint32_t threadStride = 1; threadStride < SubgroupSize; threadStride <<= 1) + { + const vector toTrade = vector (ownedTwiddle.real(), ownedTwiddle.imag()); + const vector otherTwiddle = glsl::subgroupShuffle< vector >(toTrade, (glsl::gl_SubgroupInvocationID() & (threadStride - 1)) * reverseStride); + const complex_t twiddle = { otherTwiddle.x , otherTwiddle.y }; + FFT_loop(threadStride, lowChannel, highChannel, twiddle, loAccessor, hiAccessor); + reverseStride >>= 1; + } + } + // Compute each twiddle at each iteration + else + { + [unroll] + for (uint32_t threadStride = 1; threadStride < SubgroupSize; threadStride <<= 1) + { + // Get twiddle with k = subgroupInvocation mod threadStride, halfN = threadStride + const complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID() & (threadStride - 1), threadStride); + FFT_loop(threadStride, lowChannel, highChannel, twiddle, loAccessor, hiAccessor); + } + } // special last iteration const complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID(), SubgroupSize); @@ -219,55 +212,23 @@ struct FFT } } - // Interleaved versions of the above methods, required to implement the last steps in Interleaved DIT - template - static void FFT_loop(uint32_t elementStride, uint32_t threadStride, uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) - { - const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & threadStride); - // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod elementStride, halfN = elementStride - const uint32_t loLaneIndex = (glsl::gl_SubgroupInvocationID() << NumSubgroupsLog2) + glsl::gl_SubgroupID(); - const complex_t twiddle = fft2::twiddle(loLaneIndex & (elementStride - 1), elementStride); - - [unroll] - for (uint16_t channel = lowChannel; channel <= highChannel; channel++) - { - complex_t lo, hi; - loAccessor.get(channel, lo); - hiAccessor.get(channel, hi); - fft2::DIT::radix2(twiddle, lo, hi); - - const vector toTrade = topHalf ? vector (lo.real(), lo.imag()) : vector (hi.real(), hi.imag()); - const vector exchanged = glsl::subgroupShuffleXor< vector >(toTrade, threadStride); - if (topHalf) - { - lo.real(exchanged.x); - lo.imag(exchanged.y); - } - else - { - hi.real(exchanged.x); - hi.imag(exchanged.y); - } - loAccessor.set(channel, lo); - hiAccessor.set(channel, hi); - } - } - template static void __callInterleaved(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) { + const uint32_t loLaneIndex = (glsl::gl_SubgroupInvocationID() << NumSubgroupsLog2) + glsl::gl_SubgroupID(); // Decimation in Time uint32_t threadStride = SubgroupSize >> (NumSubgroupsLog2 - 1); [unroll] for (uint32_t elementStride = SubgroupSize << 1; elementStride < WorkgroupSize; elementStride <<= 1) { - FFT_loop(elementStride, threadStride, lowChannel, highChannel, loAccessor, hiAccessor); + // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod elementStride, halfN = elementStride + const complex_t twiddle = fft2::twiddle(loLaneIndex & (elementStride - 1), elementStride); + FFT_loop(threadStride, lowChannel, highChannel, loAccessor, hiAccessor); threadStride <<= 1; } // special last iteration // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod WorkgroupSize, halfN = WorkgroupSize - const uint32_t loLaneIndex = (glsl::gl_SubgroupInvocationID() << NumSubgroupsLog2) + glsl::gl_SubgroupID(); const complex_t twiddle = fft2::twiddle(loLaneIndex, WorkgroupSize); [unroll] for (uint16_t channel = lowChannel; channel <= highChannel; channel++) diff --git a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl index 4dba242882..9e3ce90bc3 100644 --- a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl @@ -55,19 +55,23 @@ struct DivisionPolicy }; // TODO: Separate parallelFFTs from Channels as two different concepts (dictates FFTSize) -template //NBL_PRIMARY_REQUIRES(_ElementsPerInvocation > 1 && !(_ElementsPerInvocation & 1) && _WorkgroupSizeLog2 >= 5) +template //NBL_PRIMARY_REQUIRES(_ElementsPerInvocation > 1 && !(_ElementsPerInvocation & 1) && _WorkgroupSizeLog2 >= 5) struct ConstevalParameters { using scalar_t = _Scalar; NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = _ElementsPerInvocation; NBL_CONSTEXPR_STATIC_INLINE uint16_t Channels = ElementsPerInvocation >> 1; + NBL_CONSTEXPR_STATIC_INLINE uint16_t RealChannels = _RealChannels; + NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerRealChannel = ElementsPerInvocation / RealChannels; NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(1) << SubgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(1) << WorkgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t NumSubgroupsLog2 = WorkgroupSizeLog2 - SubgroupSizeLog2; - NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTTotalSize = ElementsPerInvocation * (uint32_t(1) << WorkgroupSizeLog2); + + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTTotalSize = ElementsPerRealChannel * (uint32_t(1) << WorkgroupSizeLog2); + NBL_CONSTEXPR_STATIC_INLINE uint16_t ShuffledChannelsPerRound = _ShuffledChannelsPerRound; NBL_CONSTEXPR_STATIC_INLINE uint16_t ShuffleRounds = mpl::ceil_div_v; NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = ShuffledChannelsPerRound * ((sizeof(complex_t) / sizeof(uint32_t)) << (WorkgroupSizeLog2 + (_Interleaved ? 1 : 0))); @@ -439,7 +443,7 @@ struct FFT::__call(0, Channels - 1, loAccessor, hiAccessor); + subgroup2::FFT::template __call(0, Channels - 1, loAccessor, hiAccessor); // Get workgroup threadID const uint32_t threadID = uint32_t(workgroup::SubgroupContiguousIndex()); From bc9f37953b576f687cbf353c2cb60cce95b42c81 Mon Sep 17 00:00:00 2001 From: Fletterio Date: Wed, 6 May 2026 18:08:24 -0300 Subject: [PATCH 14/15] Refactored consteval and workgroup radix 2 to be inner --- examples_tests | 2 +- include/nbl/builtin/hlsl/workgroup2/fft.hlsl | 216 +++++++++++-------- 2 files changed, 130 insertions(+), 88 deletions(-) diff --git a/examples_tests b/examples_tests index f466d64c96..1d6aae3f97 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit f466d64c96374cb9c80769c664c7aba99f228a69 +Subproject commit 1d6aae3f97998e36199645a6b08b76e39189b615 diff --git a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl index 9e3ce90bc3..c20a53eef8 100644 --- a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl @@ -40,46 +40,84 @@ struct DivisionPolicy { // No division performed at any step of the FFT NBL_CONSTEXPR_STATIC_INLINE uint16_t NoDivision = 0; + // Divides the array by sqrt(FFTSize) right at the start of the algorithm + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivBySqrtAtStart = NoDivision + 1; // Divides the array by sqrt(FFTSize) at the time of the last workgroup barrier before subgroupFFT (forward) or at the time of the first workgroup barrier // after subgroupFFT (inverse). - NBL_CONSTEXPR_STATIC_INLINE uint16_t DivBySqrtHalfway = NoDivision + 1; + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivBySqrtHalfway = DivBySqrtAtStart + 1; // Divides the array by sqrt(FFTSize) right at the end of the algorithm NBL_CONSTEXPR_STATIC_INLINE uint16_t DivBySqrtAtEnd = DivBySqrtHalfway + 1; - // Divides the array by sqrt(FFTSize) by considering `sqrt(FFTSize) = a * b`, dividing by `a` halfway (as described in `DivBySqrtHalfway`) and then dividing - // by `b` at the end. `a` and `b` are chosen so that their weight is proportional to the number of butterflies before the division. + // Total division by sqrt(FFTSize) in three steps: Once right at the start of the algorithm, once before the first workgroupBarrier, and once after the last workgroupBarrier. + // Only valid if ElementsPerInvocation > 2. NBL_CONSTEXPR_STATIC_INLINE uint16_t DivBySqrtByParts = DivBySqrtAtEnd + 1; // The three following all perform divisions in the same manner as their counterparts above, but they divide the array by `FFTSize`. - NBL_CONSTEXPR_STATIC_INLINE uint16_t DivByFullSizeHalfway = DivBySqrtByParts + 1; + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivByFullSizeAtStart = DivBySqrtByParts + 1; + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivByFullSizeHalfway = DivByFullSizeAtStart + 1; NBL_CONSTEXPR_STATIC_INLINE uint16_t DivByFullSizeAtEnd = DivByFullSizeHalfway + 1; NBL_CONSTEXPR_STATIC_INLINE uint16_t DivByFullSizeByParts = DivByFullSizeAtEnd + 1; }; -// TODO: Separate parallelFFTs from Channels as two different concepts (dictates FFTSize) -template //NBL_PRIMARY_REQUIRES(_ElementsPerInvocation > 1 && !(_ElementsPerInvocation & 1) && _WorkgroupSizeLog2 >= 5) +namespace impl +{ +template +struct DivisionConstants +{ + NBL_CONSTEXPR_STATIC_INLINE uint16_t TODO = 0; +}; +} //namespace impl + +template //NBL_PRIMARY_REQUIRES(_ElementsPerInvocation > 1 && !(_ElementsPerInvocation & 1) && _WorkgroupSizeLog2 >= 5) struct ConstevalParameters { using scalar_t = _Scalar; + using DivisionConstants = impl::DivisionConstants<_ElementsPerInvocationPerChannel, _WorkgroupSizeLog2>; - NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = _ElementsPerInvocation; - NBL_CONSTEXPR_STATIC_INLINE uint16_t Channels = ElementsPerInvocation >> 1; - NBL_CONSTEXPR_STATIC_INLINE uint16_t RealChannels = _RealChannels; - NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerRealChannel = ElementsPerInvocation / RealChannels; + NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationPerChannel = _ElementsPerInvocationPerChannel; + NBL_CONSTEXPR_STATIC_INLINE uint16_t Channels = _Channels; + NBL_CONSTEXPR_STATIC_INLINE uint16_t InnerVirtualChannels = Channels * (ElementsPerInvocationPerChannel >> 1); NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(1) << SubgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(1) << WorkgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t NumSubgroupsLog2 = WorkgroupSizeLog2 - SubgroupSizeLog2; - NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTTotalSize = ElementsPerRealChannel * (uint32_t(1) << WorkgroupSizeLog2); - - NBL_CONSTEXPR_STATIC_INLINE uint16_t ShuffledChannelsPerRound = _ShuffledChannelsPerRound; - NBL_CONSTEXPR_STATIC_INLINE uint16_t ShuffleRounds = mpl::ceil_div_v; - NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = ShuffledChannelsPerRound * ((sizeof(complex_t) / sizeof(uint32_t)) << (WorkgroupSizeLog2 + (_Interleaved ? 1 : 0))); + NBL_CONSTEXPR_STATIC_INLINE uint16_t ShuffledVirtualChannelsPerRound = _ShuffledVirtualChannelsPerRound; + NBL_CONSTEXPR_STATIC_INLINE uint16_t ShuffleRounds = mpl::ceil_div_v; + NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = ShuffledVirtualChannelsPerRound * ((sizeof(complex_t) / sizeof(uint32_t)) << (WorkgroupSizeLog2 + (_Interleaved ? 1 : 0))); NBL_CONSTEXPR_STATIC_INLINE bool ShareTwiddles = _ShareTwiddles; NBL_CONSTEXPR_STATIC_INLINE uint16_t DivisionPolicy = _DivisionPolicy; }; + +// Takes an elements accessor (lo/hi) with signature `acc.set/get(channel, pair)` and flattens it to 1D +// Workgroup-sized FFT works on many channels of a single pair each. Since these are constant folded it's the easiest solution +template +struct WorkgroupRadix2AccessorAdaptor +{ + static WorkgroupRadix2AccessorAdaptor create(NBL_REF_ARG(InvocationElementsAccessor) _accessor) + { + WorkgroupRadix2AccessorAdaptor retVal; + retVal.accessor = _accessor; + return retVal; + } + + void get(uint32_t virtualChannel, NBL_REF_ARG(complex_t) value) + { + const uint32_t channel = virtualChannel / Channels; + const uint32_t pair = virtualChannel % Channels; + accessor.get(channel, pair, value); + } + + void set(uint32_t virtualChannel, NBL_CONST_REF_ARG(complex_t) value) + { + const uint32_t channel = virtualChannel / Channels; + const uint32_t pair = virtualChannel % Channels; + accessor.set(channel, pair, value); + } + + InvocationElementsAccessor accessor; +}; } //namespace fft struct OptimalFFTParameters @@ -144,9 +182,9 @@ struct FFTIndexingUtilsHelper const uint32_t radix2mask = (1 << Radix2FFTSizeLog2) - 1; return ExtraPrimeFactor * hlsl::bitReverseAs(laneIdx, Radix2FFTSizeLog2) + (laneIdx >> Radix2FFTSizeLog2); } - else - { - return hlsl::bitReverseAs(laneIdx, Radix2FFTSizeLog2); + else + { + return hlsl::bitReverseAs(laneIdx, Radix2FFTSizeLog2); } } @@ -159,7 +197,7 @@ struct FFTIndexingUtilsHelper } else // ExtraPrimeFactor == 5 { - return (x * 52429u) >> 18; // valid for x <= 81919 + return (x * 52429u) >> 18; // valid for x <= 81919 } } @@ -171,9 +209,9 @@ struct FFTIndexingUtilsHelper const uint32_t divByPrimeFactor = fastDiv(freqIdx); return hlsl::bitReverseAs(divByPrimeFactor, Radix2FFTSizeLog2) + ((freqIdx - ExtraPrimeFactor * divByPrimeFactor) << Radix2FFTSizeLog2); } - else - { - return hlsl::bitReverseAs(freqIdx, Radix2FFTSizeLog2); + else + { + return hlsl::bitReverseAs(freqIdx, Radix2FFTSizeLog2); } } @@ -315,47 +353,48 @@ struct exchangeValues //-------------- ------------------------------------ END UTILS -------------------------------------------------------- -template -struct FFT; +namespace impl +{ +template +struct InnerFFT; -// Non-interleaved (shuffle after every butterfly) forward FFT -template -struct FFT, device_capabilities> +// Non-interleaved (shuffle after every butterfly) inner (post thread-local butterflies) forward FFT +template +struct InnerFFT, device_capabilities> { - using consteval_parameters_t = fft::ConstevalParameters; + using consteval_parameters_t = fft::ConstevalParameters; using scalar_t = typename consteval_parameters_t::scalar_t; template static void FFT_loop(uint32_t stride, uint32_t threadID, NBL_REF_ARG(uint32_t) ownedSmemIndex, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor) { const uint32_t ShuffleRounds = consteval_parameters_t::ShuffleRounds; - const uint16_t Channels = consteval_parameters_t::Channels; + const uint16_t VirtualChannels = consteval_parameters_t::InnerVirtualChannels; // Get twiddle with k = threadID mod stride, halfN = stride const complex_t twiddle = hlsl::fft::twiddle(threadID & (stride - 1), stride); bool pingPong = false; - // Unrolling this loop increases register pressure. Why? Who knows. - // It's not like it can't reuse the registers, and calls to exchangeValues are inlined anyway. - //[unroll] - for (uint32_t round = 0; round < ShuffleRounds; round++) - { - if (round) - pingPong = !pingPong; // ping pong on sharedmem to avoid barriering - this eploits that we XOR with the same stride every consecutive round - const uint32_t lowChannel = round * ShuffledChannelsPerRound; - const uint32_t highChannel = min(Channels, lowChannel + ShuffledChannelsPerRound) - 1; - [unroll] - for (uint32_t channel = lowChannel; channel <= highChannel; channel++) + // If register pressure high, can avoid unroll + [unroll] + for (uint32_t round = 0; round < ShuffleRounds; round++) { - complex_t lo, hi; - loAccessor.get(channel, lo); - hiAccessor.get(channel, hi); - fft2::DIF::radix2(twiddle, lo, hi); - loAccessor.set(channel, lo); - hiAccessor.set(channel, hi); + if (round) + pingPong = !pingPong; // ping pong on sharedmem to avoid barriering - this eploits that we XOR with the same stride every consecutive round + const uint32_t lowChannel = round * ShuffledVirtualChannelsPerRound; + const uint32_t highChannel = min(VirtualChannels, lowChannel + ShuffledVirtualChannelsPerRound) - 1; + [unroll] + for (uint32_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIF::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + + fft::impl::exchangeValues::__call(threadID, ownedSmemIndex, lowChannel, highChannel, loAccessor, hiAccessor, stride >> 1, sharedmemAdaptor, pingPong); } - - fft::impl::exchangeValues::__call(threadID, ownedSmemIndex, lowChannel, highChannel, loAccessor, hiAccessor, stride >> 1, sharedmemAdaptor, pingPong); - } // After the last exchangeValues, the memory we just read from is now owned by us, so update ownedSmemIndex = pingPong ? ownedSmemIndex : ownedSmemIndex ^ (stride >> 1); } @@ -363,7 +402,7 @@ struct FFT static void __call(NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor) { - const uint16_t Channels = consteval_parameters_t::Channels; + const uint16_t VirtualChannels = consteval_parameters_t::InnerVirtualChannels; const uint16_t SubgroupSize = consteval_parameters_t::SubgroupSize; const uint16_t WorkgroupSize = consteval_parameters_t::WorkgroupSize; @@ -374,63 +413,63 @@ struct FFT SubgroupSize) { // Set up the memory adaptor - using adaptor_t = accessor_adaptors::StructureOfArrays; + using adaptor_t = accessor_adaptors::StructureOfArrays; adaptor_t sharedmemAdaptor; sharedmemAdaptor.accessor = sharedmemAccessor; uint32_t ownedSmemIndex = threadID; // NOT unrolling this loop increases register pressure???? [unroll] - for (uint32_t stride = WorkgroupSize; stride > SubgroupSize; stride >>= 1) - { - FFT_loop(stride, threadID, ownedSmemIndex, loAccessor, hiAccessor, sharedmemAdaptor); - } + for (uint32_t stride = WorkgroupSize; stride > SubgroupSize; stride >>= 1) + { + FFT_loop(stride, threadID, ownedSmemIndex, loAccessor, hiAccessor, sharedmemAdaptor); + } // Remember to update the accessor's state sharedmemAccessor = sharedmemAdaptor.accessor; } // Subgroup-sized FFT - subgroup2::FFT::template __call(0, Channels - 1, loAccessor, hiAccessor); + subgroup2::FFT::template __call(0, VirtualChannels - 1, loAccessor, hiAccessor); } }; -// Non-interleaved (shuffle after every butterfly) inverse FFT -template -struct FFT, device_capabilities> +// Non-interleaved (shuffle after every butterfly) inner (post thread-local butterflies) inverse FFT +template +struct InnerFFT, device_capabilities> { - using consteval_parameters_t = fft::ConstevalParameters; + using consteval_parameters_t = fft::ConstevalParameters; using scalar_t = typename consteval_parameters_t::scalar_t; template static void FFT_loop(uint32_t stride, uint32_t threadID, NBL_REF_ARG(uint32_t) ownedSmemIndex, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor) { const uint32_t ShuffleRounds = consteval_parameters_t::ShuffleRounds; - const uint16_t Channels = consteval_parameters_t::Channels; + const uint16_t VirtualChannels = consteval_parameters_t::InnerVirtualChannels; // Get twiddle with k = threadID mod stride, halfN = stride const complex_t twiddle = hlsl::fft::twiddle(threadID & ((stride << 1) - 1), stride << 1); bool pingPong = false; - //[unroll] - for (uint32_t round = 0; round < ShuffleRounds; round++) - { - if (round) - pingPong = !pingPong; // ping pong on sharedmem to avoid barriering - this eploits that we XOR with the same stride every consecutive round - const uint32_t lowChannel = round * ShuffledChannelsPerRound; - const uint32_t highChannel = min(Channels, lowChannel + ShuffledChannelsPerRound) - 1; - - fft::impl::exchangeValues::__call(threadID, ownedSmemIndex, lowChannel, highChannel, loAccessor, hiAccessor, stride, sharedmemAdaptor, pingPong); - - [unroll] - for (uint32_t channel = lowChannel; channel <= highChannel; channel++) + [unroll] + for (uint32_t round = 0; round < ShuffleRounds; round++) { - complex_t lo, hi; - loAccessor.get(channel, lo); - hiAccessor.get(channel, hi); - fft2::DIT::radix2(twiddle, lo, hi); - loAccessor.set(channel, lo); - hiAccessor.set(channel, hi); + if (round) + pingPong = !pingPong; // ping pong on sharedmem to avoid barriering - this eploits that we XOR with the same stride every consecutive round + const uint32_t lowChannel = round * ShuffledVirtualChannelsPerRound; + const uint32_t highChannel = min(VirtualChannels, lowChannel + ShuffledVirtualChannelsPerRound) - 1; + + fft::impl::exchangeValues::__call(threadID, ownedSmemIndex, lowChannel, highChannel, loAccessor, hiAccessor, stride, sharedmemAdaptor, pingPong); + + [unroll] + for (uint32_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIT::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } } - } // After the last exchangeValues, the memory we just read from is now owned by us, so update ownedSmemIndex = pingPong ? ownedSmemIndex : ownedSmemIndex ^ (stride >> 1); } @@ -438,12 +477,12 @@ struct FFT static void __call(NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor) { - const uint16_t Channels = consteval_parameters_t::Channels; + const uint16_t VirtualChannels = consteval_parameters_t::InnerVirtualChannels; const uint16_t SubgroupSize = consteval_parameters_t::SubgroupSize; const uint16_t WorkgroupSize = consteval_parameters_t::WorkgroupSize; // Subgroup-sized FFT at the start - subgroup2::FFT::template __call(0, Channels - 1, loAccessor, hiAccessor); + subgroup2::FFT::template __call(0, VirtualChannels - 1, loAccessor, hiAccessor); // Get workgroup threadID const uint32_t threadID = uint32_t(workgroup::SubgroupContiguousIndex()); @@ -452,16 +491,16 @@ struct FFT SubgroupSize) { // Set up the memory adaptor - using adaptor_t = accessor_adaptors::StructureOfArrays; + using adaptor_t = accessor_adaptors::StructureOfArrays; adaptor_t sharedmemAdaptor; sharedmemAdaptor.accessor = sharedmemAccessor; uint32_t ownedSmemIndex = threadID; [unroll] - for (uint32_t stride = SubgroupSize; stride < WorkgroupSize; stride <<= 1) - { - FFT_loop(stride, threadID, ownedSmemIndex, loAccessor, hiAccessor, sharedmemAdaptor); - } + for (uint32_t stride = SubgroupSize; stride < WorkgroupSize; stride <<= 1) + { + FFT_loop(stride, threadID, ownedSmemIndex, loAccessor, hiAccessor, sharedmemAdaptor); + } // Remember to update the accessor's state sharedmemAccessor = sharedmemAdaptor.accessor; @@ -469,6 +508,9 @@ struct FFT Date: Fri, 8 May 2026 21:29:52 -0300 Subject: [PATCH 15/15] Added division options for inner FFT --- examples_tests | 2 +- include/nbl/builtin/hlsl/workgroup2/fft.hlsl | 157 +++++++++++++------ 2 files changed, 114 insertions(+), 45 deletions(-) diff --git a/examples_tests b/examples_tests index 1d6aae3f97..59d0970dc9 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 1d6aae3f97998e36199645a6b08b76e39189b615 +Subproject commit 59d0970dc95b4bd7d5aa0bcb241e29e47c7bea29 diff --git a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl index c20a53eef8..2e984a6ebc 100644 --- a/include/nbl/builtin/hlsl/workgroup2/fft.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/fft.hlsl @@ -59,18 +59,50 @@ struct DivisionPolicy namespace impl { -template + +struct Constants +{ + NBL_CONSTEXPR_STATIC_INLINE float32_t INV_SQRT_2 = 0.707106781f; + NBL_CONSTEXPR_STATIC_INLINE float32_t INV_SQRT_3 = 0.577350269f; + NBL_CONSTEXPR_STATIC_INLINE float32_t INV_SQRT_5 = 0.447213595f; +}; + +template struct DivisionConstants { - NBL_CONSTEXPR_STATIC_INLINE uint16_t TODO = 0; + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLength = ElementsPerInvocation * (1 << (WorkgroupSizeLog2)); + NBL_CONSTEXPR_STATIC_INLINE float32_t InvFFTLength = float32_t(1) / FFTLength; + + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthA = ElementsPerInvocation >> 1; + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthB = 1u << (WorkgroupSizeLog2 - SubgroupSizeLog2); + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthC = 1u << (SubgroupSizeLog2 + 1); + NBL_CONSTEXPR_STATIC_INLINE float32_t InvFFTLengthA = float32_t(1) / FFTLengthA; + NBL_CONSTEXPR_STATIC_INLINE float32_t InvFFTLengthB = float32_t(1) / FFTLengthB; + NBL_CONSTEXPR_STATIC_INLINE float32_t InvFFTLengthC = float32_t(1) / FFTLengthC; + + NBL_CONSTEXPR_STATIC_INLINE float32_t PrimeFactorInvSqrt = (ElementsPerInvocation % 3 ? (ElementsPerInvocation % 5 ? float32_t(1) : Constants::INV_SQRT_5) : Constants::INV_SQRT_3); + + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthLog2 = WorkgroupSizeLog2 + mpl::log2_v; + NBL_CONSTEXPR_STATIC_INLINE float32_t InvSqrtFFTLength = float32_t(1) / (1u << (FFTLengthLog2 / 2)) * (FFTLengthLog2 & 1 ? Constants::INV_SQRT_2 : float32_t(1)) * PrimeFactorInvSqrt; + + // Log2A part should only consider log2 of the pure radix2, ignoring extra prime factor + NBL_CONSTEXPR_STATIC_INLINE uint32_t Radix2ElementsPerInvocation = (ElementsPerInvocation % 3 ? (ElementsPerInvocation % 5 ? ElementsPerInvocation : ElementsPerInvocation / 5) : ElementsPerInvocation / 3); + + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthLog2A = mpl::log2_v - 1; + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthLog2B = WorkgroupSizeLog2 - SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthLog2C = SubgroupSizeLog2 + 1; + NBL_CONSTEXPR_STATIC_INLINE float32_t InvSqrtFFTLengthA = float32_t(1) / (1u << (FFTLengthLog2A / 2)) * (FFTLengthLog2A & 1 ? Constants::INV_SQRT_2 : float32_t(1)) * PrimeFactorInvSqrt; + NBL_CONSTEXPR_STATIC_INLINE float32_t InvSqrtFFTLengthB = float32_t(1) / (1u << (FFTLengthLog2B / 2)) * (FFTLengthLog2B & 1 ? Constants::INV_SQRT_2 : float32_t(1)); + NBL_CONSTEXPR_STATIC_INLINE float32_t InvSqrtFFTLengthC = float32_t(1) / (1u << (FFTLengthLog2C / 2)) * (FFTLengthLog2C & 1 ? Constants::INV_SQRT_2 : float32_t(1)); }; + } //namespace impl template //NBL_PRIMARY_REQUIRES(_ElementsPerInvocation > 1 && !(_ElementsPerInvocation & 1) && _WorkgroupSizeLog2 >= 5) struct ConstevalParameters { using scalar_t = _Scalar; - using DivisionConstants = impl::DivisionConstants<_ElementsPerInvocationPerChannel, _WorkgroupSizeLog2>; + using DivisionConstants = impl::DivisionConstants<_ElementsPerInvocationPerChannel, _SubgroupSizeLog2, _WorkgroupSizeLog2>; NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationPerChannel = _ElementsPerInvocationPerChannel; NBL_CONSTEXPR_STATIC_INLINE uint16_t Channels = _Channels; @@ -376,25 +408,25 @@ struct InnerFFT lo, hi; - loAccessor.get(channel, lo); - hiAccessor.get(channel, hi); - fft2::DIF::radix2(twiddle, lo, hi); - loAccessor.set(channel, lo); - hiAccessor.set(channel, hi); - } - - fft::impl::exchangeValues::__call(threadID, ownedSmemIndex, lowChannel, highChannel, loAccessor, hiAccessor, stride >> 1, sharedmemAdaptor, pingPong); + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIF::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); } + + fft::impl::exchangeValues::__call(threadID, ownedSmemIndex, lowChannel, highChannel, loAccessor, hiAccessor, stride >> 1, sharedmemAdaptor, pingPong); + } // After the last exchangeValues, the memory we just read from is now owned by us, so update ownedSmemIndex = pingPong ? ownedSmemIndex : ownedSmemIndex ^ (stride >> 1); } @@ -420,14 +452,33 @@ struct InnerFFT SubgroupSize; stride >>= 1) - { - FFT_loop(stride, threadID, ownedSmemIndex, loAccessor, hiAccessor, sharedmemAdaptor); - } + for (uint32_t stride = WorkgroupSize; stride > SubgroupSize; stride >>= 1) + { + FFT_loop(stride, threadID, ownedSmemIndex, loAccessor, hiAccessor, sharedmemAdaptor); + } // Remember to update the accessor's state sharedmemAccessor = sharedmemAdaptor.accessor; } + + const float32_t DivisionFactor = (DivisionPolicy == fft::DivisionPolicy::DivByFullSizeHalfway ? consteval_parameters_t::DivisionConstants::InvFFTLength + : (DivisionPolicy == fft::DivisionPolicy::DivBySqrtHalfway ? consteval_parameters_t::DivisionConstants::InvSqrtFFTLength + : (DivisionPolicy == fft::DivisionPolicy::DivByFullSizeByParts ? consteval_parameters_t::DivisionConstants::InvFFTLengthC + : consteval_parameters_t::DivisionConstants::InvSqrtFFTLengthC))); // Assume DivBySqrtByParts, won't be used otherwise + + if (DivisionPolicy == fft::DivisionPolicy::DivByFullSizeHalfway || DivisionPolicy == fft::DivisionPolicy::DivBySqrtHalfway || DivisionPolicy == fft::DivisionPolicy::DivByFullSizeByParts || DivisionPolicy == fft::DivisionPolicy::DivBySqrtByParts) + { + [unroll] + for (uint32_t channel = 0; channel < VirtualChannels; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + loAccessor.set(channel, lo * scalar_t(DivisionFactor)); + hiAccessor.set(channel, hi * scalar_t(DivisionFactor)); + } + } + // Subgroup-sized FFT subgroup2::FFT::template __call(0, VirtualChannels - 1, loAccessor, hiAccessor); } @@ -450,26 +501,26 @@ struct InnerFFT::__call(threadID, ownedSmemIndex, lowChannel, highChannel, loAccessor, hiAccessor, stride, sharedmemAdaptor, pingPong); - - [unroll] - for (uint32_t channel = lowChannel; channel <= highChannel; channel++) - { - complex_t lo, hi; - loAccessor.get(channel, lo); - hiAccessor.get(channel, hi); - fft2::DIT::radix2(twiddle, lo, hi); - loAccessor.set(channel, lo); - hiAccessor.set(channel, hi); - } - } + for (uint32_t round = 0; round < ShuffleRounds; round++) + { + if (round) + pingPong = !pingPong; // ping pong on sharedmem to avoid barriering - this eploits that we XOR with the same stride every consecutive round + const uint32_t lowChannel = round * ShuffledVirtualChannelsPerRound; + const uint32_t highChannel = min(VirtualChannels, lowChannel + ShuffledVirtualChannelsPerRound) - 1; + + fft::impl::exchangeValues::__call(threadID, ownedSmemIndex, lowChannel, highChannel, loAccessor, hiAccessor, stride, sharedmemAdaptor, pingPong); + + [unroll] + for (uint32_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIT::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + } // After the last exchangeValues, the memory we just read from is now owned by us, so update ownedSmemIndex = pingPong ? ownedSmemIndex : ownedSmemIndex ^ (stride >> 1); } @@ -484,6 +535,24 @@ struct InnerFFT::template __call(0, VirtualChannels - 1, loAccessor, hiAccessor); + const float32_t DivisionFactor = (DivisionPolicy == fft::DivisionPolicy::DivByFullSizeHalfway ? consteval_parameters_t::DivisionConstants::InvFFTLength + : (DivisionPolicy == fft::DivisionPolicy::DivBySqrtHalfway ? consteval_parameters_t::DivisionConstants::InvSqrtFFTLength + : (DivisionPolicy == fft::DivisionPolicy::DivByFullSizeByParts ? consteval_parameters_t::DivisionConstants::InvFFTLengthC + : consteval_parameters_t::DivisionConstants::InvSqrtFFTLengthC))); // Assume DivBySqrtByParts, won't be used otherwise + + if (DivisionPolicy == fft::DivisionPolicy::DivByFullSizeHalfway || DivisionPolicy == fft::DivisionPolicy::DivBySqrtHalfway || DivisionPolicy == fft::DivisionPolicy::DivByFullSizeByParts || DivisionPolicy == fft::DivisionPolicy::DivBySqrtByParts) + { + [unroll] + for (uint32_t channel = 0; channel < VirtualChannels; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + loAccessor.set(channel, lo * scalar_t(DivisionFactor)); + hiAccessor.set(channel, hi * scalar_t(DivisionFactor)); + } + } + // Get workgroup threadID const uint32_t threadID = uint32_t(workgroup::SubgroupContiguousIndex());