From 1ff17582cee2216d6c6ea130489e365023755528 Mon Sep 17 00:00:00 2001 From: weric Date: Mon, 20 Apr 2026 14:18:35 +0000 Subject: [PATCH] Fix min_count support for multi-dimensional reductions - Remove restriction that prevented min_count with tuple/list axes - Use np.take(mask.shape, axis).prod() for computing total elements across multiple axes - Add _is_nat_dtype helper to fix NaT dtype comparison bug - Update docstrings for clarity --- xarray/core/nanops.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 41c8d258d7a..ff7d65660c7 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -22,23 +22,34 @@ def _replace_nan(a, val): return where_method(val, mask, a), mask +def _is_nat_dtype(dtype): + """Check if dtype is a datetime64 or timedelta64 type (NaT types). + + This is needed because numpy's __eq__ behavior makes dtype in (nat1, nat2) + return True even for non-NAT dtypes due to numpy scalar comparison quirks. + """ + if not isinstance(dtype, np.dtype): + return False + return dtype.kind in "mM" # 'm' for timedelta64, 'M' for datetime64 + + def _maybe_null_out(result, axis, mask, min_count=1): """ xarray version of pandas.core.nanops._maybe_null_out """ - if hasattr(axis, "__len__"): # if tuple or list - raise ValueError( - "min_count is not available for reduction with more than one dimensions." - ) - if axis is not None and getattr(result, "ndim", False): - null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0 + if hasattr(axis, "__len__"): # if tuple or list + # For multiple axes, compute total elements along those axes + total_elements = np.take(mask.shape, axis).prod() + null_mask = (total_elements - mask.sum(axis) - min_count) < 0 + else: + null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0 if null_mask.any(): dtype, fill_value = dtypes.maybe_promote(result.dtype) result = result.astype(dtype) result[null_mask] = fill_value - elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES: + elif getattr(result, "dtype", None) is not None and not _is_nat_dtype(result.dtype): null_mask = mask.size - mask.sum() if null_mask < min_count: result = np.nan @@ -47,7 +58,7 @@ def _maybe_null_out(result, axis, mask, min_count=1): def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): - """ In house nanargmin, nanargmax for object arrays. Always return integer + """In house nanargmin, nanargmax for object arrays. Always return integer type """ valid_count = count(value, axis=axis) @@ -62,7 +73,7 @@ def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs): - """ In house nanmin and nanmax for object array """ + """In house nanmin and nanmax for object array""" valid_count = count(value, axis=axis) filled_value = fillna(value, fill_value) data = getattr(np, func)(filled_value, axis=axis, **kwargs) @@ -118,7 +129,7 @@ def nansum(a, axis=None, dtype=None, out=None, min_count=None): def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs): - """ In house nanmean. ddof argument will be used in _nanvar method """ + """In house nanmean. ddof argument will be used in _nanvar method""" from .duck_array_ops import _dask_or_eager_func, count, fillna, where_method valid_count = count(value, axis=axis)