Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ mod within;
pub use error::MaxMindDbError;
pub use metadata::Metadata;
pub use reader::Reader;
pub use result::{LookupResult, PathElement};
pub use within::{Within, WithinOptions};
pub use result::{LookupResult, OwnedLookupResult, PathElement};
pub use within::{OwnedWithin, Within, WithinOptions};

#[cfg(feature = "mmap")]
pub use memmap2::Mmap;
Expand Down
77 changes: 76 additions & 1 deletion src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::collections::HashSet;
use std::fs;
use std::net::IpAddr;
use std::path::Path;
use std::sync::Arc;

use ipnetwork::IpNetwork;
use serde::Deserialize;
Expand All @@ -19,7 +20,7 @@ use crate::decoder;
use crate::error::MaxMindDbError;
use crate::metadata::Metadata;
use crate::result::{LookupResult, LookupSource, NetworkKind};
use crate::within::{IpInt, Within, WithinNode, WithinOptions};
use crate::within::{IpInt, OwnedWithin, Within, WithinNode, WithinOptions};

/// Size of the data section separator (16 zero bytes).
const DATA_SECTION_SEPARATOR_SIZE: usize = 16;
Expand Down Expand Up @@ -273,6 +274,23 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
self.within(cidr, options)
}

/// Iterate over all networks in the database with an owned reader.
///
/// This is the owned-reader counterpart to [`networks()`](Self::networks).
/// It consumes an [`Arc<Reader<_>>`](Arc), and the returned iterator keeps
/// the reader alive internally.
pub fn networks_owned(
self: Arc<Self>,
options: WithinOptions,
) -> Result<OwnedWithin<S>, MaxMindDbError> {
let cidr = if self.metadata.ip_version == 6 {
IpNetwork::V6("::/0".parse().unwrap())
} else {
IpNetwork::V4("0.0.0.0/0".parse().unwrap())
};
self.within_owned(cidr, options)
}

/// Iterate over IP networks within a CIDR range.
///
/// Returns an iterator that yields [`LookupResult`] for each network in the
Expand Down Expand Up @@ -398,6 +416,63 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
Ok(within)
}

/// Iterate over IP networks within a CIDR range with an owned reader.
///
/// This is the owned-reader counterpart to [`within()`](Self::within).
/// It consumes an [`Arc<Reader<_>>`](Arc), and the returned iterator keeps
/// the reader alive internally.
pub fn within_owned(
self: Arc<Self>,
cidr: IpNetwork,
options: WithinOptions,
) -> Result<OwnedWithin<S>, MaxMindDbError> {
if self.metadata.ip_version == 4 && matches!(cidr, IpNetwork::V6(_)) {
return Err(MaxMindDbError::invalid_input(
"cannot iterate IPv6 network in IPv4-only database",
));
}
let ip_address = cidr.network();
let prefix_len = cidr.prefix() as usize;
let ip_int = IpInt::new(ip_address);
let bit_count = ip_int.bit_count();

let mut node = self.start_node(bit_count);
let node_count = self.node_count;

let mut stack: Vec<WithinNode> = Vec::with_capacity(bit_count - prefix_len);

// Traverse down the tree to the level that matches the cidr mark
let mut depth = 0_usize;
for i in 0..prefix_len {
let bit = ip_int.get_bit(i);
node = self.read_node(node, bit as usize);
depth = i + 1; // We've now traversed i+1 bits (bits 0 through i)

if node >= node_count {
// We've hit a data node or dead end before we exhausted our prefix.
// This means the requested CIDR is contained in a single record.
break;
}
}

// Always push the node - it could be:
// - A data node (> node_count): will be yielded as a single record
// - The empty node (== node_count): will be skipped unless include_networks_without_data
// - An internal node (< node_count): will be traversed to find all contained records
stack.push(WithinNode {
node,
ip_int,
prefix_len: depth,
});

Ok(OwnedWithin {
reader: self,
node_count,
stack,
options,
})
}

// Pointer 0 means "not found" because normalize_lookup_result collapses both
// the placeholder empty node (`node == node_count`) and an unfinished internal
// terminal (`node < node_count`, i.e. bits exhausted while still on a tree
Expand Down
58 changes: 57 additions & 1 deletion src/reader_test.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::net::IpAddr;
use std::sync::Arc;

use ipnetwork::IpNetwork;
use serde::Deserialize;
use serde_json::json;

use crate::geoip2;
use crate::{MaxMindDbError, Reader, Within, WithinOptions};
use crate::{MaxMindDbError, OwnedWithin, Reader, Within, WithinOptions};

const TEST_DATABASE_CONFIGS: &[(usize, usize)] =
&[(24, 4), (28, 4), (32, 4), (24, 6), (28, 6), (32, 6)];
Expand All @@ -30,6 +31,17 @@ fn collect_networks<S: AsRef<[u8]>>(iter: Within<'_, S>) -> Vec<String> {
.collect()
}

fn collect_owned_networks<S: AsRef<[u8]>>(iter: OwnedWithin<S>) -> Vec<String> {
iter.map(|result| {
result
.unwrap_or_else(|e| panic!("unexpected iterator error: {e}"))
.network()
.unwrap_or_else(|e| panic!("failed to build network from lookup result: {e}"))
.to_string()
})
.collect()
}

#[allow(clippy::float_cmp)]
#[test]
fn test_decoder() {
Expand Down Expand Up @@ -649,6 +661,50 @@ fn test_networks() {
}
}

/// Test networks_owned() keeps the reader alive and matches networks().
#[test]
fn test_networks_owned() {
init_logger();

let reader = open_test_data_reader("MaxMind-DB-test-ipv4-24.mmdb");
let expected = collect_networks(reader.networks(Default::default()).unwrap());

let reader = Arc::new(reader);
let networks = collect_owned_networks(
Arc::clone(&reader)
.networks_owned(Default::default())
.unwrap(),
);

assert_eq!(networks, expected);
}

/// Test owned iterator results can decode after the caller drops its Arc.
#[test]
fn test_networks_owned_decode_after_original_arc_drop() {
init_logger();

#[derive(Deserialize)]
struct IpRecord {
ip: String,
}

let reader = Arc::new(open_test_data_reader("MaxMind-DB-test-ipv4-24.mmdb"));
let mut iter = Arc::clone(&reader)
.networks_owned(Default::default())
.unwrap();
drop(reader);

let lookup = iter
.next()
.expect("expected at least one network")
.expect("unexpected iterator error");
let network = lookup.network().unwrap();
let record: IpRecord = lookup.decode().unwrap().unwrap();

assert_eq!(record.ip, network.ip().to_string());
}

/// Test that default options skip aliased networks
#[test]
fn test_default_skips_aliases() {
Expand Down
Loading
Loading