diff --git a/CMakeLists.txt b/CMakeLists.txt index c91128d..01444ac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -148,6 +148,15 @@ if(PIXIE_TESTS) gtest gtest_main ${PIXIE_DIAGNOSTICS_LIBS}) + + add_executable(excess_positions_tests + src/tests/excess_positions_tests.cpp) + target_include_directories(excess_positions_tests + PUBLIC include) + target_link_libraries(excess_positions_tests + gtest + gtest_main + ${PIXIE_DIAGNOSTICS_LIBS}) endif() # --------------------------------------------------------------------------- @@ -200,6 +209,15 @@ if(PIXIE_BENCHMARKS) benchmark benchmark_main ${PIXIE_DIAGNOSTICS_LIBS}) + + add_executable(excess_positions_benchmarks + src/benchmarks/excess_positions_benchmarks.cpp) + target_include_directories(excess_positions_benchmarks + PUBLIC include) + target_link_libraries(excess_positions_benchmarks + benchmark + benchmark_main + ${PIXIE_DIAGNOSTICS_LIBS}) endif() # --------------------------------------------------------------------------- diff --git a/CMakePresets.json b/CMakePresets.json index dabbdb0..cae9c15 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -6,9 +6,17 @@ "patch": 0 }, "configurePresets": [ + { + "name": "base", + "hidden": true, + "cacheVariables": { + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON" + } + }, { "name": "debug", "displayName": "Debug", + "inherits": "base", "binaryDir": "${sourceDir}/build/debug", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" @@ -17,6 +25,7 @@ { "name": "release", "displayName": "Release", + "inherits": "base", "binaryDir": "${sourceDir}/build/release", "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" @@ -25,6 +34,7 @@ { "name": "benchmarks-all", "displayName": "Benchmarks", + "inherits": "base", "binaryDir": "${sourceDir}/build/release", "cacheVariables": { "CMAKE_BUILD_TYPE": "Release", @@ -34,6 +44,7 @@ { "name": "benchmarks-diagnostic", "displayName": "Benchmarks diagnostic build", + "inherits": "base", "binaryDir": "${sourceDir}/build/release-with-deb", "cacheVariables": { "BENCHMARK_ENABLE_LIBPFM": "ON", @@ -45,6 +56,7 @@ { "name": "docs", "displayName": "Docs", + "inherits": "base", "binaryDir": "${sourceDir}/build/docs", "cacheVariables": { "CMAKE_BUILD_TYPE": "Release", @@ -54,6 +66,7 @@ { "name": "coverage", "displayName": "Coverage", + "inherits": "base", "binaryDir": "${sourceDir}/build/coverage", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", @@ -64,6 +77,7 @@ { "name": "asan", "displayName": "AddressSanitizer", + "inherits": "base", "binaryDir": "${sourceDir}/build/asan", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", diff --git a/include/pixie/bits.h b/include/pixie/bits.h index d1a170b..e8826a2 100644 --- a/include/pixie/bits.h +++ b/include/pixie/bits.h @@ -632,6 +632,231 @@ void popcount_32x8(const uint8_t* x, uint8_t* result) { * @param * result Pointer to store the resulting 32 8-bit integers */ +#ifdef PIXIE_AVX2_SUPPORT +static inline __m256i excess_bit_masks_16x() noexcept { + return _mm256_setr_epi16(0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, + 0x0040, 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, + 0x1000, 0x2000, 0x4000, (int16_t)0x8000); +} + +static inline __m256i excess_prefix_sum_16x_i16(__m256i v) noexcept { + __m256i x = v; + __m256i t = _mm256_slli_si256(x, 2); + x = _mm256_add_epi16(x, t); + t = _mm256_slli_si256(x, 4); + x = _mm256_add_epi16(x, t); + t = _mm256_slli_si256(x, 8); + x = _mm256_add_epi16(x, t); + + __m128i lo = _mm256_extracti128_si256(x, 0); + __m128i hi = _mm256_extracti128_si256(x, 1); + const int16_t carry = (int16_t)_mm_extract_epi16(lo, 7); + hi = _mm_add_epi16(hi, _mm_set1_epi16(carry)); + + __m256i out = _mm256_castsi128_si256(lo); + out = _mm256_inserti128_si256(out, hi, 1); + return out; +} + +static inline int16_t excess_last_prefix_16x_i16(__m256i pref) noexcept { + __m128i hi = _mm256_extracti128_si256(pref, 1); + return (int16_t)_mm_extract_epi16(hi, 7); +} +#endif + +/** + * @brief Find every prefix whose excess equals target_x in a 512-bit bitstring. + * + * Excess(i) = 2*popcount(bits[0..i-1]) - i for i in [0..512]. + * Bit (w*64 + b) of out[w] is set iff excess(w*64 + b + 1) == target_x. + * I.e. out bit index b corresponds to prefix length (b+1). + * + * @param s 8 little-endian uint64_t words (bit 0 of s[0] is the first bit). + * @param target_x Target excess value in [-512, 512]; outside this range out is + * zeroed. + * @param out 8 uint64_t words receiving the result bitmask. + */ +static inline void excess_positions_512(const uint64_t* s, + int target_x, + uint64_t* out) noexcept { + out[0] = out[1] = out[2] = out[3] = 0; + out[4] = out[5] = out[6] = out[7] = 0; + + if (target_x < -512 || target_x > 512) { + return; + } + +#ifdef PIXIE_AVX2_SUPPORT + static const __m256i masks = excess_bit_masks_16x(); + static const __m256i vzero = _mm256_setzero_si256(); + static const __m256i vallones = _mm256_cmpeq_epi16(vzero, vzero); + static const __m256i vminus1 = _mm256_set1_epi16(-1); + static const __m256i vtwo = _mm256_set1_epi16(2); + const __m256i vtarget = _mm256_set1_epi16((int16_t)target_x); + + int cur = 0; + for (int k = 0; k < 32; ++k) { + const size_t bit_off = size_t(k) * 16; + const size_t word_idx = bit_off >> 6; + const size_t shift = bit_off & 63; + + uint16_t bits16 = (uint16_t)((s[word_idx] >> shift) & 0xFFFFull); + if (shift > 48 && word_idx + 1 < 8) { + bits16 |= (uint16_t)(s[word_idx + 1] << (64 - shift)); + } + + const __m256i vb = _mm256_set1_epi16((int16_t)bits16); + const __m256i m = _mm256_and_si256(vb, masks); + const __m256i is_zero = _mm256_cmpeq_epi16(m, vzero); + const __m256i is_set = _mm256_andnot_si256(is_zero, vallones); + const __m256i steps = + _mm256_add_epi16(vminus1, _mm256_and_si256(is_set, vtwo)); + + const __m256i pref_rel = excess_prefix_sum_16x_i16(steps); + const __m256i base = _mm256_set1_epi16((int16_t)cur); + const __m256i pref_abs = _mm256_add_epi16(pref_rel, base); + const __m256i cmp = _mm256_cmpeq_epi16(pref_abs, vtarget); + + const uint32_t m32 = (uint32_t)_mm256_movemask_epi8(cmp); + const uint16_t m16 = (uint16_t)_pext_u32(m32, 0xAAAAAAAAu); + + const size_t out_word = bit_off >> 6; + const size_t out_shift = bit_off & 63; + out[out_word] |= uint64_t(m16) << out_shift; + if (out_shift > 48 && out_word + 1 < 8) { + out[out_word + 1] |= uint64_t(m16) >> (64 - out_shift); + } + + cur += (int)excess_last_prefix_16x_i16(pref_rel); + } +#else + int cur = 0; + for (size_t i = 0; i < 512; ++i) { + const uint64_t w = s[i >> 6]; + const int bit = int((w >> (i & 63)) & 1ull); + cur += bit ? +1 : -1; + if (cur == target_x) { + out[i >> 6] |= (uint64_t{1} << (i & 63)); + } + } +#endif +} + +#ifdef PIXIE_AVX2_SUPPORT +static inline __m128i excess_nibble_delta_lut() noexcept { + alignas(16) static const int8_t lut[16] = {-4, -2, -2, 0, -2, 0, 0, 2, + -2, 0, 0, 2, 0, 2, 2, 4}; + return _mm_load_si128((const __m128i*)lut); +} + +static inline __m128i excess_nibble_pos0_lut() noexcept { + alignas(16) static const int8_t lut[16] = {-1, 1, -1, 1, -1, 1, -1, 1, + -1, 1, -1, 1, -1, 1, -1, 1}; + return _mm_load_si128((const __m128i*)lut); +} + +static inline __m128i excess_nibble_pos1_lut() noexcept { + alignas(16) static const int8_t lut[16] = {-2, 0, 0, 2, -2, 0, 0, 2, + -2, 0, 0, 2, -2, 0, 0, 2}; + return _mm_load_si128((const __m128i*)lut); +} + +static inline __m128i excess_nibble_pos2_lut() noexcept { + alignas(16) static const int8_t lut[16] = {-3, -1, -1, 1, -1, 1, 1, 3, + -3, -1, -1, 1, -1, 1, 1, 3}; + return _mm_load_si128((const __m128i*)lut); +} +#endif + +static inline void excess_positions_512_lut(const uint64_t* s, + int target_x, + uint64_t* out) noexcept { + out[0] = out[1] = out[2] = out[3] = 0; + out[4] = out[5] = out[6] = out[7] = 0; + + if (target_x < -512 || target_x > 512) { + return; + } + +#ifdef PIXIE_AVX2_SUPPORT + const __m128i vdelta = excess_nibble_delta_lut(); + const __m128i vpos0 = excess_nibble_pos0_lut(); + const __m128i vpos1 = excess_nibble_pos1_lut(); + const __m128i vpos2 = excess_nibble_pos2_lut(); + const __m128i vnibble_mask = _mm_set1_epi8(0x0F); + const __m128i vzero = _mm_setzero_si128(); + + int cur = 0; + for (int w = 0; w < 8; ++w) { + const uint64_t word = s[w]; + const int word_delta = 2 * (int)std::popcount(word) - 64; + const int target_local = target_x - cur; + + if (target_local < -64 || target_local > 64) { + out[w] = 0; + cur += word_delta; + continue; + } + + __m128i bytes = _mm_cvtsi64_si128(static_cast(word)); + __m128i lo = _mm_and_si128(bytes, vnibble_mask); + __m128i hi = _mm_and_si128(_mm_srli_epi16(bytes, 4), vnibble_mask); + __m128i nibbles = _mm_unpacklo_epi8(lo, hi); + + __m128i deltas = _mm_shuffle_epi8(vdelta, nibbles); + + __m128i ps = deltas; + ps = _mm_add_epi8(ps, _mm_slli_si128(ps, 1)); + ps = _mm_add_epi8(ps, _mm_slli_si128(ps, 2)); + ps = _mm_add_epi8(ps, _mm_slli_si128(ps, 4)); + ps = _mm_add_epi8(ps, _mm_slli_si128(ps, 8)); + + __m128i excl = _mm_slli_si128(ps, 1); + + __m128i vtarget_local = _mm_set1_epi8(static_cast(target_local)); + // Overflow safety: excl[i] ∈ [0, 60], target_local ∈ [-64, 64]. + // base = excl - target_local ∈ [-124, 124], fits in int8. + // base + pos_j ∈ [-128, 128]. The boundary value 128 wraps to -128 in + // int8, but -128 ≠ 0 so cmpeq_epi8 produces no false positive. + // The value -128 is exactly representable and also ≠ 0. + __m128i base = _mm_sub_epi8(excl, vtarget_local); + + __m128i cmp0 = _mm_cmpeq_epi8( + _mm_add_epi8(base, _mm_shuffle_epi8(vpos0, nibbles)), vzero); + uint16_t bits0 = static_cast(_mm_movemask_epi8(cmp0)); + + __m128i cmp1 = _mm_cmpeq_epi8( + _mm_add_epi8(base, _mm_shuffle_epi8(vpos1, nibbles)), vzero); + uint16_t bits1 = static_cast(_mm_movemask_epi8(cmp1)); + + __m128i cmp2 = _mm_cmpeq_epi8( + _mm_add_epi8(base, _mm_shuffle_epi8(vpos2, nibbles)), vzero); + uint16_t bits2 = static_cast(_mm_movemask_epi8(cmp2)); + + __m128i cmp3 = _mm_cmpeq_epi8( + _mm_add_epi8(base, _mm_shuffle_epi8(vdelta, nibbles)), vzero); + uint16_t bits3 = static_cast(_mm_movemask_epi8(cmp3)); + + out[w] = _pdep_u64(bits0, 0x1111111111111111ULL) | + _pdep_u64(bits1, 0x2222222222222222ULL) | + _pdep_u64(bits2, 0x4444444444444444ULL) | + _pdep_u64(bits3, 0x8888888888888888ULL); + + cur += word_delta; + } +#else + int cur = 0; + for (size_t i = 0; i < 512; ++i) { + const uint64_t w = s[i >> 6]; + const int bit = int((w >> (i & 63)) & 1ull); + cur += bit ? +1 : -1; + if (cur == target_x) { + out[i >> 6] |= (uint64_t{1} << (i & 63)); + } + } +#endif +} + void rank_32x8(const uint8_t* x, uint8_t* result) { #ifdef PIXIE_AVX512_SUPPORT // Step 1: Calculate popcount of each byte diff --git a/scripts/coverage_report.sh b/scripts/coverage_report.sh index 0820856..d1fdd1b 100755 --- a/scripts/coverage_report.sh +++ b/scripts/coverage_report.sh @@ -8,6 +8,7 @@ cmake --preset coverage cmake --build --preset coverage "${BUILD_DIR}/unittests" +"${BUILD_DIR}/excess_positions_tests" "${BUILD_DIR}/louds_tree_tests" "${BUILD_DIR}/test_rmm" diff --git a/src/benchmarks/excess_positions_benchmarks.cpp b/src/benchmarks/excess_positions_benchmarks.cpp new file mode 100644 index 0000000..cf52f2b --- /dev/null +++ b/src/benchmarks/excess_positions_benchmarks.cpp @@ -0,0 +1,108 @@ +#include +#include + +#include +#include +#include +#include +#include + +static std::vector> make_blocks( + size_t num_blocks = 4096) { + std::mt19937_64 rng(42); + std::vector> blocks(num_blocks); + for (auto& b : blocks) { + for (auto& w : b) { + w = rng(); + } + } + return blocks; +} + +static void BM_ExcessPositions512(benchmark::State& state) { + const int target_x = state.range(0); + const auto blocks = make_blocks(); + const size_t num_blocks = blocks.size(); + + alignas(64) uint64_t out[8]; + size_t idx = 0; + + for (auto _ : state) { + const auto& s = blocks[idx % num_blocks]; + excess_positions_512(s.data(), target_x, out); + benchmark::DoNotOptimize(out); + ++idx; + } + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_ExcessPositions512) + ->ArgNames({"X"}) + ->Args({-64}) + ->Args({-8}) + ->Args({0}) + ->Args({8}) + ->Args({64}); + +static void BM_ExcessPositions512_Scalar(benchmark::State& state) { + const int target_x = state.range(0); + const auto blocks = make_blocks(); + const size_t num_blocks = blocks.size(); + + alignas(64) uint64_t out[8]; + size_t idx = 0; + + for (auto _ : state) { + const auto& s = blocks[idx % num_blocks]; + for (int w = 0; w < 8; ++w) { + out[w] = 0; + } + int cur = 0; + for (size_t i = 0; i < 512; ++i) { + const int bit = int((s[i >> 6] >> (i & 63)) & 1ull); + cur += bit ? +1 : -1; + if (cur == target_x) { + out[i >> 6] |= (uint64_t{1} << (i & 63)); + } + } + benchmark::DoNotOptimize(out); + ++idx; + } + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_ExcessPositions512_Scalar) + ->ArgNames({"X"}) + ->Args({-64}) + ->Args({-8}) + ->Args({0}) + ->Args({8}) + ->Args({64}); + +static void BM_ExcessPositions512_LUT(benchmark::State& state) { + const int target_x = state.range(0); + const auto blocks = make_blocks(); + const size_t num_blocks = blocks.size(); + + alignas(64) uint64_t out[8]; + size_t idx = 0; + + for (auto _ : state) { + const auto& s = blocks[idx % num_blocks]; + excess_positions_512_lut(s.data(), target_x, out); + benchmark::DoNotOptimize(out); + ++idx; + } + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_ExcessPositions512_LUT) + ->ArgNames({"X"}) + ->Args({-64}) + ->Args({-8}) + ->Args({0}) + ->Args({8}) + ->Args({64}); diff --git a/src/tests/excess_positions_tests.cpp b/src/tests/excess_positions_tests.cpp new file mode 100644 index 0000000..f9ed763 --- /dev/null +++ b/src/tests/excess_positions_tests.cpp @@ -0,0 +1,359 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +static void naive_excess_positions_512(const uint64_t* s, + int target_x, + uint64_t* out) { + for (int w = 0; w < 8; ++w) { + out[w] = 0; + } + if (target_x < -512 || target_x > 512) { + return; + } + int cur = 0; + for (size_t i = 0; i < 512; ++i) { + const uint64_t w = s[i >> 6]; + const int bit = int((w >> (i & 63)) & 1ull); + cur += bit ? +1 : -1; + if (cur == target_x) { + out[i >> 6] |= (uint64_t{1} << (i & 63)); + } + } +} + +static size_t count_matches(const uint64_t* out) { + size_t cnt = 0; + for (int w = 0; w < 8; ++w) { + cnt += std::popcount(out[w]); + } + return cnt; +} + +template +static void check_matches_naive(Fn fn, + const char* fn_name, + const uint64_t* s, + int target_x, + int case_id = 0) { + alignas(64) uint64_t out[8]; + alignas(64) uint64_t ref[8]; + fn(s, target_x, out); + naive_excess_positions_512(s, target_x, ref); + for (int w = 0; w < 8; ++w) { + ASSERT_EQ(out[w], ref[w]) + << fn_name << " case=" << case_id << " x=" << target_x << " word=" << w; + } + ASSERT_EQ(count_matches(out), count_matches(ref)) + << fn_name << " case=" << case_id << " x=" << target_x; +} + +TEST(ExcessPositions512, AllZeros) { + alignas(64) uint64_t s[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + alignas(64) uint64_t out[8]; + alignas(64) uint64_t ref[8]; + + for (int x = -8; x <= 0; ++x) { + excess_positions_512(s, x, out); + naive_excess_positions_512(s, x, ref); + for (int w = 0; w < 8; ++w) { + EXPECT_EQ(out[w], ref[w]) << "x=" << x << " word=" << w; + } + } + + excess_positions_512(s, 1, out); + for (int w = 0; w < 8; ++w) { + EXPECT_EQ(out[w], 0u); + } +} + +TEST(ExcessPositions512, AllOnes) { + alignas(64) uint64_t s[8] = {UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, + UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX}; + alignas(64) uint64_t out[8]; + alignas(64) uint64_t ref[8]; + + for (int x = 0; x <= 8; ++x) { + excess_positions_512(s, x, out); + naive_excess_positions_512(s, x, ref); + for (int w = 0; w < 8; ++w) { + EXPECT_EQ(out[w], ref[w]) << "x=" << x << " word=" << w; + } + } + + excess_positions_512(s, -1, out); + for (int w = 0; w < 8; ++w) { + EXPECT_EQ(out[w], 0u); + } +} + +TEST(ExcessPositions512, Alternating) { + alignas(64) uint64_t s[8]; + for (int w = 0; w < 8; ++w) { + s[w] = 0xAAAAAAAAAAAAAAAAull; + } + alignas(64) uint64_t out[8]; + alignas(64) uint64_t ref[8]; + + for (int x = -2; x <= 2; ++x) { + excess_positions_512(s, x, out); + naive_excess_positions_512(s, x, ref); + for (int w = 0; w < 8; ++w) { + EXPECT_EQ(out[w], ref[w]) << "x=" << x << " word=" << w; + } + } +} + +TEST(ExcessPositions512, OutOfRange) { + alignas(64) uint64_t s[8] = {UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, + UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX}; + alignas(64) uint64_t out[8]; + excess_positions_512(s, 513, out); + for (int w = 0; w < 8; ++w) { + EXPECT_EQ(out[w], 0u); + } + excess_positions_512(s, -513, out); + for (int w = 0; w < 8; ++w) { + EXPECT_EQ(out[w], 0u); + } +} + +TEST(ExcessPositions512, ExhaustiveSmall16) { + alignas(64) uint64_t s[8]; + alignas(64) uint64_t out[8]; + alignas(64) uint64_t ref[8]; + + for (uint64_t pattern = 0; pattern < (1ull << 16); ++pattern) { + s[0] = pattern; + for (int w = 1; w < 8; ++w) { + s[w] = 0; + } + for (int x = -20; x <= 20; ++x) { + excess_positions_512(s, x, out); + naive_excess_positions_512(s, x, ref); + for (int w = 0; w < 8; ++w) { + ASSERT_EQ(out[w], ref[w]) + << "pattern=" << pattern << " x=" << x << " word=" << w; + } + } + } +} + +TEST(ExcessPositions512, Random) { + const int cases = [] { + const char* env = std::getenv("EXCESS_POS_CASES"); + return env ? std::atoi(env) : 1000; + }(); + const uint64_t seed = [] { + const char* env = std::getenv("EXCESS_POS_SEED"); + return env ? std::stoull(env) : 42ull; + }(); + + std::mt19937_64 rng(static_cast(seed)); + std::uniform_int_distribution x_dist(-512, 512); + + alignas(64) uint64_t s[8]; + alignas(64) uint64_t out[8]; + alignas(64) uint64_t ref[8]; + + for (int t = 0; t < cases; ++t) { + for (int w = 0; w < 8; ++w) { + s[w] = rng(); + } + const int x = x_dist(rng); + + excess_positions_512(s, x, out); + naive_excess_positions_512(s, x, ref); + + for (int w = 0; w < 8; ++w) { + ASSERT_EQ(out[w], ref[w]) << "case=" << t << " x=" << x << " word=" << w; + } + + ASSERT_EQ(count_matches(out), count_matches(ref)) + << "case=" << t << " x=" << x; + } +} + +TEST(ExcessPositions512, TargetZero) { + const uint64_t seed = 12345; + std::mt19937_64 rng(seed); + + alignas(64) uint64_t s[8]; + alignas(64) uint64_t out[8]; + alignas(64) uint64_t ref[8]; + + for (int t = 0; t < 500; ++t) { + for (int w = 0; w < 8; ++w) { + s[w] = rng(); + } + excess_positions_512(s, 0, out); + naive_excess_positions_512(s, 0, ref); + for (int w = 0; w < 8; ++w) { + ASSERT_EQ(out[w], ref[w]) << "case=" << t << " word=" << w; + } + } +} + +TEST(ExcessPositions512LUT, AllZeros) { + alignas(64) uint64_t s[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + + for (int x = -8; x <= 0; ++x) { + check_matches_naive(excess_positions_512_lut, "lut", s, x); + } + + alignas(64) uint64_t out[8]; + excess_positions_512_lut(s, 1, out); + for (int w = 0; w < 8; ++w) { + EXPECT_EQ(out[w], 0u); + } +} + +TEST(ExcessPositions512LUT, AllOnes) { + alignas(64) uint64_t s[8] = {UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, + UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX}; + + for (int x = 0; x <= 8; ++x) { + check_matches_naive(excess_positions_512_lut, "lut", s, x); + } + + alignas(64) uint64_t out[8]; + excess_positions_512_lut(s, -1, out); + for (int w = 0; w < 8; ++w) { + EXPECT_EQ(out[w], 0u); + } +} + +TEST(ExcessPositions512LUT, Alternating) { + alignas(64) uint64_t s[8]; + for (int w = 0; w < 8; ++w) { + s[w] = 0xAAAAAAAAAAAAAAAAull; + } + + for (int x = -2; x <= 2; ++x) { + check_matches_naive(excess_positions_512_lut, "lut", s, x); + } +} + +TEST(ExcessPositions512LUT, OutOfRange) { + alignas(64) uint64_t s[8] = {UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, + UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX}; + alignas(64) uint64_t out[8]; + excess_positions_512_lut(s, 513, out); + for (int w = 0; w < 8; ++w) { + EXPECT_EQ(out[w], 0u); + } + excess_positions_512_lut(s, -513, out); + for (int w = 0; w < 8; ++w) { + EXPECT_EQ(out[w], 0u); + } +} + +TEST(ExcessPositions512LUT, ExhaustiveSmall16) { + alignas(64) uint64_t s[8]; + + for (uint64_t pattern = 0; pattern < (1ull << 16); ++pattern) { + s[0] = pattern; + for (int w = 1; w < 8; ++w) { + s[w] = 0; + } + for (int x = -20; x <= 20; ++x) { + check_matches_naive(excess_positions_512_lut, "lut", s, x, + static_cast(pattern)); + } + } +} + +TEST(ExcessPositions512LUT, Random) { + const int cases = [] { + const char* env = std::getenv("EXCESS_POS_CASES"); + return env ? std::atoi(env) : 1000; + }(); + const uint64_t seed = [] { + const char* env = std::getenv("EXCESS_POS_SEED"); + return env ? std::stoull(env) : 42ull; + }(); + + std::mt19937_64 rng(static_cast(seed)); + std::uniform_int_distribution x_dist(-512, 512); + + alignas(64) uint64_t s[8]; + + for (int t = 0; t < cases; ++t) { + for (int w = 0; w < 8; ++w) { + s[w] = rng(); + } + const int x = x_dist(rng); + check_matches_naive(excess_positions_512_lut, "lut", s, x, t); + } +} + +TEST(ExcessPositions512LUT, TargetZero) { + const uint64_t seed = 12345; + std::mt19937_64 rng(seed); + + alignas(64) uint64_t s[8]; + + for (int t = 0; t < 500; ++t) { + for (int w = 0; w < 8; ++w) { + s[w] = rng(); + } + check_matches_naive(excess_positions_512_lut, "lut", s, 0, t); + } +} + +TEST(ExcessPositions512LUT, MatchesExpand) { + const int cases = 500; + std::mt19937_64 rng(99999); + std::uniform_int_distribution x_dist(-512, 512); + + alignas(64) uint64_t s[8]; + alignas(64) uint64_t out_expand[8]; + alignas(64) uint64_t out_lut[8]; + + for (int t = 0; t < cases; ++t) { + for (int w = 0; w < 8; ++w) { + s[w] = rng(); + } + const int x = x_dist(rng); + + excess_positions_512(s, x, out_expand); + excess_positions_512_lut(s, x, out_lut); + + for (int w = 0; w < 8; ++w) { + ASSERT_EQ(out_expand[w], out_lut[w]) + << "case=" << t << " x=" << x << " word=" << w; + } + } +} + +TEST(ExcessPositions512LUT, OverflowBoundary) { + alignas(64) uint64_t s[8]; + alignas(64) uint64_t out_expand[8]; + alignas(64) uint64_t out_lut[8]; + + for (int x = -64; x <= 64; ++x) { + for (uint64_t hi_pattern = 0; hi_pattern < 256; ++hi_pattern) { + for (int fill = 0; fill <= 7; ++fill) { + for (int w = 0; w < 8; ++w) { + s[w] = (w < fill) ? UINT64_MAX : 0; + } + s[fill] = hi_pattern; + + excess_positions_512(s, x, out_expand); + excess_positions_512_lut(s, x, out_lut); + + for (int w = 0; w < 8; ++w) { + ASSERT_EQ(out_expand[w], out_lut[w]) + << "x=" << x << " fill=" << fill << " hi=" << hi_pattern + << " word=" << w; + } + } + } + } +}