diff --git a/examples_tests b/examples_tests index 1345dae922..59d0970dc9 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 1345dae9220598734e73ed425225b49dc3c3cfe6 +Subproject commit 59d0970dc95b4bd7d5aa0bcb241e29e47c7bea29 diff --git a/include/nbl/asset/utils/IShaderCompiler.h b/include/nbl/asset/utils/IShaderCompiler.h index 05116b8d52..526a9dd80e 100644 --- a/include/nbl/asset/utils/IShaderCompiler.h +++ b/include/nbl/asset/utils/IShaderCompiler.h @@ -261,6 +261,9 @@ class NBL_API2 IShaderCompiler : public core::IReferenceCounted SPreprocessorOptions preprocessorOptions = {}; CCache* readCache = nullptr; CCache* writeCache = nullptr; + bool optimizerIsExtraPasses = false; // Instead of disabling the default opt passes, run the provided optimization passes at the end + std::string preprocessedOutputPath = ""; + std::string spvOutputPath = ""; }; class CCache final : public IReferenceCounted diff --git a/include/nbl/builtin/hlsl/complex.hlsl b/include/nbl/builtin/hlsl/complex.hlsl index 7e8f6526ec..96ca799b92 100644 --- a/include/nbl/builtin/hlsl/complex.hlsl +++ b/include/nbl/builtin/hlsl/complex.hlsl @@ -7,6 +7,7 @@ #include #include +#include using namespace nbl::hlsl; @@ -32,22 +33,6 @@ struct complex_t : public std::complex } }; -// Fast mul by i -template -complex_t rotateLeft(NBL_CONST_REF_ARG(complex_t) value) -{ - complex_t retVal = { -value.imag(), value.real() }; - return retVal; -} - -// Fast mul by -i -template -complex_t rotateRight(NBL_CONST_REF_ARG(complex_t) value) -{ - complex_t retVal = { value.imag(), -value.real() }; - return retVal; -} - } } @@ -414,28 +399,11 @@ complex_t proj(const complex_t c) template complex_t polar(const Scalar r, const Scalar theta) { - complex_t retVal = {r * cos(theta), r * sin(theta)}; + complex_t retVal = {r * nbl::hlsl::cos(theta), r * nbl::hlsl::sin(theta)}; return retVal; } -// --------------------------------------------- Some more functions that come in handy -------------------------------------- -// Fast mul by i -template -complex_t rotateLeft(NBL_CONST_REF_ARG(complex_t) value) -{ - complex_t retVal = { -value.imag(), value.real() }; - return retVal; -} - -// Fast mul by -i -template -complex_t rotateRight(NBL_CONST_REF_ARG(complex_t) value) -{ - complex_t retVal = { value.imag(), -value.real() }; - return retVal; -} - } } @@ -456,4 +424,40 @@ NBL_REGISTER_OBJ_TYPE(complex_t,::nbl::hlsl::alignment_of_v +complex_t rotateLeft(NBL_CONST_REF_ARG(complex_t) value) +{ + complex_t retVal = { -value.imag(), value.real() }; + return retVal; +} + +// Fast mul by -i +template +complex_t rotateRight(NBL_CONST_REF_ARG(complex_t) value) +{ + complex_t retVal = { value.imag(), -value.real() }; + return retVal; +} + +// Fast square +template +complex_t square(NBL_CONST_REF_ARG(complex_t) value) +{ + Scalar real = value.real() * value.real() - value.imag() * value.imag(); + Scalar imag = 2 * value.real() * value.imag(); + complex_t retVal = { real, imag }; + return retVal; +} + +} +} + #endif diff --git a/include/nbl/builtin/hlsl/fft2/common.hlsl b/include/nbl/builtin/hlsl/fft2/common.hlsl new file mode 100644 index 0000000000..7158cd3ca4 --- /dev/null +++ b/include/nbl/builtin/hlsl/fft2/common.hlsl @@ -0,0 +1,301 @@ +#ifndef _NBL_BUILTIN_HLSL_FFT2_COMMON_INCLUDED_ +#define _NBL_BUILTIN_HLSL_FFT2_COMMON_INCLUDED_ + +#include +#include +#include +#include +#include +#include + +namespace nbl +{ +namespace hlsl +{ +namespace fft2 +{ + +uint32_t padDimension(uint32_t dimension, uint16_t subgroupSize) +{ + const uint16_t subgroupFFTSize = subgroupSize << 1; + uint32_t padded = hlsl::roundUpToPoT(dimension); + if (padded <= subgroupFFTSize) + return subgroupFFTSize; + // Consider a factor of 3 + padded = hlsl::min(padded, 3 * hlsl::roundUpToPoT(hlsl::ceilDiv(dimension, 3))); + // Do the same for a factor of 5 + padded = hlsl::min(padded, 5 * hlsl::roundUpToPoT(hlsl::ceilDiv(dimension, 5))); + // TODO: Consider if factor of 7 is viable + return padded; +} + +template 0 && N <= 4) +/** +* @brief Returns the size of the full FFT computed, in terms of number of complex elements. If the signal is real, you MUST provide a valid value for `firstAxis` (this is to run two real FFTs as one complex). + If the signal is complex, you must NOT pass any value to `firstAxis`. +* The padding rule is the following: if FFT is subgroup-sized, size is rounded up to PoT. If bigger than that, it's rounded to the smallest number that is either PoT or of the form +* 3 * 2^n or 5 * 2^n. +* +* @tparam N Number of dimensions of the signal to perform FFT on. +* +* @param [in] dimensions Size of the signal. +* @param [in] firstAxis Indicates which axis the FFT is performed on first. Only relevant for real-valued signals. +*/ +inline vector padDimensions(vector dimensions, uint16_t subgroupSize, uint16_t firstAxis = N) +{ + vector newDimensions; + for (uint16_t i = 0u; i < N; i++) + newDimensions[i] = padDimension(dimensions[i], subgroupSize); + // If real, first axis gets halved since we run two real FFTs at once + if (firstAxis < N) + newDimensions[firstAxis] /= 2; + return newDimensions; +} + +template 0 && N <= 4) +/** +* @brief Returns the size required by a buffer to hold the result of the FFT of a signal after a certain pass. +* +* @tparam N Number of dimensions of the signal to perform FFT on. +* +* @param [in] numChannels Number of channels of the signal. +* @param [in] inputDimensions Size of the signal. +* @param [in] passIx Which pass the size is being computed for. +* @param [in] axisPassOrder Order of the axis in which the FFT is computed in. Default is xyzw. +* @param [in] realFFT True if the signal is real. False by default. +* @param [in] halfFloats True if using half-precision floats. False by default. +*/ +inline uint64_t getOutputBufferSize( + uint32_t numChannels, + vector inputDimensions, + uint16_t passIx, + uint16_t subgroupSize, + vector axisPassOrder = _static_cast >(uint16_t4(0, 1, 2, 3)), + bool realFFT = false, + bool halfFloats = false +) +{ + const vector paddedDimensions = padDimensions(inputDimensions, subgroupSize, realFFT ? axisPassOrder[0] : N); + vector axesDone = promote, bool>(false); + for (uint16_t i = 0; i <= passIx; i++) + axesDone[axisPassOrder[i]] = true; + const vector passOutputDimension = lerp(inputDimensions, paddedDimensions, axesDone); + uint64_t numberOfComplexElements = uint64_t(numChannels); + for (uint16_t i = 0; i < N; i++) + numberOfComplexElements *= uint64_t(passOutputDimension[i]); + return numberOfComplexElements * (halfFloats ? sizeof(complex_t) : sizeof(complex_t)); +} + +template 0 && N <= 4) +/** +* @brief Returns the size required by a buffer to hold the result of the FFT of a signal after a certain pass, when using the FFT to convolve it against a kernel. +* +* @tparam N Number of dimensions of the signal to perform FFT on. +* +* @param [in] numChannels Number of channels of the signal. +* @param [in] inputDimensions Size of the signal. +* @param [in] kernelDimensions Size of the kernel. +* @param [in] passIx Which pass the size is being computed for. +* @param [in] axisPassOrder Order of the axis in which the FFT is computed in. Default is xyzw. +* @param [in] realFFT True if the signal is real. False by default. +* @param [in] halfFloats True if using half-precision floats. False by default. +*/ +inline uint64_t getOutputBufferSizeConvolution( + uint32_t numChannels, + vector inputDimensions, + vector kernelDimensions, + uint16_t passIx, + uint16_t subgroupSize, + vector axisPassOrder = _static_cast >(uint16_t4(0, 1, 2, 3)), + bool realFFT = false, + bool halfFloats = false +) +{ + const vector paddedDimensions = padDimensions(inputDimensions + kernelDimensions, subgroupSize, realFFT ? axisPassOrder[0] : N); + vector axesDone = promote, bool>(false); + for (uint16_t i = 0; i <= passIx; i++) + axesDone[axisPassOrder[i]] = true; + const vector passOutputDimension = lerp(inputDimensions, paddedDimensions, axesDone); + uint64_t numberOfComplexElements = uint64_t(numChannels); + for (uint16_t i = 0; i < N; i++) + numberOfComplexElements *= uint64_t(passOutputDimension[i]); + return numberOfComplexElements * (halfFloats ? sizeof(complex_t) : sizeof(complex_t)); +} + + +// Computes the kth element in the group of N roots of unity +// Notice 0 <= k < N/2, rotating counterclockwise in the forward (DIF) transform and clockwise in the inverse (DIT) +template +complex_t twiddle(uint32_t k, uint32_t halfN) +{ + complex_t retVal; + const Scalar kthRootAngleRadians = numbers::pi *Scalar(k) / Scalar(halfN); + Scalar cosine = nbl::hlsl::cos(kthRootAngleRadians); + Scalar sine = nbl::hlsl::sin(kthRootAngleRadians); + retVal.real(cosine); + if (!inverse) + retVal.imag(-sine); + else + retVal.imag(sine); + return retVal; +} + +template +struct DIX +{ + static void radix2(complex_t twiddle, NBL_REF_ARG(complex_t) lo, NBL_REF_ARG(complex_t) hi) + { + plus_assign< complex_t > plusAss; + //Decimation in time - inverse + if (inverse) { + complex_t wHi = twiddle * hi; + hi = lo - wHi; + plusAss(lo, wHi); + } + //Decimation in frequency - forward + else { + complex_t diff = lo - hi; + plusAss(lo, hi); + hi = twiddle * diff; + } + } + + static void radix3( + complex_t twiddle1, + complex_t twiddle2, + NBL_REF_ARG(complex_t) lo, + NBL_REF_ARG(complex_t) mid, + NBL_REF_ARG(complex_t) hi) + { + plus_assign< complex_t > plusAss; + NBL_CONSTEXPR_FUNC_SCOPE_VAR Scalar SQRT3_OVER_2 = Scalar(0.8660254037844386); + + // Decimation in time - inverse + if (inverse) { + // Apply twiddles first, then butterfly + complex_t w1 = twiddle1 * mid; + complex_t w2 = twiddle2 * hi; + + complex_t s = w1 + w2; + complex_t d = w1 - w2; + + // u = i * sqrt(3)/2 * d + complex_t u = complex_t(d.imag() * (-SQRT3_OVER_2), d.real() * SQRT3_OVER_2); + + complex_t t = lo - s * Scalar(0.5); + plusAss(lo, s); // lo = lo + s + mid = t + u; // inverse: swapped signs vs forward + hi = t - u; + } + // Decimation in frequency - forward + else { + // Butterfly first, then apply twiddles + complex_t s = mid + hi; + complex_t d = mid - hi; + + // u = i * sqrt(3)/2 * d + complex_t u = complex_t(d.imag() * (-SQRT3_OVER_2), d.real() * SQRT3_OVER_2); + + complex_t t = lo - s * Scalar(0.5); + plusAss(lo, s); // lo = lo + s + mid = twiddle1 * (t - u); + hi = twiddle2 * (t + u); + } + } + + static void radix5( + complex_t twiddle1, + complex_t twiddle2, + complex_t twiddle3, + complex_t twiddle4, + NBL_REF_ARG(complex_t) x0, + NBL_REF_ARG(complex_t) x1, + NBL_REF_ARG(complex_t) x2, + NBL_REF_ARG(complex_t) x3, + NBL_REF_ARG(complex_t) x4) + { + plus_assign< complex_t > plusAss; + + NBL_CONSTEXPR_FUNC_SCOPE_VAR Scalar COS_2PI_5 = Scalar(0.30901699437494742); // (sqrt(5) - 1)/4 + NBL_CONSTEXPR_FUNC_SCOPE_VAR Scalar COS_4PI_5 = Scalar(-0.80901699437494742); // -(sqrt(5) + 1)/4 + NBL_CONSTEXPR_FUNC_SCOPE_VAR Scalar SIN_2PI_5 = Scalar(0.95105651629515357); + NBL_CONSTEXPR_FUNC_SCOPE_VAR Scalar SIN_4PI_5 = Scalar(0.58778525229247312); + + //Decimation in time - inverse + if (inverse) { + // Apply twiddles first + complex_t w1 = twiddle1 * x1; + complex_t w2 = twiddle2 * x2; + complex_t w3 = twiddle3 * x3; + complex_t w4 = twiddle4 * x4; + + // 5-point inverse DFT: exploit W_5 conjugate symmetry via sum/diff pairs + complex_t s1 = w1 + w4; + complex_t d1 = w1 - w4; + complex_t s2 = w2 + w3; + complex_t d2 = w2 - w3; + + complex_t tA = x0 + COS_2PI_5 * s1 + COS_4PI_5 * s2; + complex_t tB = x0 + COS_4PI_5 * s1 + COS_2PI_5 * s2; + + // uA = i * (sA * d1 + sB * d2), uB = i * (sB * d1 - sA * d2) + complex_t vA = SIN_2PI_5 * d1 + SIN_4PI_5 * d2; + complex_t vB = SIN_4PI_5 * d1 - SIN_2PI_5 * d2; + complex_t uA = rotateLeft(vA); + complex_t uB = rotateLeft(vB); + + plusAss(x0, s1 + s2); // x0 = x0 + s1 + s2 + // inverse flips the sign on u compared to forward + x1 = tA + uA; + x4 = tA - uA; + x2 = tB + uB; + x3 = tB - uB; + } + //Decimation in frequency - forward + else { + // 5-point DFT first + complex_t s1 = x1 + x4; + complex_t d1 = x1 - x4; + complex_t s2 = x2 + x3; + complex_t d2 = x2 - x3; + + complex_t tA = x0 + COS_2PI_5 * s1 + COS_4PI_5 * s2; + complex_t tB = x0 + COS_4PI_5 * s1 + COS_2PI_5 * s2; + + complex_t vA = SIN_2PI_5 * d1 + SIN_4PI_5 * d2; + complex_t vB = SIN_4PI_5 * d1 - SIN_2PI_5 * d2; + complex_t uA = rotateLeft(vA); + complex_t uB = rotateLeft(vB); + + plusAss(x0, s1 + s2); // x0 = x0 + s1 + s2 + // Apply twiddles to the DFT outputs + x1 = twiddle1 * (tA - uA); + x4 = twiddle4 * (tA + uA); + x2 = twiddle2 * (tB - uB); + x3 = twiddle3 * (tB + uB); + } + } +}; + +template +using DIT = DIX; + +template +using DIF = DIX; + +// ------------------------------------------------- Utils --------------------------------------------------------- +// +// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi +template +void unpack(NBL_REF_ARG(complex_t) lo, NBL_REF_ARG(complex_t) hi) +{ + complex_t x = (lo + conj(hi)) * Scalar(0.5); + hi = rotateRight(lo - conj(hi)) * Scalar(0.5); + lo = x; +} + +} //namespace fft2 +} //namespace hlsl +} //namespace nbl + +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/math/functions.hlsl b/include/nbl/builtin/hlsl/math/functions.hlsl index 692b2aa594..8573856674 100644 --- a/include/nbl/builtin/hlsl/math/functions.hlsl +++ b/include/nbl/builtin/hlsl/math/functions.hlsl @@ -91,7 +91,7 @@ scalar_type_t lpNorm(NBL_CONST_REF_ARG(T) v) return impl::lp_norm::__call(v); } - +// [Francisco] sqrt hits the SFU the same as a sin call, calling both sin and cos might just be faster and more accurate? // valid only for `theta` in [-PI,PI] template ) void sincos(T theta, NBL_REF_ARG(T) s, NBL_REF_ARG(T) c) diff --git a/include/nbl/builtin/hlsl/math/intutil.hlsl b/include/nbl/builtin/hlsl/math/intutil.hlsl index 7394e03ae4..2a76ec6234 100644 --- a/include/nbl/builtin/hlsl/math/intutil.hlsl +++ b/include/nbl/builtin/hlsl/math/intutil.hlsl @@ -21,11 +21,17 @@ NBL_CONSTEXPR_FORCED_INLINE_FUNC bool isPoT(Integer value) return !isNPoT(value); } +template) +// Returns ceiled log2 +NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer log2ceil(Integer value) +{ + return Integer(1 + hlsl::findMSB(value - Integer(1))); +} template) NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer roundUpToPoT(Integer value) { - return Integer(0x1u) << Integer(1 + hlsl::findMSB(value - Integer(1))); // this wont result in constexpr because findMSB is not one + return Integer(0x1u) << log2ceil(value); } template) @@ -34,11 +40,16 @@ NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer roundDownToPoT(Integer value) return Integer(0x1u) << hlsl::findMSB(value); } +template) +NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer ceilDiv(Integer dividend, Integer divisor) +{ + return (dividend + divisor - 1) / divisor; +} + template) NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer roundUp(Integer value, Integer multiple) { - Integer tmp = (value + multiple - 1u) / multiple; - return tmp * multiple; + return ceilDiv(value, multiple) * multiple; } template) @@ -58,6 +69,93 @@ NBL_CONSTEXPR_FORCED_INLINE_FUNC Integer align(Integer alignment, Integer size, return address = nextAlignedAddr; } +// Bitshift utils +// TODO: These can be expanded to shift by more than just one position at a time +// TODO: Can be made to wrok on uint64_t + +// Given an N-bit number stored as uint32_t, performs a circular bit shift right on the upper H bits +template +enable_if_t<(1 < H) && (H <= N) && (N <= 32), uint32_t> circularBitShiftRightHigher(uint32_t i) +{ + // Highest H bits are numbered N-1 through N - H + // N - H is then the middle bit + // Lowest bits numbered from 0 through N - H - 1 + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t lowMask = (1 << (N - H)) - 1; + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t midMask = 1 << (N - H); + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t highMask = ~(lowMask | midMask); + + uint32_t low = i & lowMask; + uint32_t mid = i & midMask; + uint32_t high = i & highMask; + + high >>= 1; + mid <<= H - 1; + + return mid | high | low; +} + +// Given an N-bit number stored as uint32_t, performs a circular bit shift left on the upper H bits +template +enable_if_t<(1 < H) && (H <= N) && (N < 32), uint32_t> circularBitShiftLeftHigher(uint32_t i) +{ + // Highest H bits are numbered N-1 through N - H + // N - 1 is then the highest bit, and N - 2 through N - H are the middle bits + // Lowest bits numbered from 0 through N - H - 1 + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t lowMask = (1 << (N - H)) - 1; + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t highMask = 1 << (N - 1); + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t midMask = ~(lowMask | highMask); + + uint32_t low = i & lowMask; + uint32_t mid = i & midMask; + uint32_t high = i & highMask; + + mid <<= 1; + high >>= H - 1; + + return mid | high | low; +} +// Perform a circular bit shift right on the lower L bits of a number +template +enable_if_t<(1 < L), uint32_t> circularBitShiftRightLower(uint32_t i) +{ + // Lowest bit is indexed 0 + // Middle bits numbered 1 to L-1 + // Highest bits numbered from L through N-1 + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t lowMask = 1; + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t midMask = ((1 << L) - 1) ^ 1; + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t highMask = ~(lowMask | midMask); + + uint32_t low = i & lowMask; + uint32_t mid = i & midMask; + uint32_t high = i & highMask; + + low <<= L - 1; + mid >>= 1; + + return high | low | mid; +} + +// Perform a circular bit shift left on the lower L bits of a number +template +enable_if_t<(1 < L), uint32_t> circularBitShiftLeftLower(uint32_t i) +{ + // Lowest L - 1 bits numbered 0 through L - 2 + // L - 1 is then the middle bit + // L through N-1 the higher bits + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t lowMask = (1 << (L - 1)) - 1; + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t midMask = 1 << (L - 1); + NBL_CONSTEXPR_FUNC_SCOPE_VAR uint32_t highMask = ~(lowMask | midMask); + + uint32_t low = i & lowMask; + uint32_t mid = i & midMask; + uint32_t high = i & highMask; + + low <<= 1; + mid >>= L - 1; + + return high | low | mid; +} + // ------------------------------------- CPP ONLY ---------------------------------------------------------- #ifndef __HLSL_VERSION diff --git a/include/nbl/builtin/hlsl/mpl.hlsl b/include/nbl/builtin/hlsl/mpl.hlsl index 7734dea15f..9fb5372a8f 100644 --- a/include/nbl/builtin/hlsl/mpl.hlsl +++ b/include/nbl/builtin/hlsl/mpl.hlsl @@ -123,6 +123,12 @@ struct find_lsb }; template NBL_CONSTEXPR_INLINE_NSPC_SCOPE_VAR uint64_t find_lsb_v = find_lsb::value; + +template +struct ceil_div : integral_constant {}; +template +NBL_CONSTEXPR_INLINE_NSPC_SCOPE_VAR uint64_t ceil_div_v = ceil_div::value; + } } } diff --git a/include/nbl/builtin/hlsl/subgroup2/fft.hlsl b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl new file mode 100644 index 0000000000..d01f332b70 --- /dev/null +++ b/include/nbl/builtin/hlsl/subgroup2/fft.hlsl @@ -0,0 +1,251 @@ +#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_FFT_INCLUDED_ +#define _NBL_BUILTIN_HLSL_SUBGROUP2_FFT_INCLUDED_ + +#include "nbl/builtin/hlsl/fft2/common.hlsl" +#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" +#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl" +#include "nbl/builtin/hlsl/concepts/accessors/fft.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace subgroup2 +{ + +// ----------------------------------------------------------------------------------------------------------------------------------------------------------------- +template +struct FFT +{ + template + static void __call(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor); +}; + +// ---------------------------------------- Radix 2 forward transform - DIF ------------------------------------------------------- + +template +struct FFT +{ + template + static void FFT_loop(uint32_t threadStride, uint16_t lowChannel, uint16_t highChannel, NBL_CONST_REF_ARG(complex_t) twiddle, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) + { + const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & threadStride); + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + const vector toTrade = topHalf ? vector (lo.real(), lo.imag()) : vector (hi.real(), hi.imag()); + const vector exchanged = glsl::subgroupShuffleXor< vector >(toTrade, threadStride); + if (topHalf) + { + lo.real(exchanged.x); + lo.imag(exchanged.y); + } + else + { + hi.real(exchanged.x); + hi.imag(exchanged.y); + } + fft2::DIF::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + } + + template + static void __call(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) + { + // special first iteration + complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID(), SubgroupSize); + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIF::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + + // Decimation in Frequency + // Compute all twiddles at the start, then reshare them among threads + if (ShareTwiddles) + { + uint32_t iteration = 1; + [unroll] + for (uint32_t threadStride = SubgroupSize >> 1; threadStride > 0; threadStride >>= 1) + { + const vector toTrade = vector (twiddle.real(), twiddle.imag()); + const vector otherTwiddle = glsl::subgroupShuffle< vector >(toTrade, (glsl::gl_SubgroupInvocationID() & (threadStride - 1)) << iteration); + twiddle.real(otherTwiddle.x); + twiddle.imag(otherTwiddle.y); + FFT_loop(threadStride, lowChannel, highChannel, twiddle, loAccessor, hiAccessor); + iteration++; + } + } + // Recompute twiddles at every step + else + { + [unroll] + for (uint32_t threadStride = SubgroupSize >> 1; threadStride > 0; threadStride >>= 1) + { + // Get twiddle with k = subgroupInvocation mod threadStride, halfN = threadStride + const complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID() & (threadStride - 1), threadStride); + FFT_loop(threadStride, lowChannel, highChannel, twiddle, loAccessor, hiAccessor); + } + } + } + + // Only uses subgroup methods, but is actually used at workgroup level. Used by the interleaved workgroup FFT at bigger than subgroup strides + template + static void __callInterleaved(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) + { + const uint32_t loLaneIndex = (glsl::gl_SubgroupInvocationID() << NumSubgroupsLog2) + glsl::gl_SubgroupID(); + // special first iteration + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) + { + // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod WorkgroupSize, halfN = WorkgroupSize + const complex_t twiddle = fft2::twiddle(loLaneIndex, WorkgroupSize); + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIF::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + + // Decimation in Frequency + uint32_t threadStride = SubgroupSize >> 1; + [unroll] + for (uint32_t elementStride = WorkgroupSize >> 1; elementStride > SubgroupSize; elementStride >>= 1) + { + // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod elementStride, halfN = elementStride + const complex_t twiddle = fft2::twiddle(loLaneIndex & (elementStride - 1), elementStride); + FFT_loop(threadStride, lowChannel, highChannel, twiddle, loAccessor, hiAccessor); + threadStride >>= 1; + } + } +}; + + +// ---------------------------------------- Radix 2 inverse transform - DIT ------------------------------------------------------- + +template +struct FFT +{ + template + static void FFT_loop(uint32_t stride, uint16_t lowChannel, uint16_t highChannel, NBL_CONST_REF_ARG(complex_t) twiddle, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) + { + const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & stride); + + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIT::radix2(twiddle, lo, hi); + + const vector toTrade = topHalf ? vector (lo.real(), lo.imag()) : vector (hi.real(), hi.imag()); + const vector exchanged = glsl::subgroupShuffleXor< vector >(toTrade, stride); + if (topHalf) + { + lo.real(exchanged.x); + lo.imag(exchanged.y); + } + else + { + hi.real(exchanged.x); + hi.imag(exchanged.y); + } + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + } + + template + static void __call(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) + { + // Decimation in Time + // Compute all twiddles at the start, then shuffle them + if (ShareTwiddles) + { + const complex_t ownedTwiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID(), SubgroupSize); + uint32_t reverseStride = SubgroupSize; + [unroll] + for (uint32_t threadStride = 1; threadStride < SubgroupSize; threadStride <<= 1) + { + const vector toTrade = vector (ownedTwiddle.real(), ownedTwiddle.imag()); + const vector otherTwiddle = glsl::subgroupShuffle< vector >(toTrade, (glsl::gl_SubgroupInvocationID() & (threadStride - 1)) * reverseStride); + const complex_t twiddle = { otherTwiddle.x , otherTwiddle.y }; + FFT_loop(threadStride, lowChannel, highChannel, twiddle, loAccessor, hiAccessor); + reverseStride >>= 1; + } + } + // Compute each twiddle at each iteration + else + { + [unroll] + for (uint32_t threadStride = 1; threadStride < SubgroupSize; threadStride <<= 1) + { + // Get twiddle with k = subgroupInvocation mod threadStride, halfN = threadStride + const complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID() & (threadStride - 1), threadStride); + FFT_loop(threadStride, lowChannel, highChannel, twiddle, loAccessor, hiAccessor); + } + } + + // special last iteration + const complex_t twiddle = fft2::twiddle(glsl::gl_SubgroupInvocationID(), SubgroupSize); + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIT::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + } + + template + static void __callInterleaved(uint16_t lowChannel, uint16_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor) + { + const uint32_t loLaneIndex = (glsl::gl_SubgroupInvocationID() << NumSubgroupsLog2) + glsl::gl_SubgroupID(); + // Decimation in Time + uint32_t threadStride = SubgroupSize >> (NumSubgroupsLog2 - 1); + [unroll] + for (uint32_t elementStride = SubgroupSize << 1; elementStride < WorkgroupSize; elementStride <<= 1) + { + // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod elementStride, halfN = elementStride + const complex_t twiddle = fft2::twiddle(loLaneIndex & (elementStride - 1), elementStride); + FFT_loop(threadStride, lowChannel, highChannel, loAccessor, hiAccessor); + threadStride <<= 1; + } + + // special last iteration + // Get twiddle with k = gl_SubgroupInvocationID() * NumSubgroups + gl_SubgroupID() mod WorkgroupSize, halfN = WorkgroupSize + const complex_t twiddle = fft2::twiddle(loLaneIndex, WorkgroupSize); + [unroll] + for (uint16_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIT::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + } +}; + + +} //namespace subgroup2 +} //namespace hlsl +} //namespace nbl + +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/workgroup/basic.hlsl b/include/nbl/builtin/hlsl/workgroup/basic.hlsl index 3467b46407..dc69f37e17 100644 --- a/include/nbl/builtin/hlsl/workgroup/basic.hlsl +++ b/include/nbl/builtin/hlsl/workgroup/basic.hlsl @@ -29,7 +29,7 @@ uint16_t SubgroupContiguousIndex() assert(retval +#include +#include + +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_FFT_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP2_FFT_INCLUDED_ + +// ------------------------------- COMMON ----------------------------------------- + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup2 +{ + +namespace fft +{ +// Minimum size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT +template +uint32_t minimumSharedMemoryDWORDs(uint16_t workgroupSizeLog2) +{ + NBL_IF_CONSTEXPR(Interleaved) + { + return (sizeof(complex_t) / sizeof(uint32_t)) << (workgroupSizeLog2 + 1); + } + else + { + return (sizeof(complex_t) / sizeof(uint32_t)) << workgroupSizeLog2; + } +} + +// The DFT (and DFT) have two different formulations. One of them doesn't divide in DFT and divides by N in the IDFT. This makes the determinant of the DFT +// sqrt(N) and the determinant of the IDFT as sqrt(N)^(-1). This formulation is problematic, for example, when performing FFT Convolution of images when +// using half-precision, since if N is big this can make the FFT along the second axis (and sometimes even the first!) exceed the representable range and become +// NaNs. The other formulation divides by sqrt(N) on both the DFT and IDFT, giving a determinant of 1 for both transforms. This can be used to avoid +// overflow. +// These policies describe different ways of avoiding overflow by dividing at different moments through the algorithm. +struct DivisionPolicy +{ + // No division performed at any step of the FFT + NBL_CONSTEXPR_STATIC_INLINE uint16_t NoDivision = 0; + // Divides the array by sqrt(FFTSize) right at the start of the algorithm + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivBySqrtAtStart = NoDivision + 1; + // Divides the array by sqrt(FFTSize) at the time of the last workgroup barrier before subgroupFFT (forward) or at the time of the first workgroup barrier + // after subgroupFFT (inverse). + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivBySqrtHalfway = DivBySqrtAtStart + 1; + // Divides the array by sqrt(FFTSize) right at the end of the algorithm + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivBySqrtAtEnd = DivBySqrtHalfway + 1; + // Total division by sqrt(FFTSize) in three steps: Once right at the start of the algorithm, once before the first workgroupBarrier, and once after the last workgroupBarrier. + // Only valid if ElementsPerInvocation > 2. + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivBySqrtByParts = DivBySqrtAtEnd + 1; + // The three following all perform divisions in the same manner as their counterparts above, but they divide the array by `FFTSize`. + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivByFullSizeAtStart = DivBySqrtByParts + 1; + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivByFullSizeHalfway = DivByFullSizeAtStart + 1; + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivByFullSizeAtEnd = DivByFullSizeHalfway + 1; + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivByFullSizeByParts = DivByFullSizeAtEnd + 1; +}; + +namespace impl +{ + +struct Constants +{ + NBL_CONSTEXPR_STATIC_INLINE float32_t INV_SQRT_2 = 0.707106781f; + NBL_CONSTEXPR_STATIC_INLINE float32_t INV_SQRT_3 = 0.577350269f; + NBL_CONSTEXPR_STATIC_INLINE float32_t INV_SQRT_5 = 0.447213595f; +}; + +template +struct DivisionConstants +{ + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLength = ElementsPerInvocation * (1 << (WorkgroupSizeLog2)); + NBL_CONSTEXPR_STATIC_INLINE float32_t InvFFTLength = float32_t(1) / FFTLength; + + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthA = ElementsPerInvocation >> 1; + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthB = 1u << (WorkgroupSizeLog2 - SubgroupSizeLog2); + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthC = 1u << (SubgroupSizeLog2 + 1); + NBL_CONSTEXPR_STATIC_INLINE float32_t InvFFTLengthA = float32_t(1) / FFTLengthA; + NBL_CONSTEXPR_STATIC_INLINE float32_t InvFFTLengthB = float32_t(1) / FFTLengthB; + NBL_CONSTEXPR_STATIC_INLINE float32_t InvFFTLengthC = float32_t(1) / FFTLengthC; + + NBL_CONSTEXPR_STATIC_INLINE float32_t PrimeFactorInvSqrt = (ElementsPerInvocation % 3 ? (ElementsPerInvocation % 5 ? float32_t(1) : Constants::INV_SQRT_5) : Constants::INV_SQRT_3); + + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthLog2 = WorkgroupSizeLog2 + mpl::log2_v; + NBL_CONSTEXPR_STATIC_INLINE float32_t InvSqrtFFTLength = float32_t(1) / (1u << (FFTLengthLog2 / 2)) * (FFTLengthLog2 & 1 ? Constants::INV_SQRT_2 : float32_t(1)) * PrimeFactorInvSqrt; + + // Log2A part should only consider log2 of the pure radix2, ignoring extra prime factor + NBL_CONSTEXPR_STATIC_INLINE uint32_t Radix2ElementsPerInvocation = (ElementsPerInvocation % 3 ? (ElementsPerInvocation % 5 ? ElementsPerInvocation : ElementsPerInvocation / 5) : ElementsPerInvocation / 3); + + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthLog2A = mpl::log2_v - 1; + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthLog2B = WorkgroupSizeLog2 - SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTLengthLog2C = SubgroupSizeLog2 + 1; + NBL_CONSTEXPR_STATIC_INLINE float32_t InvSqrtFFTLengthA = float32_t(1) / (1u << (FFTLengthLog2A / 2)) * (FFTLengthLog2A & 1 ? Constants::INV_SQRT_2 : float32_t(1)) * PrimeFactorInvSqrt; + NBL_CONSTEXPR_STATIC_INLINE float32_t InvSqrtFFTLengthB = float32_t(1) / (1u << (FFTLengthLog2B / 2)) * (FFTLengthLog2B & 1 ? Constants::INV_SQRT_2 : float32_t(1)); + NBL_CONSTEXPR_STATIC_INLINE float32_t InvSqrtFFTLengthC = float32_t(1) / (1u << (FFTLengthLog2C / 2)) * (FFTLengthLog2C & 1 ? Constants::INV_SQRT_2 : float32_t(1)); +}; + +} //namespace impl + +template //NBL_PRIMARY_REQUIRES(_ElementsPerInvocation > 1 && !(_ElementsPerInvocation & 1) && _WorkgroupSizeLog2 >= 5) +struct ConstevalParameters +{ + using scalar_t = _Scalar; + using DivisionConstants = impl::DivisionConstants<_ElementsPerInvocationPerChannel, _SubgroupSizeLog2, _WorkgroupSizeLog2>; + + NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationPerChannel = _ElementsPerInvocationPerChannel; + NBL_CONSTEXPR_STATIC_INLINE uint16_t Channels = _Channels; + NBL_CONSTEXPR_STATIC_INLINE uint16_t InnerVirtualChannels = Channels * (ElementsPerInvocationPerChannel >> 1); + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(1) << SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(1) << WorkgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t NumSubgroupsLog2 = WorkgroupSizeLog2 - SubgroupSizeLog2; + + NBL_CONSTEXPR_STATIC_INLINE uint16_t ShuffledVirtualChannelsPerRound = _ShuffledVirtualChannelsPerRound; + NBL_CONSTEXPR_STATIC_INLINE uint16_t ShuffleRounds = mpl::ceil_div_v; + NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = ShuffledVirtualChannelsPerRound * ((sizeof(complex_t) / sizeof(uint32_t)) << (WorkgroupSizeLog2 + (_Interleaved ? 1 : 0))); + + NBL_CONSTEXPR_STATIC_INLINE bool ShareTwiddles = _ShareTwiddles; + + NBL_CONSTEXPR_STATIC_INLINE uint16_t DivisionPolicy = _DivisionPolicy; +}; + +// Takes an elements accessor (lo/hi) with signature `acc.set/get(channel, pair)` and flattens it to 1D +// Workgroup-sized FFT works on many channels of a single pair each. Since these are constant folded it's the easiest solution +template +struct WorkgroupRadix2AccessorAdaptor +{ + static WorkgroupRadix2AccessorAdaptor create(NBL_REF_ARG(InvocationElementsAccessor) _accessor) + { + WorkgroupRadix2AccessorAdaptor retVal; + retVal.accessor = _accessor; + return retVal; + } + + void get(uint32_t virtualChannel, NBL_REF_ARG(complex_t) value) + { + const uint32_t channel = virtualChannel / Channels; + const uint32_t pair = virtualChannel % Channels; + accessor.get(channel, pair, value); + } + + void set(uint32_t virtualChannel, NBL_CONST_REF_ARG(complex_t) value) + { + const uint32_t channel = virtualChannel / Channels; + const uint32_t pair = virtualChannel % Channels; + accessor.set(channel, pair, value); + } + + InvocationElementsAccessor accessor; +}; +} //namespace fft + +struct OptimalFFTParameters +{ + uint16_t elementsPerInvocation : 8; + uint16_t workgroupSizeLog2 : 8; + + // Used to check if the parameters returned by `optimalFFTParameters` are valid + bool areValid() + { + return elementsPerInvocation > 0 && workgroupSizeLog2 > 0; + } +}; + +/** +* @brief Returns the best parameters (according to our metric) to run an FFT +* +* @param [in] maxWorkgroupSize The max number of threads that can be launched in a single workgroup +* @param [in] inputArrayLength The length of the array to run an FFT on +* @param [in] subgroupSize Number of threads running in a subgroup +*/ +inline OptimalFFTParameters optimalFFTParameters(uint32_t maxWorkgroupSize, uint32_t inputArrayLength, uint32_t subgroupSize) +{ + NBL_CONSTEXPR_STATIC OptimalFFTParameters invalidParameters = { 0 , 0 }; + + if (subgroupSize < 2 || maxWorkgroupSize < subgroupSize || inputArrayLength <= subgroupSize) + return invalidParameters; + // Pad inputarrayLength to size that FFT algo handles + const uint32_t FFTLength = hlsl::fft2::padDimension(inputArrayLength, uint16_t(subgroupSize)); + // Round maxWorkgroupSize down to PoT + const uint32_t actualMaxWorkgroupSize = hlsl::roundDownToPoT(maxWorkgroupSize); + // Max number of threads that can run the FFT + uint32_t maxThreads = FFTLength / 2; + // Factors of 3 and 5 do not contribute to the amount of max threads since those are handled in-register to keep everything PoT later + maxThreads /= maxThreads % 3 ? 1 : 3; + maxThreads /= maxThreads % 5 ? 1 : 5; + // Both are PoT + const uint16_t workgroupSizeLog2 = uint16_t(findMSB(min(maxThreads, actualMaxWorkgroupSize))); + + // Parameters are valid if the workgroup size is at most half of the FFT Length and at least as big as the subgroupSize + if ((FFTLength >> workgroupSizeLog2) <= 1 || subgroupSize > (1u << workgroupSizeLog2)) + { + return invalidParameters; + } + + const uint16_t elementsPerInvocation = uint16_t(FFTLength >> workgroupSizeLog2); + const OptimalFFTParameters retVal = { elementsPerInvocation, workgroupSizeLog2 }; + + return retVal; +} + +namespace impl +{ +template +struct FFTIndexingUtilsHelper +{ + // Maps the lane of index `laneIdx` at the end of the DIF diagram to its corresponding frequency position as an element of the DFT. + static uint32_t mapLaneToFreq(uint32_t laneIdx) + { + NBL_IF_CONSTEXPR(ExtraPrimeFactor > 1) + { + const uint32_t radix2mask = (1 << Radix2FFTSizeLog2) - 1; + return ExtraPrimeFactor * hlsl::bitReverseAs(laneIdx, Radix2FFTSizeLog2) + (laneIdx >> Radix2FFTSizeLog2); + } + else + { + return hlsl::bitReverseAs(laneIdx, Radix2FFTSizeLog2); + } + } + + // Implements fast division by 3 or 5, needed by `mapFreqtoLane` + static uint32_t fastDiv(uint32_t x) + { + NBL_IF_CONSTEXPR(ExtraPrimeFactor == 3) + { + return (x * 43691u) >> 17; // valid for x <= 98303 + } + else // ExtraPrimeFactor == 5 + { + return (x * 52429u) >> 18; // valid for x <= 81919 + } + } + + // Inverse of `mapLaneToFreq`. Maps a frequency index `freqIdx` into the DFT to the lane in the DIF diagram that outputs it. + static uint32_t mapFreqToLane(uint32_t freqIdx) + { + NBL_IF_CONSTEXPR(ExtraPrimeFactor > 1) + { + const uint32_t divByPrimeFactor = fastDiv(freqIdx); + return hlsl::bitReverseAs(divByPrimeFactor, Radix2FFTSizeLog2) + ((freqIdx - ExtraPrimeFactor * divByPrimeFactor) << Radix2FFTSizeLog2); + } + else + { + return hlsl::bitReverseAs(freqIdx, Radix2FFTSizeLog2); + } + } + + // log2(ElementsPerInvocation / ExtraPrimeFactor) + NBL_CONSTEXPR_STATIC_INLINE uint16_t Radix2FFTSizeLog2 = WorkgroupSizeLog2 + Radix2ElementsPerInvocationLog2; + // Size of the full FFT if no mixed radix used, otherwise size of the sub-FFTs computed after the first radix-3/5 step in the DIF forward FFT + NBL_CONSTEXPR_STATIC_INLINE uint16_t Radix2FFTSize = uint16_t(1) << (Radix2FFTSizeLog2); + // Total size of the FFT computed + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = ExtraPrimeFactor * Radix2FFTSize; +}; +} // namespace impl + +template +struct FFTIndexingUtils +{ + using helper_t = impl::FFTIndexingUtilsHelper; + + // Maps the array index 'arrayIdx' of the output of an FFT in workgroup-linear order (meaning all threads write their local element 0 contiguously + // and in ascending order by threadIndex, then element 1 and so on) to its corresponding frequency position as an element of the DFT. + static uint32_t mapArrayToFreq(uint32_t arrayIdx) + { + return helper_t::mapLaneToFreq(circularBitShiftLeftLower(arrayIdx)); + } + + // Maps a frequency index 'freqIdx' into the DFT to its corresponding position in the output array of an FFT when written in workgroup-linear order. + static uint32_t mapFreqToArray(uint32_t freqIdx) + { + return circularBitShiftRightLower(helper_t::mapFreqToLane(freqIdx)); + } + + // Mirrors an index about the Nyquist frequency in the DFT order + static uint32_t getDFTMirrorIndex(uint32_t freqIdx) + { + return (FFTSize - freqIdx) & (FFTSize - 1); + } + + // Given an index `arrayIdx` of an element into the output array of an FFT, get the index into the same array of the element corresponding + // to its negative frequency + static uint32_t getNablaMirrorIndex(uint32_t arrayIdx) + { + return mapFreqToArray(getDFTMirrorIndex(mapArrayToFreq(arrayIdx))); + } + + // log2(ElementsPerInvocation / ExtraPrimeFactor) + NBL_CONSTEXPR_STATIC_INLINE uint16_t Radix2FFTSizeLog2 = helper_t::Radix2FFTSizeLog2; + // Size of the full FFT if no mixed radix used, otherwise size of the sub-FFTs computed after the first radix-3/5 step in the DIF forward FFT + NBL_CONSTEXPR_STATIC_INLINE uint16_t Radix2FFTSize = helper_t::Radix2FFTSize; + // Total size of the FFT computed + NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = helper_t::FFTSize; +}; + +// TODO: Implement when doing 2D FFTConv +template +struct FFTMirrorTradeUtils; + +} +} +} +// ------------------------------- END COMMON --------------------------------------------- + +// ------------------------------- HLSL ONLY --------------------------------------------- + +#ifdef __HLSL_VERSION + +#include "nbl/builtin/hlsl/subgroup2/fft.hlsl" +#include "nbl/builtin/hlsl/workgroup/basic.hlsl" +#include "nbl/builtin/hlsl/glsl_compat/core.hlsl" +#include "nbl/builtin/hlsl/mpl.hlsl" +#include "nbl/builtin/hlsl/memory_accessor.hlsl" +#include "nbl/builtin/hlsl/bit.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup2 +{ + +//-------------- ---------------------------------------- UTILS -------------------------------------------------------- + +namespace fft +{ +namespace impl +{ + +template +struct exchangeValues +{ + static void __call(uint32_t threadID, uint32_t ownedSmemIndex, uint32_t lowChannel, uint32_t highChannel, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor, NBL_REF_ARG(bool) pingPong) + { + const bool topHalf = bool(threadID & stride); + const uint32_t writeIndex = pingPong ? ownedSmemIndex ^ stride : ownedSmemIndex; + const uint32_t readIndex = pingPong ? ownedSmemIndex : ownedSmemIndex ^ stride; + // Write elements to sharedmem + uint32_t adaptorOffset = 0; + [unroll] + for (uint32_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + vector toExchange = topHalf ? vector(lo.real(), lo.imag()) : vector(hi.real(), hi.imag()); + sharedmemAdaptor.template set >(adaptorOffset | writeIndex, toExchange); + + adaptorOffset += WorkgroupSize; + } + // Wait until all writes are done before reading + sharedmemAdaptor.workgroupExecutionAndMemoryBarrier(); + + // Read elements from sharedmem + adaptorOffset = 0; + [unroll] + for (uint32_t channel = lowChannel; channel <= highChannel; channel++) + { + vector exchanged; + sharedmemAdaptor.template get >(adaptorOffset | readIndex, exchanged); + complex_t complex_exchanged = { exchanged.x, exchanged.y }; + if (topHalf) + { + loAccessor.set(channel, complex_exchanged); + } + else + { + hiAccessor.set(channel, complex_exchanged); + } + + adaptorOffset += WorkgroupSize; + } + } + + static void __callInterleaved() + { + // TODO + } +}; + +} //namespace impl +} //namespace fft + +//-------------- ------------------------------------ END UTILS -------------------------------------------------------- + +namespace impl +{ +template +struct InnerFFT; + +// Non-interleaved (shuffle after every butterfly) inner (post thread-local butterflies) forward FFT +template +struct InnerFFT, device_capabilities> +{ + using consteval_parameters_t = fft::ConstevalParameters; + using scalar_t = typename consteval_parameters_t::scalar_t; + + template + static void FFT_loop(uint32_t stride, uint32_t threadID, NBL_REF_ARG(uint32_t) ownedSmemIndex, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor) + { + const uint32_t ShuffleRounds = consteval_parameters_t::ShuffleRounds; + const uint16_t VirtualChannels = consteval_parameters_t::InnerVirtualChannels; + // Get twiddle with k = threadID mod stride, halfN = stride + const complex_t twiddle = hlsl::fft::twiddle(threadID & (stride - 1), stride); + + bool pingPong = false; + // If register pressure high, can avoid unroll + [unroll] + for (uint32_t round = 0; round < ShuffleRounds; round++) + { + if (round) + pingPong = !pingPong; // ping pong on sharedmem to avoid barriering - this eploits that we XOR with the same stride every consecutive round + const uint32_t lowChannel = round * ShuffledVirtualChannelsPerRound; + const uint32_t highChannel = min(VirtualChannels, lowChannel + ShuffledVirtualChannelsPerRound) - 1; + [unroll] + for (uint32_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIF::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + + fft::impl::exchangeValues::__call(threadID, ownedSmemIndex, lowChannel, highChannel, loAccessor, hiAccessor, stride >> 1, sharedmemAdaptor, pingPong); + } + // After the last exchangeValues, the memory we just read from is now owned by us, so update + ownedSmemIndex = pingPong ? ownedSmemIndex : ownedSmemIndex ^ (stride >> 1); + } + + template + static void __call(NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor) + { + const uint16_t VirtualChannels = consteval_parameters_t::InnerVirtualChannels; + const uint16_t SubgroupSize = consteval_parameters_t::SubgroupSize; + const uint16_t WorkgroupSize = consteval_parameters_t::WorkgroupSize; + + // Get workgroup threadID + const uint32_t threadID = uint32_t(workgroup::SubgroupContiguousIndex()); + + // If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps + if (WorkgroupSize > SubgroupSize) + { + // Set up the memory adaptor + using adaptor_t = accessor_adaptors::StructureOfArrays; + adaptor_t sharedmemAdaptor; + sharedmemAdaptor.accessor = sharedmemAccessor; + + uint32_t ownedSmemIndex = threadID; + // NOT unrolling this loop increases register pressure???? + [unroll] + for (uint32_t stride = WorkgroupSize; stride > SubgroupSize; stride >>= 1) + { + FFT_loop(stride, threadID, ownedSmemIndex, loAccessor, hiAccessor, sharedmemAdaptor); + } + + // Remember to update the accessor's state + sharedmemAccessor = sharedmemAdaptor.accessor; + } + + const float32_t DivisionFactor = (DivisionPolicy == fft::DivisionPolicy::DivByFullSizeHalfway ? consteval_parameters_t::DivisionConstants::InvFFTLength + : (DivisionPolicy == fft::DivisionPolicy::DivBySqrtHalfway ? consteval_parameters_t::DivisionConstants::InvSqrtFFTLength + : (DivisionPolicy == fft::DivisionPolicy::DivByFullSizeByParts ? consteval_parameters_t::DivisionConstants::InvFFTLengthC + : consteval_parameters_t::DivisionConstants::InvSqrtFFTLengthC))); // Assume DivBySqrtByParts, won't be used otherwise + + if (DivisionPolicy == fft::DivisionPolicy::DivByFullSizeHalfway || DivisionPolicy == fft::DivisionPolicy::DivBySqrtHalfway || DivisionPolicy == fft::DivisionPolicy::DivByFullSizeByParts || DivisionPolicy == fft::DivisionPolicy::DivBySqrtByParts) + { + [unroll] + for (uint32_t channel = 0; channel < VirtualChannels; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + loAccessor.set(channel, lo * scalar_t(DivisionFactor)); + hiAccessor.set(channel, hi * scalar_t(DivisionFactor)); + } + } + + // Subgroup-sized FFT + subgroup2::FFT::template __call(0, VirtualChannels - 1, loAccessor, hiAccessor); + } +}; + +// Non-interleaved (shuffle after every butterfly) inner (post thread-local butterflies) inverse FFT +template +struct InnerFFT, device_capabilities> +{ + using consteval_parameters_t = fft::ConstevalParameters; + using scalar_t = typename consteval_parameters_t::scalar_t; + + template + static void FFT_loop(uint32_t stride, uint32_t threadID, NBL_REF_ARG(uint32_t) ownedSmemIndex, NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor) + { + const uint32_t ShuffleRounds = consteval_parameters_t::ShuffleRounds; + const uint16_t VirtualChannels = consteval_parameters_t::InnerVirtualChannels; + // Get twiddle with k = threadID mod stride, halfN = stride + const complex_t twiddle = hlsl::fft::twiddle(threadID & ((stride << 1) - 1), stride << 1); + + bool pingPong = false; + [unroll] + for (uint32_t round = 0; round < ShuffleRounds; round++) + { + if (round) + pingPong = !pingPong; // ping pong on sharedmem to avoid barriering - this eploits that we XOR with the same stride every consecutive round + const uint32_t lowChannel = round * ShuffledVirtualChannelsPerRound; + const uint32_t highChannel = min(VirtualChannels, lowChannel + ShuffledVirtualChannelsPerRound) - 1; + + fft::impl::exchangeValues::__call(threadID, ownedSmemIndex, lowChannel, highChannel, loAccessor, hiAccessor, stride, sharedmemAdaptor, pingPong); + + [unroll] + for (uint32_t channel = lowChannel; channel <= highChannel; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + fft2::DIT::radix2(twiddle, lo, hi); + loAccessor.set(channel, lo); + hiAccessor.set(channel, hi); + } + } + // After the last exchangeValues, the memory we just read from is now owned by us, so update + ownedSmemIndex = pingPong ? ownedSmemIndex : ownedSmemIndex ^ (stride >> 1); + } + + template + static void __call(NBL_REF_ARG(InvocationElementsAccessor) loAccessor, NBL_REF_ARG(InvocationElementsAccessor) hiAccessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor) + { + const uint16_t VirtualChannels = consteval_parameters_t::InnerVirtualChannels; + const uint16_t SubgroupSize = consteval_parameters_t::SubgroupSize; + const uint16_t WorkgroupSize = consteval_parameters_t::WorkgroupSize; + + // Subgroup-sized FFT at the start + subgroup2::FFT::template __call(0, VirtualChannels - 1, loAccessor, hiAccessor); + + const float32_t DivisionFactor = (DivisionPolicy == fft::DivisionPolicy::DivByFullSizeHalfway ? consteval_parameters_t::DivisionConstants::InvFFTLength + : (DivisionPolicy == fft::DivisionPolicy::DivBySqrtHalfway ? consteval_parameters_t::DivisionConstants::InvSqrtFFTLength + : (DivisionPolicy == fft::DivisionPolicy::DivByFullSizeByParts ? consteval_parameters_t::DivisionConstants::InvFFTLengthC + : consteval_parameters_t::DivisionConstants::InvSqrtFFTLengthC))); // Assume DivBySqrtByParts, won't be used otherwise + + if (DivisionPolicy == fft::DivisionPolicy::DivByFullSizeHalfway || DivisionPolicy == fft::DivisionPolicy::DivBySqrtHalfway || DivisionPolicy == fft::DivisionPolicy::DivByFullSizeByParts || DivisionPolicy == fft::DivisionPolicy::DivBySqrtByParts) + { + [unroll] + for (uint32_t channel = 0; channel < VirtualChannels; channel++) + { + complex_t lo, hi; + loAccessor.get(channel, lo); + hiAccessor.get(channel, hi); + loAccessor.set(channel, lo * scalar_t(DivisionFactor)); + hiAccessor.set(channel, hi * scalar_t(DivisionFactor)); + } + } + + // Get workgroup threadID + const uint32_t threadID = uint32_t(workgroup::SubgroupContiguousIndex()); + + // If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps + if (WorkgroupSize > SubgroupSize) + { + // Set up the memory adaptor + using adaptor_t = accessor_adaptors::StructureOfArrays; + adaptor_t sharedmemAdaptor; + sharedmemAdaptor.accessor = sharedmemAccessor; + + uint32_t ownedSmemIndex = threadID; + [unroll] + for (uint32_t stride = SubgroupSize; stride < WorkgroupSize; stride <<= 1) + { + FFT_loop(stride, threadID, ownedSmemIndex, loAccessor, hiAccessor, sharedmemAdaptor); + } + + // Remember to update the accessor's state + sharedmemAccessor = sharedmemAdaptor.accessor; + } + } +}; + +} //namespace impl + + +} //namespace workgroup2 +} //namespace hlsl +} //namespace nbl + +#endif + + +#endif diff --git a/include/nbl/video/ILogicalDevice.h b/include/nbl/video/ILogicalDevice.h index 756b417c79..742cb506c6 100644 --- a/include/nbl/video/ILogicalDevice.h +++ b/include/nbl/video/ILogicalDevice.h @@ -833,6 +833,9 @@ class NBL_API2 ILogicalDevice : public core::IReferenceCounted, public IDeviceMe std::span extraDefines = {}; hlsl::ShaderStage stage = hlsl::ShaderStage::ESS_ALL_OR_LIBRARY; core::bitflag debugInfoFlags = asset::IShaderCompiler::E_DEBUG_INFO_FLAGS::EDIF_NONE; + bool optimizerIsExtraPasses = false; // Instead of disabling the default opt passes, run the provided optimization passes at the end + std::string preprocessedOutputPath = ""; + std::string spvOutputPath = ""; }; core::smart_refctd_ptr compileShader(const SShaderCreationParameters& creationParams); diff --git a/src/nbl/asset/utils/CHLSLCompiler.cpp b/src/nbl/asset/utils/CHLSLCompiler.cpp index 6fec81c8cc..661edd18b8 100644 --- a/src/nbl/asset/utils/CHLSLCompiler.cpp +++ b/src/nbl/asset/utils/CHLSLCompiler.cpp @@ -546,6 +546,30 @@ core::smart_refctd_ptr CHLSLCompiler::compileToSPIRV_impl(const std::st auto newCode = preprocessShader(std::string(code), stage, hlslOptions.preprocessorOptions, dxc_compile_flags, dependencies); if (newCode.empty()) return nullptr; + if (!options.preprocessedOutputPath.empty()) + { + core::smart_refctd_ptr preprocessedFile; + + system::ISystem::future_t> future; + m_system->deleteFile(options.preprocessedOutputPath); + m_system->createFile(future, options.preprocessedOutputPath, system::IFile::ECF_WRITE); + if (future.wait()) + { + future.acquire().move_into(preprocessedFile); + if (preprocessedFile) + { + system::IFile::success_t succ; + preprocessedFile->write(succ, newCode.data(), 0, newCode.size()); + if (!succ) + logger.log("Failed Writing To Preprocessed Output File.", nbl::system::ILogger::ELL_ERROR); + } + else + logger.log("Failed Creating Preprocessed Output File.", nbl::system::ILogger::ELL_ERROR); + } + else + logger.log("Failed Creating Preprocessed Output File.", nbl::system::ILogger::ELL_ERROR); + } + // Suffix is the shader model version std::wstring targetProfile(SHADER_MODEL_PROFILE); @@ -584,13 +608,13 @@ core::smart_refctd_ptr CHLSLCompiler::compileToSPIRV_impl(const std::st // TODO: add entry point to `CHLSLCompiler::SOptions` and handle it properly in `dxc_compile_flags.empty()` arguments.push_back(L"main"); } - // If a custom SPIR-V optimizer is specified, use that instead of DXC's spirv-opt. + // If a custom SPIR-V optimizer is specified and set to replace default optimization passes, use that instead of DXC's spirv-opt. // This is how we can get more optimizer options. // // Optimization is also delegated to SPIRV-Tools. Right now there are no difference between // optimization levels greater than zero; they will all invoke the same optimization recipe. // https://github.com/Microsoft/DirectXShaderCompiler/blob/main/docs/SPIR-V.rst#optimization - if (hlslOptions.spirvOptimizer) + if (hlslOptions.spirvOptimizer && !hlslOptions.optimizerIsExtraPasses) arguments.push_back(L"-O0"); } // @@ -651,6 +675,30 @@ core::smart_refctd_ptr CHLSLCompiler::compileToSPIRV_impl(const std::st if (hlslOptions.spirvOptimizer) outSpirv = hlslOptions.spirvOptimizer->optimize(outSpirv.get(), logger); + if (!options.spvOutputPath.empty()) + { + core::smart_refctd_ptr spvFile; + + system::ISystem::future_t> future; + m_system->deleteFile(options.spvOutputPath); + m_system->createFile(future, options.spvOutputPath, system::IFile::ECF_WRITE); + if (future.wait()) + { + future.acquire().move_into(spvFile); + if (spvFile) + { + system::IFile::success_t succ; + spvFile->write(succ, outSpirv->getPointer(), 0, outSpirv->getSize()); + if (!succ) + logger.log("Failed Writing To SPIR-V Output File.", nbl::system::ILogger::ELL_ERROR); + } + else + logger.log("Failed Creating SPIR-V Output File.", nbl::system::ILogger::ELL_ERROR); + } + else + logger.log("Failed Creating SPIR-V Output File.", nbl::system::ILogger::ELL_ERROR); + } + return core::make_smart_refctd_ptr(std::move(outSpirv), IShader::E_CONTENT_TYPE::ECT_SPIRV, hlslOptions.preprocessorOptions.sourceIdentifier.data()); } diff --git a/src/nbl/video/ILogicalDevice.cpp b/src/nbl/video/ILogicalDevice.cpp index 1752717879..bee6381f7a 100644 --- a/src/nbl/video/ILogicalDevice.cpp +++ b/src/nbl/video/ILogicalDevice.cpp @@ -362,11 +362,15 @@ core::smart_refctd_ptr ILogicalDevice::compileShader(const SShad commonCompileOptions.stage = creationParams.stage; commonCompileOptions.debugInfoFlags = creationParams.debugInfoFlags; commonCompileOptions.spirvOptimizer = creationParams.optimizer; + commonCompileOptions.optimizerIsExtraPasses = creationParams.optimizerIsExtraPasses; commonCompileOptions.preprocessorOptions.targetSpirvVersion = m_physicalDevice->getLimits().spirvVersion; commonCompileOptions.readCache = creationParams.readCache; commonCompileOptions.writeCache = creationParams.writeCache; + commonCompileOptions.preprocessedOutputPath = creationParams.preprocessedOutputPath; + commonCompileOptions.spvOutputPath = creationParams.spvOutputPath; + if (sourceContent==asset::IShader::E_CONTENT_TYPE::ECT_HLSL) { // TODO: add specific HLSLCompiler::SOption params