From 5eaea3f9259e9277ddf87bef30252fbfbe168b03 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Sat, 13 Jun 2026 20:21:46 -0700 Subject: [PATCH] Add owned network iterators --- src/lib.rs | 4 +- src/reader.rs | 77 +++++++++++++++++- src/reader_test.rs | 58 +++++++++++++- src/result.rs | 192 +++++++++++++++++++++++++++++++++++++++++++++ src/within.rs | 147 +++++++++++++++++++++++++++++++++- 5 files changed, 473 insertions(+), 5 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b19c723d..daefe4c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -86,8 +86,8 @@ mod within; pub use error::MaxMindDbError; pub use metadata::Metadata; pub use reader::Reader; -pub use result::{LookupResult, PathElement}; -pub use within::{Within, WithinOptions}; +pub use result::{LookupResult, OwnedLookupResult, PathElement}; +pub use within::{OwnedWithin, Within, WithinOptions}; #[cfg(feature = "mmap")] pub use memmap2::Mmap; diff --git a/src/reader.rs b/src/reader.rs index 08eb3ec2..3a89261c 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -4,6 +4,7 @@ use std::collections::HashSet; use std::fs; use std::net::IpAddr; use std::path::Path; +use std::sync::Arc; use ipnetwork::IpNetwork; use serde::Deserialize; @@ -19,7 +20,7 @@ use crate::decoder; use crate::error::MaxMindDbError; use crate::metadata::Metadata; use crate::result::{LookupResult, LookupSource, NetworkKind}; -use crate::within::{IpInt, Within, WithinNode, WithinOptions}; +use crate::within::{IpInt, OwnedWithin, Within, WithinNode, WithinOptions}; /// Size of the data section separator (16 zero bytes). const DATA_SECTION_SEPARATOR_SIZE: usize = 16; @@ -273,6 +274,23 @@ impl<'de, S: AsRef<[u8]>> Reader { self.within(cidr, options) } + /// Iterate over all networks in the database with an owned reader. + /// + /// This is the owned-reader counterpart to [`networks()`](Self::networks). + /// It consumes an [`Arc>`](Arc), and the returned iterator keeps + /// the reader alive internally. + pub fn networks_owned( + self: Arc, + options: WithinOptions, + ) -> Result, MaxMindDbError> { + let cidr = if self.metadata.ip_version == 6 { + IpNetwork::V6("::/0".parse().unwrap()) + } else { + IpNetwork::V4("0.0.0.0/0".parse().unwrap()) + }; + self.within_owned(cidr, options) + } + /// Iterate over IP networks within a CIDR range. /// /// Returns an iterator that yields [`LookupResult`] for each network in the @@ -398,6 +416,63 @@ impl<'de, S: AsRef<[u8]>> Reader { Ok(within) } + /// Iterate over IP networks within a CIDR range with an owned reader. + /// + /// This is the owned-reader counterpart to [`within()`](Self::within). + /// It consumes an [`Arc>`](Arc), and the returned iterator keeps + /// the reader alive internally. + pub fn within_owned( + self: Arc, + cidr: IpNetwork, + options: WithinOptions, + ) -> Result, MaxMindDbError> { + if self.metadata.ip_version == 4 && matches!(cidr, IpNetwork::V6(_)) { + return Err(MaxMindDbError::invalid_input( + "cannot iterate IPv6 network in IPv4-only database", + )); + } + let ip_address = cidr.network(); + let prefix_len = cidr.prefix() as usize; + let ip_int = IpInt::new(ip_address); + let bit_count = ip_int.bit_count(); + + let mut node = self.start_node(bit_count); + let node_count = self.node_count; + + let mut stack: Vec = Vec::with_capacity(bit_count - prefix_len); + + // Traverse down the tree to the level that matches the cidr mark + let mut depth = 0_usize; + for i in 0..prefix_len { + let bit = ip_int.get_bit(i); + node = self.read_node(node, bit as usize); + depth = i + 1; // We've now traversed i+1 bits (bits 0 through i) + + if node >= node_count { + // We've hit a data node or dead end before we exhausted our prefix. + // This means the requested CIDR is contained in a single record. + break; + } + } + + // Always push the node - it could be: + // - A data node (> node_count): will be yielded as a single record + // - The empty node (== node_count): will be skipped unless include_networks_without_data + // - An internal node (< node_count): will be traversed to find all contained records + stack.push(WithinNode { + node, + ip_int, + prefix_len: depth, + }); + + Ok(OwnedWithin { + reader: self, + node_count, + stack, + options, + }) + } + // Pointer 0 means "not found" because normalize_lookup_result collapses both // the placeholder empty node (`node == node_count`) and an unfinished internal // terminal (`node < node_count`, i.e. bits exhausted while still on a tree diff --git a/src/reader_test.rs b/src/reader_test.rs index 329e7cd8..583bdd27 100644 --- a/src/reader_test.rs +++ b/src/reader_test.rs @@ -1,11 +1,12 @@ use std::net::IpAddr; +use std::sync::Arc; use ipnetwork::IpNetwork; use serde::Deserialize; use serde_json::json; use crate::geoip2; -use crate::{MaxMindDbError, Reader, Within, WithinOptions}; +use crate::{MaxMindDbError, OwnedWithin, Reader, Within, WithinOptions}; const TEST_DATABASE_CONFIGS: &[(usize, usize)] = &[(24, 4), (28, 4), (32, 4), (24, 6), (28, 6), (32, 6)]; @@ -30,6 +31,17 @@ fn collect_networks>(iter: Within<'_, S>) -> Vec { .collect() } +fn collect_owned_networks>(iter: OwnedWithin) -> Vec { + iter.map(|result| { + result + .unwrap_or_else(|e| panic!("unexpected iterator error: {e}")) + .network() + .unwrap_or_else(|e| panic!("failed to build network from lookup result: {e}")) + .to_string() + }) + .collect() +} + #[allow(clippy::float_cmp)] #[test] fn test_decoder() { @@ -649,6 +661,50 @@ fn test_networks() { } } +/// Test networks_owned() keeps the reader alive and matches networks(). +#[test] +fn test_networks_owned() { + init_logger(); + + let reader = open_test_data_reader("MaxMind-DB-test-ipv4-24.mmdb"); + let expected = collect_networks(reader.networks(Default::default()).unwrap()); + + let reader = Arc::new(reader); + let networks = collect_owned_networks( + Arc::clone(&reader) + .networks_owned(Default::default()) + .unwrap(), + ); + + assert_eq!(networks, expected); +} + +/// Test owned iterator results can decode after the caller drops its Arc. +#[test] +fn test_networks_owned_decode_after_original_arc_drop() { + init_logger(); + + #[derive(Deserialize)] + struct IpRecord { + ip: String, + } + + let reader = Arc::new(open_test_data_reader("MaxMind-DB-test-ipv4-24.mmdb")); + let mut iter = Arc::clone(&reader) + .networks_owned(Default::default()) + .unwrap(); + drop(reader); + + let lookup = iter + .next() + .expect("expected at least one network") + .expect("unexpected iterator error"); + let network = lookup.network().unwrap(); + let record: IpRecord = lookup.decode().unwrap().unwrap(); + + assert_eq!(record.ip, network.ip().to_string()); +} + /// Test that default options skip aliased networks #[test] fn test_default_skips_aliases() { diff --git a/src/result.rs b/src/result.rs index 8b3253bb..fa03d94f 100644 --- a/src/result.rs +++ b/src/result.rs @@ -6,6 +6,7 @@ //! selectively via paths. use std::net::IpAddr; +use std::sync::Arc; use ipnetwork::IpNetwork; use serde::Deserialize; @@ -57,6 +58,22 @@ pub struct LookupResult<'a, S: AsRef<[u8]>> { network_kind: NetworkKind, } +/// The result of looking up or iterating an IP network with an owned reader. +/// +/// This is the owned-reader counterpart to [`LookupResult`]. It keeps the +/// backing [`Reader`] alive with an [`Arc`], which allows owned iterators to +/// yield lazy lookup handles without borrowing from the iterator itself. +#[derive(Debug, Clone)] +pub struct OwnedLookupResult> { + reader: Arc>, + /// Offset into the data section, or None if not found. + data_offset: Option, + prefix_len: u8, + ip: IpAddr, + source: LookupSource, + network_kind: NetworkKind, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum LookupSource { Lookup, @@ -327,6 +344,181 @@ impl<'a, S: AsRef<[u8]>> LookupResult<'a, S> { } } +impl> OwnedLookupResult { + #[inline] + fn decoder(&self, offset: usize) -> super::decoder::Decoder<'_> { + let buf = &self.reader.buf.as_ref()[self.reader.pointer_base..]; + super::decoder::Decoder::new(buf, offset) + } + + /// Creates a new OwnedLookupResult for a found IP. + pub(crate) fn new_found( + reader: Arc>, + data_offset: usize, + prefix_len: u8, + ip: IpAddr, + source: LookupSource, + network_kind: NetworkKind, + ) -> Self { + OwnedLookupResult { + reader, + data_offset: Some(data_offset), + prefix_len, + ip, + source, + network_kind, + } + } + + /// Creates a new OwnedLookupResult for an IP not in the database. + pub(crate) fn new_not_found( + reader: Arc>, + prefix_len: u8, + ip: IpAddr, + source: LookupSource, + network_kind: NetworkKind, + ) -> Self { + OwnedLookupResult { + reader, + data_offset: None, + prefix_len, + ip, + source, + network_kind, + } + } + + /// Returns true if the database contains data for this IP address. + #[inline] + pub fn has_data(&self) -> bool { + self.data_offset.is_some() + } + + /// Returns the network containing the looked-up IP address. + pub fn network(&self) -> Result { + let (ip, prefix) = match (self.source, self.network_kind, self.ip) { + (_, NetworkKind::V4, IpAddr::V4(v4)) => (IpAddr::V4(v4), self.prefix_len), + (_, NetworkKind::V4InV6Subtree, IpAddr::V4(v4)) => ( + IpAddr::V4(v4), + self.prefix_len - self.reader.ipv4_start_bit_depth as u8, + ), + (LookupSource::Lookup, NetworkKind::V6, IpAddr::V4(_)) => { + use std::net::Ipv6Addr; + (IpAddr::V6(Ipv6Addr::UNSPECIFIED), self.prefix_len) + } + (_, NetworkKind::V6, IpAddr::V6(v6)) => (IpAddr::V6(v6), self.prefix_len), + (_, _, ip) => unreachable!("unexpected lookup result state for network: {ip:?}"), + }; + + let network_ip = mask_ip(ip, prefix); + IpNetwork::new(network_ip, prefix).map_err(MaxMindDbError::InvalidNetwork) + } + + /// Returns the data section offset if found, for use as a cache key. + #[inline] + pub fn offset(&self) -> Option { + self.data_offset + } + + /// Decodes the full record into the specified type. + pub fn decode<'a, T>(&'a self) -> Result, MaxMindDbError> + where + T: Deserialize<'a>, + { + let Some(offset) = self.data_offset else { + return Ok(None); + }; + + let mut decoder = self.decoder(offset); + T::deserialize(&mut decoder).map(Some) + } + + /// Decodes a value at a specific path within the record. + pub fn decode_path<'a, T>( + &'a self, + path: &[PathElement<'_>], + ) -> Result, MaxMindDbError> + where + T: Deserialize<'a>, + { + let Some(offset) = self.data_offset else { + return Ok(None); + }; + + let mut decoder = self.decoder(offset); + + for (i, element) in path.iter().enumerate() { + let with_path = |e| add_path_context(e, &path[..=i]); + + match *element { + PathElement::Key(key) => { + let (_, type_num) = decoder.peek_type().map_err(with_path)?; + if type_num != TYPE_MAP { + return Err(MaxMindDbError::decoding_at_path( + format!("expected map for Key(\"{key}\"), got type {type_num}"), + decoder.offset(), + render_path(&path[..=i]), + )); + } + + let size = decoder.consume_map_header().map_err(with_path)?; + + let mut found = false; + let key_bytes = key.as_bytes(); + for _ in 0..size { + let k = decoder.read_str_as_bytes().map_err(with_path)?; + if k == key_bytes { + found = true; + break; + } else { + decoder.skip_value().map_err(with_path)?; + } + } + + if !found { + return Ok(None); + } + } + PathElement::Index(idx) | PathElement::IndexFromEnd(idx) => { + let (_, type_num) = decoder.peek_type().map_err(with_path)?; + if type_num != TYPE_ARRAY { + let elem = match *element { + PathElement::Index(i) => format!("Index({i})"), + PathElement::IndexFromEnd(i) => format!("IndexFromEnd({i})"), + PathElement::Key(_) => unreachable!(), + }; + return Err(MaxMindDbError::decoding_at_path( + format!("expected array for {elem}, got type {type_num}"), + decoder.offset(), + render_path(&path[..=i]), + )); + } + + let size = decoder.consume_array_header().map_err(with_path)?; + + if idx >= size { + return Ok(None); + } + + let actual_idx = match *element { + PathElement::Index(i) => i, + PathElement::IndexFromEnd(i) => size - 1 - i, + PathElement::Key(_) => unreachable!(), + }; + + for _ in 0..actual_idx { + decoder.skip_value().map_err(with_path)?; + } + } + } + } + + T::deserialize(&mut decoder) + .map(Some) + .map_err(|e| add_path_context(e, path)) + } +} + /// Adds path context to a Decoding error if it doesn't already have one. fn add_path_context(err: MaxMindDbError, path: &[PathElement<'_>]) -> MaxMindDbError { match err { diff --git a/src/within.rs b/src/within.rs index 4aaf75e4..09987afe 100644 --- a/src/within.rs +++ b/src/within.rs @@ -2,11 +2,12 @@ use std::cmp::Ordering; use std::net::IpAddr; +use std::sync::Arc; use crate::decoder; use crate::error::MaxMindDbError; use crate::reader::Reader; -use crate::result::{LookupResult, LookupSource, NetworkKind}; +use crate::result::{LookupResult, LookupSource, NetworkKind, OwnedLookupResult}; /// Options for network iteration. /// @@ -115,6 +116,20 @@ pub struct Within<'de, S: AsRef<[u8]>> { pub(crate) options: WithinOptions, } +/// Owned iterator over IP networks within a CIDR range. +/// +/// Created by [`Reader::within_owned()`](crate::Reader::within_owned) or +/// [`Reader::networks_owned()`](crate::Reader::networks_owned). This iterator +/// owns an [`Arc`] to the reader, which makes it suitable for APIs that need to +/// store and return the iterator without a separate reader borrow. +#[derive(Debug)] +pub struct OwnedWithin> { + pub(crate) reader: Arc>, + pub(crate) node_count: usize, + pub(crate) stack: Vec, + pub(crate) options: WithinOptions, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum IpInt { V4(u32), @@ -262,6 +277,108 @@ impl<'de, S: AsRef<[u8]>> Iterator for Within<'de, S> { } } +impl> Iterator for OwnedWithin { + type Item = Result, MaxMindDbError>; + + fn next(&mut self) -> Option { + while let Some(current) = self.stack.pop() { + let bit_count = current.ip_int.bit_count(); + + // Skip networks that are aliases for the IPv4 network (unless option is set) + if !self.options.include_aliased_networks + && self.reader.ipv4_start != 0 + && current.node == self.reader.ipv4_start + && bit_count == 128 + && !current.ip_int.is_ipv4_in_ipv6() + { + continue; + } + + match current.node.cmp(&self.node_count) { + Ordering::Greater => { + // This is a data node, emit it and we're done (until the following next call) + let ip_addr = ip_int_to_addr(¤t.ip_int); + + // Resolve the pointer to a data offset + let data_offset = match self.reader.resolve_data_pointer(current.node) { + Ok(offset) => offset, + Err(e) => return Some(Err(e)), + }; + + // Check if we should skip empty values + if self.options.skip_empty_values { + match self.is_empty_value_at(data_offset) { + Ok(true) => continue, // Skip empty value + Ok(false) => {} // Not empty, proceed + Err(e) => return Some(Err(e)), + } + } + + let network_kind = match current.ip_int { + IpInt::V4(_) => NetworkKind::V4, + IpInt::V6(_) + if current.ip_int.is_ipv4_in_ipv6() + && self.reader.has_ipv4_subtree() + && current.prefix_len >= self.reader.ipv4_start_bit_depth => + { + NetworkKind::V4InV6Subtree + } + IpInt::V6(_) => NetworkKind::V6, + }; + + return Some(Ok(OwnedLookupResult::new_found( + Arc::clone(&self.reader), + data_offset, + current.prefix_len as u8, + ip_addr, + LookupSource::Iter, + network_kind, + ))); + } + Ordering::Equal => { + // Dead end (no data) - include if option is set + if self.options.include_networks_without_data { + let ip_addr = ip_int_to_addr(¤t.ip_int); + let network_kind = match current.ip_int { + IpInt::V4(_) => NetworkKind::V4, + IpInt::V6(_) + if current.ip_int.is_ipv4_in_ipv6() + && self.reader.has_ipv4_subtree() + && current.prefix_len >= self.reader.ipv4_start_bit_depth => + { + NetworkKind::V4InV6Subtree + } + IpInt::V6(_) => NetworkKind::V6, + }; + return Some(Ok(OwnedLookupResult::new_not_found( + Arc::clone(&self.reader), + current.prefix_len as u8, + ip_addr, + LookupSource::Iter, + network_kind, + ))); + } + // Otherwise skip (current behavior) + } + Ordering::Less => { + // In order traversal of our children + // right/1-bit + let mut right_ip_int = current.ip_int; + + if current.prefix_len < bit_count { + right_ip_int.set_bit(current.prefix_len); + } + + self.push_child(current.node, 1, right_ip_int, current.prefix_len + 1); + // left/0-bit + self.push_child(current.node, 0, current.ip_int, current.prefix_len + 1); + } + } + } + None + } +} + impl<'de, S: AsRef<[u8]>> Within<'de, S> { fn push_child( &mut self, @@ -290,6 +407,34 @@ impl<'de, S: AsRef<[u8]>> Within<'de, S> { } } +impl> OwnedWithin { + fn push_child( + &mut self, + parent_node: usize, + direction: usize, + ip_int: IpInt, + prefix_len: usize, + ) { + let node = self.reader.read_node(parent_node, direction); + self.stack.push(WithinNode { + node, + ip_int, + prefix_len, + }); + } + + /// Check if the value at the given data offset is an empty map or array. + fn is_empty_value_at(&self, data_offset: usize) -> Result { + let buf = &self.reader.buf.as_ref()[self.reader.pointer_base..]; + let mut dec = decoder::Decoder::new(buf, data_offset); + let (size, type_num) = dec.peek_type()?; + match type_num { + decoder::TYPE_MAP | decoder::TYPE_ARRAY => Ok(size == 0), + _ => Ok(false), // Non-container types are never "empty" + } + } +} + /// Convert IpInt to IpAddr pub(crate) fn ip_int_to_addr(ip_int: &IpInt) -> IpAddr { match ip_int {