From 595ccb9005dc29bc24f4d2d4eb0e7d4ccf988137 Mon Sep 17 00:00:00 2001 From: Lachlan Deakin Date: Sun, 3 May 2026 13:37:57 +1000 Subject: [PATCH] feat: add next_up and next_down --- CHANGELOG.md | 8 ++ src/bits.rs | 102 +++++++++++++++++-- src/formats.rs | 10 ++ src/micro.rs | 10 +- tests/format_edges.rs | 221 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 341 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e402cb0..0aebcf7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased](https://github.com/LDeakin/microfloat/compare/v0.1.2...HEAD) +### Added + +- Add const `next_up()` and `next_down()` methods to all float types + +### Changed + +- `classify_bits`, `is_infinity_bits` are now `const` + ## [0.1.2](https://github.com/LDeakin/microfloat/releases/tag/v0.1.2) - 2026-04-29 ### Added diff --git a/src/bits.rs b/src/bits.rs index 27f695e..76c1685 100644 --- a/src/bits.rs +++ b/src/bits.rs @@ -12,7 +12,7 @@ pub struct Class { pub is_nan: bool, } -pub fn classify_bits(bits: u8) -> Class { +pub const fn classify_bits(bits: u8) -> Class { if is_nan_bits::(bits) { return Class { is_zero: false, @@ -29,13 +29,16 @@ pub fn classify_bits(bits: u8) -> Class { is_nan: false, }; } - if F::ZERO == ZeroMode::None { - return Class { - is_zero: false, - is_subnormal: false, - is_infinite: false, - is_nan: false, - }; + match F::ZERO { + ZeroMode::None => { + return Class { + is_zero: false, + is_subnormal: false, + is_infinite: false, + is_nan: false, + }; + } + ZeroMode::Signed | ZeroMode::Unsigned => {} } let mag = magnitude_bits::(bits); @@ -174,7 +177,7 @@ pub const fn is_nan_bits(bits: u8) -> bool { } } -pub fn is_infinity_bits(bits: u8) -> bool { +pub const fn is_infinity_bits(bits: u8) -> bool { F::HAS_INF && exponent_field::(bits) == F::MAX_EXPONENT_FIELD && mantissa_field::(bits) == 0 @@ -208,6 +211,87 @@ pub const fn abs_bits(bits: u8) -> u8 { } } +pub const fn next_up_bits(bits: u8) -> u8 { + let bits = bits & F::STORAGE_MASK; + let class = classify_bits::(bits); + if class.is_nan { + return bits; + } + if class.is_infinite { + return if is_negative_bits::(bits) { + max_finite_bits::(true) + } else { + bits + }; + } + if class.is_zero { + return 1; + } + match F::SIGN { + SignMode::Unsigned => { + return if bits == max_finite_bits::(false) { + infinity_bits::(false) + } else { + bits + 1 + }; + } + SignMode::Signed => {} + } + if is_negative_bits::(bits) { + if magnitude_bits::(bits) == 1 { + neg_zero_bits::() + } else { + bits - 1 + } + } else if bits == max_finite_bits::(false) { + infinity_bits::(false) + } else { + bits + 1 + } +} + +pub const fn next_down_bits(bits: u8) -> u8 { + let bits = bits & F::STORAGE_MASK; + let class = classify_bits::(bits); + if class.is_nan { + return bits; + } + if class.is_infinite { + return if is_negative_bits::(bits) { + bits + } else { + max_finite_bits::(false) + }; + } + if class.is_zero { + return match F::SIGN { + SignMode::Signed => negate_bits::(1), + SignMode::Unsigned => infinity_bits::(true), + }; + } + match F::SIGN { + SignMode::Unsigned => { + return if bits == 0 { + infinity_bits::(true) + } else { + bits - 1 + }; + } + SignMode::Signed => {} + } + if is_negative_bits::(bits) { + if bits == max_finite_bits::(true) { + infinity_bits::(true) + } else { + bits + 1 + } + } else if magnitude_bits::(bits) == 1 { + F::ZERO_BITS + } else { + bits - 1 + } +} + pub fn decode_f32(bits: u8) -> f32 { if F::ZERO == ZeroMode::None { return if bits == 0xff { diff --git a/src/formats.rs b/src/formats.rs index 6d03ef7..4a6e4b6 100644 --- a/src/formats.rs +++ b/src/formats.rs @@ -370,6 +370,16 @@ macro_rules! define_format { Self(self.0.abs()) } + /// Returns the least representable value greater than `self`. + pub const fn next_up(self) -> Self { + Self(self.0.next_up()) + } + + /// Returns the greatest representable value less than `self`. + pub const fn next_down(self) -> Self { + Self(self.0.next_down()) + } + /// Returns the greatest integer less than or equal to `self`, rounded to this format. pub fn floor(self) -> Self { Self(self.0.floor()) diff --git a/src/micro.rs b/src/micro.rs index e14335d..e153e44 100644 --- a/src/micro.rs +++ b/src/micro.rs @@ -4,7 +4,7 @@ use core::num::FpCategory; use crate::bits::{ abs_bits, classify_bits, decode_f32, encode_f32, infinity_bits, nan_bits, neg_zero_bits, - negate_bits, one_bits, total_key, + negate_bits, next_down_bits, next_up_bits, one_bits, total_key, }; use crate::format::{Format, NanEncoding, SignMode}; @@ -156,6 +156,14 @@ impl MicroFloat { Self::from_bits(abs_bits::(self.bits)) } + pub const fn next_up(self) -> Self { + Self::from_bits(next_up_bits::(self.bits)) + } + + pub const fn next_down(self) -> Self { + Self::from_bits(next_down_bits::(self.bits)) + } + pub fn floor(self) -> Self { unary_result(self, libm::floorf(self.to_f32())) } diff --git a/tests/format_edges.rs b/tests/format_edges.rs index 9b1fd66..c198175 100644 --- a/tests/format_edges.rs +++ b/tests/format_edges.rs @@ -3,6 +3,67 @@ use microfloat::{ f8e5m2fnuz, f8e8m0fnu, }; +trait StepFloat: Copy + core::fmt::Debug + PartialEq { + const INFINITY: Self; + const NEG_INFINITY: Self; + + fn from_bits(bits: u8) -> Self; + fn to_bits(self) -> u8; + fn to_f32(self) -> f32; + fn is_nan(self) -> bool; + fn next_up(self) -> Self; + fn next_down(self) -> Self; +} + +macro_rules! impl_step_float { + ($($type:ty),* $(,)?) => { + $( + impl StepFloat for $type { + const INFINITY: Self = <$type>::INFINITY; + const NEG_INFINITY: Self = <$type>::NEG_INFINITY; + + fn from_bits(bits: u8) -> Self { + <$type>::from_bits(bits) + } + + fn to_bits(self) -> u8 { + <$type>::to_bits(self) + } + + fn to_f32(self) -> f32 { + <$type>::to_f32(self) + } + + fn is_nan(self) -> bool { + <$type>::is_nan(self) + } + + fn next_up(self) -> Self { + <$type>::next_up(self) + } + + fn next_down(self) -> Self { + <$type>::next_down(self) + } + } + )* + }; +} + +impl_step_float!( + f8e3m4, + f8e4m3, + f8e4m3b11fnuz, + f8e4m3fn, + f8e4m3fnuz, + f8e5m2, + f8e5m2fnuz, + f8e8m0fnu, + f4e2m1fn, + f6e2m3fn, + f6e3m2fn, +); + #[test] fn special_constants_cover_format_modes() { assert_eq!(f8e8m0fnu::ONE.to_bits(), 0x7f); @@ -88,3 +149,163 @@ fn format_query_helpers() { assert!(!f6e3m2fn::has_nan()); assert!(f6e3m2fn::is_finite_only()); } + +#[test] +fn next_methods_cover_every_canonical_value() { + assert_next_methods::("f8e3m4", 0xff); + assert_next_methods::("f8e4m3", 0xff); + assert_next_methods::("f8e4m3b11fnuz", 0xff); + assert_next_methods::("f8e4m3fn", 0xff); + assert_next_methods::("f8e4m3fnuz", 0xff); + assert_next_methods::("f8e5m2", 0xff); + assert_next_methods::("f8e5m2fnuz", 0xff); + assert_next_methods::("f8e8m0fnu", 0xff); + assert_next_methods::("f4e2m1fn", 0x0f); + assert_next_methods::("f6e2m3fn", 0x3f); + assert_next_methods::("f6e3m2fn", 0x3f); +} + +#[test] +fn next_methods_cover_named_edges() { + assert_eq!(f8e4m3::NEG_ZERO.next_up().to_bits(), 0x01); + assert_eq!(f8e4m3::ZERO.next_down().to_bits(), 0x81); + assert_eq!( + f8e4m3::INFINITY.next_down().to_bits(), + f8e4m3::MAX.to_bits() + ); + assert_eq!( + f8e4m3::NEG_INFINITY.next_up().to_bits(), + f8e4m3::MIN.to_bits() + ); + assert_eq!( + f8e4m3::INFINITY.next_up().to_bits(), + f8e4m3::INFINITY.to_bits() + ); + assert_eq!( + f8e4m3::NEG_INFINITY.next_down().to_bits(), + f8e4m3::NEG_INFINITY.to_bits() + ); + assert_eq!(f8e4m3::NAN.next_up().to_bits(), f8e4m3::NAN.to_bits()); + assert_eq!(f8e4m3::NAN.next_down().to_bits(), f8e4m3::NAN.to_bits()); + + assert_eq!( + f8e4m3fnuz::MAX.next_up().to_bits(), + f8e4m3fnuz::NAN.to_bits() + ); + assert_eq!( + f8e4m3fnuz::MIN.next_down().to_bits(), + f8e4m3fnuz::NAN.to_bits() + ); + assert_eq!(f4e2m1fn::MAX.next_up().to_bits(), f4e2m1fn::MAX.to_bits()); + assert_eq!(f4e2m1fn::MIN.next_down().to_bits(), f4e2m1fn::MIN.to_bits()); + assert_eq!(f8e8m0fnu::MAX.next_up().to_bits(), f8e8m0fnu::NAN.to_bits()); + assert_eq!( + f8e8m0fnu::from_bits(0x00).next_down().to_bits(), + f8e8m0fnu::NAN.to_bits() + ); +} + +fn assert_next_methods(name: &str, max_bits: u8) { + let sorted = sorted_representable_values::(max_bits); + for raw in 0..=max_bits { + let value = T::from_bits(raw); + let next_up = value.next_up(); + let expected_up = expected_next_up(value, &sorted); + assert_eq!( + next_up.to_bits(), + expected_up.to_bits(), + "{name} next_up raw {raw:#04x}: got {:#04x}, expected {:#04x}", + next_up.to_bits(), + expected_up.to_bits() + ); + + let next_down = value.next_down(); + let expected_down = expected_next_down(value, &sorted); + assert_eq!( + next_down.to_bits(), + expected_down.to_bits(), + "{name} next_down raw {raw:#04x}: got {:#04x}, expected {:#04x}", + next_down.to_bits(), + expected_down.to_bits() + ); + } +} + +fn sorted_representable_values(max_bits: u8) -> Vec { + let mut values = Vec::new(); + for raw in 0..=max_bits { + let value = T::from_bits(raw); + if !value.is_nan() { + values.push(value); + } + } + values.sort_by(|lhs, rhs| lhs.to_f32().total_cmp(&rhs.to_f32())); + values +} + +fn expected_next_up(value: T, sorted: &[T]) -> T { + if value.is_nan() { + return value; + } + + let as_f32 = value.to_f32(); + if is_positive_infinity(as_f32) { + return value; + } + if is_zero(as_f32) { + return first_greater_than_zero(sorted).unwrap_or(T::INFINITY); + } + let index = sorted_index(value, sorted); + sorted.get(index + 1).copied().unwrap_or(T::INFINITY) +} + +fn expected_next_down(value: T, sorted: &[T]) -> T { + if value.is_nan() { + return value; + } + + let as_f32 = value.to_f32(); + if is_negative_infinity(as_f32) { + return value; + } + if is_zero(as_f32) { + return last_less_than_zero(sorted).unwrap_or(T::NEG_INFINITY); + } + let index = sorted_index(value, sorted); + index + .checked_sub(1) + .and_then(|previous| sorted.get(previous)) + .copied() + .unwrap_or(T::NEG_INFINITY) +} + +fn sorted_index(value: T, sorted: &[T]) -> usize { + sorted + .iter() + .position(|candidate| candidate.to_bits() == value.to_bits()) + .expect("non-NaN canonical value must be present in sorted representable values") +} + +fn first_greater_than_zero(sorted: &[T]) -> Option { + sorted.iter().copied().find(|value| value.to_f32() > 0.0) +} + +fn last_less_than_zero(sorted: &[T]) -> Option { + sorted + .iter() + .copied() + .rev() + .find(|value| value.to_f32() < 0.0) +} + +fn is_zero(value: f32) -> bool { + matches!(value.classify(), core::num::FpCategory::Zero) +} + +fn is_positive_infinity(value: f32) -> bool { + value.is_infinite() && value.is_sign_positive() +} + +fn is_negative_infinity(value: f32) -> bool { + value.is_infinite() && value.is_sign_negative() +}