Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion xarray/computation/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def _maybe_null_out(result, axis, mask, min_count=1):
dtype, fill_value = dtypes.maybe_promote(result.dtype)
result = where(null_mask, fill_value, astype(result, dtype))

elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES:
elif (dtype := getattr(result, "dtype", None)) and getattr(
dtype, "kind", None
) not in {"m", "M"}:
null_mask = mask.size - duck_array_ops.sum(mask)
result = where(null_mask < min_count, np.nan, result)

Expand Down
16 changes: 11 additions & 5 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from xarray.compat import array_api_compat, npcompat
from xarray.compat.npcompat import HAS_STRING_DTYPE
from xarray.core import utils
from xarray.core.types import PDDatetimeUnitOptions

if TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -88,7 +89,11 @@ def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]:
# See https://github.com/numpy/numpy/issues/10685
# np.timedelta64 is a subclass of np.integer
# Check np.timedelta64 before np.integer
fill_value = np.timedelta64("NaT")
unit, _ = np.datetime_data(dtype)
# np.datetime_data returns a generic str for the unit so we need to
# cast it to a valid time unit for mypy purposes.
unit = cast(PDDatetimeUnitOptions, unit)
fill_value = np.timedelta64("NaT", unit)
dtype_ = dtype
elif isdtype(dtype, "integral"):
dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64
Expand All @@ -97,8 +102,12 @@ def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]:
dtype_ = dtype
fill_value = np.nan + np.nan * 1j
elif np.issubdtype(dtype, np.datetime64):
unit, _ = np.datetime_data(dtype)
# np.datetime_data returns a generic str for the unit so we need to
# cast it to a valid time unit for mypy purposes.
unit = cast(PDDatetimeUnitOptions, unit)
dtype_ = dtype
fill_value = np.datetime64("NaT")
fill_value = np.datetime64("NaT", unit)
else:
dtype_ = object
fill_value = np.nan
Expand All @@ -108,9 +117,6 @@ def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]:
return dtype_out, fill_value


NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype}


def get_fill_value(dtype):
"""Return an appropriate fill value for this dtype.

Expand Down
17 changes: 13 additions & 4 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ def _determine_cmap_params(
else:
mpl = attempt_import("matplotlib")

if plot_data.dtype.kind == "m":
unit, _ = np.datetime_data(plot_data.dtype)
zero = np.timedelta64(0, unit)
elif plot_data.dtype.kind == "M":
unit, _ = np.datetime_data(plot_data.dtype)
zero = np.datetime64(0, unit)
else:
zero = 0.0

if isinstance(levels, Iterable):
levels = sorted(levels)

Expand All @@ -197,15 +206,15 @@ def _determine_cmap_params(
# Handle all-NaN input data gracefully
if calc_data.size == 0:
# Arbitrary default for when all values are NaN
calc_data = np.array(0.0)
calc_data = np.array(zero)

# Setting center=False prevents a divergent cmap
possibly_divergent = center is not False

# Set center to 0 so math below makes sense but remember its state
center_is_none = False
if center is None:
center = 0
center = zero
center_is_none = True

# Setting both vmin and vmax prevents a divergent cmap
Expand Down Expand Up @@ -240,10 +249,10 @@ def _determine_cmap_params(

if possibly_divergent:
levels_are_divergent = (
isinstance(levels, Iterable) and levels[0] * levels[-1] < 0
isinstance(levels, Iterable) and levels[0] * levels[-1] < zero
)
# kwargs not specific about divergent or not: infer defaults from data
divergent = (vmin < 0 < vmax) or not center_is_none or levels_are_divergent
divergent = (vmin < zero < vmax) or not center_is_none or levels_are_divergent
else:
divergent = False

Expand Down
10 changes: 2 additions & 8 deletions xarray/tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def test_inf(obj) -> None:
("I", (np.float64, "nan")), # dtype('uint32')
("l", (np.float64, "nan")), # dtype('int64')
("L", (np.float64, "nan")), # dtype('uint64')
("m", (np.timedelta64, "NaT")), # dtype('<m8')
("M", (np.datetime64, "NaT")), # dtype('<M8')
("<m8[ns]", (np.dtype("<m8[ns]"), "NaT")), # dtype('<m8[ns]')
("<M8[ns]", (np.dtype("<M8[ns]"), "NaT")), # dtype('<M8[ns]')
("O", (np.dtype("O"), "nan")), # dtype('O')
("p", (np.float64, "nan")), # dtype('int64')
("P", (np.float64, "nan")), # dtype('uint64')
Expand All @@ -123,12 +123,6 @@ def test_maybe_promote(kind, expected) -> None:
assert str(actual[1]) == expected[1]


def test_nat_types_membership() -> None:
assert np.datetime64("NaT").dtype in dtypes.NAT_TYPES
assert np.timedelta64("NaT").dtype in dtypes.NAT_TYPES
assert np.float64 not in dtypes.NAT_TYPES


@pytest.mark.parametrize(
["dtype", "kinds", "xp", "expected"],
(
Expand Down
Loading