diff --git a/doc/api.rst b/doc/api.rst index 63427447d53..f31fce3fda9 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -250,6 +250,7 @@ Reshaping and reorganizing Dataset.roll Dataset.pad Dataset.sortby + Dataset.shuffle_by Dataset.broadcast_like DataArray @@ -590,6 +591,7 @@ Reshaping and reorganizing DataArray.roll DataArray.pad DataArray.sortby + DataArray.shuffle_by DataArray.broadcast_like DataTree @@ -1096,6 +1098,7 @@ Dataset DatasetGroupBy.var DatasetGroupBy.dims DatasetGroupBy.groups + DatasetGroupBy.shuffle DataArray --------- @@ -1127,6 +1130,7 @@ DataArray DataArrayGroupBy.var DataArrayGroupBy.dims DataArrayGroupBy.groups + DataArrayGroupBy.shuffle Grouper Objects --------------- diff --git a/doc/user-guide/groupby.rst b/doc/user-guide/groupby.rst index 98bd7b4833b..defe05e1b26 100644 --- a/doc/user-guide/groupby.rst +++ b/doc/user-guide/groupby.rst @@ -321,3 +321,41 @@ Different groupers can be combined to construct sophisticated GroupBy operations from xarray.groupers import BinGrouper ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum() + + +Shuffling +~~~~~~~~~ + +Shuffling is a generalization of sorting a DataArray or Dataset by another DataArray, named ``label`` for example, that follows from the idea of grouping by ``label``. +Shuffling reorders the DataArray or the DataArrays in a Dataset such that all members of a group occur sequentially. For example, + +.. ipython:: python + + da = xr.DataArray( + dims="x", + data=[1, 2, 3, 4, 5, 6], + coords={"label": ("x", "a b c a b c".split(" "))}, + ) + da.shuffle_by("label") + + +:py:meth:`Dataset.shuffle_by` and :py:meth:`DataArray.shuffle_by` can also take Grouper objects: + +.. ipython:: python + + from xarray.groupers import UniqueGrouper + + da.shuffle_by(label=UniqueGrouper()) + + +Shuffling can also be performed on :py:class:`DatasetGroupBy` and :py:class:`DataArrayGroupBy` objects. +The :py:meth:`DatasetGroupBy.shuffle` and :py:meth:`DataArrayGroupBy.shuffle` methods return new :py:class:`DatasetGroupBy` and :py:class:`DataArrayGroupBy` objects that operate on the shuffled Dataset or DataArray respectively. + + +.. ipython:: python + + da.groupby(label=UniqueGrouper()).shuffle() + + +For chunked array types (e.g. dask or cubed), shuffle may result in a more optimized communication pattern when compared to direct indexing by the appropriate indexer. +Shuffling also makes GroupBy operations on chunked arrays an embarrassingly parallel problem, and may significantly improve workloads that use :py:meth:`DatasetGroupBy.map` or :py:meth:`DataArrayGroupBy.map`. diff --git a/xarray/core/common.py b/xarray/core/common.py index 9a6807faad2..1b0b2e578d3 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -52,7 +52,7 @@ T_Variable, ) from xarray.core.variable import Variable - from xarray.groupers import Resampler + from xarray.groupers import Grouper, Resampler DTypeMaybeMapping = Union[DTypeLikeSave, Mapping[Any, DTypeLikeSave]] @@ -891,6 +891,68 @@ def rolling_exp( return rolling_exp.RollingExp(self, window, window_type) + def shuffle_by( + self, + group: Hashable | DataArray | Mapping[Any, Grouper] | None = None, + chunks: T_Chunks = None, + **groupers: Grouper, + ) -> Self: + """ + Sort or "shuffle" this object by a Grouper. + + "Shuffle" means the object is sorted so that all group members occur sequentially, + in the same chunk. Multiple groups may occur in the same chunk. + This method is particularly useful for chunked arrays (e.g. dask, cubed). + For chunked array types, the order of appearance is not guaranteed, but will depend on + the input chunking. + + Parameters + ---------- + group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper + Array whose unique values should be used to group this array. If a + Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary, + must map an existing variable name to a :py:class:`Grouper` instance. + chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional + How to adjust chunks along dimensions not present in the array being grouped by. + **groupers : Grouper + Grouper objects using which to shuffle the data. + + Examples + -------- + >>> import dask + >>> from xarray.groupers import UniqueGrouper + >>> da = xr.DataArray( + ... dims="x", + ... data=dask.array.arange(10, chunks=1), + ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + ... name="a", + ... ) + >>> da + Size: 80B + dask.array + Coordinates: + * x (x) int64 80B 1 2 3 1 2 3 1 2 3 0 + + >>> da.shuffle_by(x=UniqueGrouper()) + Size: 80B + dask.array + Coordinates: + * x (x) int64 80B 0 1 1 1 2 2 2 3 3 3 + + Returns + ------- + DataArray or Dataset + The same type as this object + + See Also + -------- + DataArrayGroupBy.shuffle + DatasetGroupBy.shuffle + dask.dataframe.DataFrame.shuffle + dask.array.shuffle + """ + return self.groupby(group=group, **groupers)._shuffle_obj(chunks) + def _resample( self, resample_cls: type[T_Resample], diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c6bc082f5ed..8e7c87fed72 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -117,6 +117,7 @@ Self, SideOptions, T_ChunkDimFreq, + T_Chunks, T_ChunksFreq, T_Xarray, ) @@ -661,6 +662,12 @@ def _to_dataset_whole( coord_names = set(self._coords) return Dataset._construct_direct(variables, coord_names, indexes=indexes) + def _shuffle(self, dim, *, indices: list[list[int]], chunks: T_Chunks) -> None: + shuffled = self._to_temp_dataset()._shuffle( + dim=dim, indices=indices, chunks=chunks + ) + return self._from_temp_dataset(shuffled) + def to_dataset( self, dim: Hashable = None, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bc9360a809d..13133546b7a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -166,6 +166,7 @@ ResampleCompatible, SideOptions, T_ChunkDimFreq, + T_Chunks, T_DatasetPadConstantValues, T_Xarray, ) @@ -3236,6 +3237,30 @@ def sel( result = self.isel(indexers=query_results.dim_indexers, drop=drop) return result._overwrite_indexes(*query_results.as_tuple()[1:]) + def _shuffle(self, dim, *, indices: list[list[int]], chunks: T_Chunks) -> Self: + # Shuffling is only different from `isel` for chunked arrays. + # Extract them out, and treat them specially. The rest, we route through isel. + # This makes it easy to ensure correct handling of indexes. + is_chunked = { + name: var + for name, var in self._variables.items() + if is_chunked_array(var._data) + } + subset = self[[name for name in self._variables if name not in is_chunked]] + + shuffled = ( + subset + if dim not in subset.dims + else subset.isel({dim: np.concatenate(indices)}) + ) + for name, var in is_chunked.items(): + shuffled[name] = var._shuffle( + indices=indices, + dim=dim, + chunks=chunks, + ) + return shuffled + def head( self, indexers: Mapping[Any, int] | int | None = None, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 5a5a241f6c1..082638c9206 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -6,6 +6,7 @@ import warnings from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field +from functools import partial from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast import numpy as np @@ -29,6 +30,7 @@ ) from xarray.core.merge import merge_coords from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.parallel import map_blocks from xarray.core.types import ( Dims, QuantileMethods, @@ -54,7 +56,13 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey + from xarray.core.types import ( + GroupIndex, + GroupIndices, + GroupInput, + GroupKey, + T_Chunks, + ) from xarray.core.utils import Frozen from xarray.groupers import EncodedGroups, Grouper @@ -80,6 +88,24 @@ def _codes_to_group_indices(codes: np.ndarray, N: int) -> GroupIndices: return groups +def _infer_map_blocks_template(shuffled: GroupBy, func: Callable, *args, **kwargs): + template = shuffled.map(func, *args, **kwargs) + name = shuffled.group1d.name + chunksizes = shuffled._obj.chunksizes[shuffled._group_dim] + output_group = template[name] + out_group_lens = output_group.groupby(name).count().data + block_ids = np.repeat(np.arange(len(chunksizes)), chunksizes) + frame = pd.DataFrame( + {"block_id": pd.Index(block_ids), "codes": shuffled.encoded.codes} + ) + groups_in_chunk = frame["codes"].groupby(block_ids).unique() + out_chunks = tuple( + itertools.chain(*[out_group_lens[group].tolist() for group in groups_in_chunk]) + ) + template = template.chunk({name: out_chunks}) + return template, name + + def _dummy_copy(xarray_obj): from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -614,6 +640,142 @@ def sizes(self) -> Mapping[Hashable, int]: self._sizes = self._obj.isel({self._group_dim: index}).sizes return self._sizes + def shuffle(self, chunks: T_Chunks = None): + """ + Sort or "shuffle" the underlying object. + + "Shuffle" means the object is sorted so that all group members occur sequentially, + in the same chunk. Multiple groups may occur in the same chunk. + This method is particularly useful for chunked arrays (e.g. dask, cubed). + particularly when you need to map a function that requires all members of a group + to be present in a single chunk. For chunked array types, the order of appearance + is not guaranteed, but will depend on the input chunking. + + Parameters + ---------- + chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional + How to adjust chunks along dimensions not present in the array being grouped by. + + Returns + ------- + DataArrayGroupBy or DatasetGroupBy + + Examples + -------- + >>> import dask + >>> da = xr.DataArray( + ... dims="x", + ... data=dask.array.arange(10, chunks=3), + ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + ... name="a", + ... ) + >>> shuffled = da.groupby("x").shuffle() + >>> shuffled.quantile(q=0.5).compute() + Size: 32B + array([9., 3., 4., 5.]) + Coordinates: + quantile float64 8B 0.5 + * x (x) int64 32B 0 1 2 3 + + See Also + -------- + dask.dataframe.DataFrame.shuffle + dask.array.shuffle + """ + new_groupers = { + # Using group.name handles the BinGrouper case + # It does *not* handle the TimeResampler case, + # so we just override this method in Resample + grouper.group.name: grouper.grouper.reset() + for grouper in self.groupers + } + return self._shuffle_obj(chunks).groupby( + new_groupers, + restore_coord_dims=self._restore_coord_dims, + ) + + def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: + from xarray.core.dataarray import DataArray + + was_array = isinstance(self._obj, DataArray) + as_dataset = self._obj._to_temp_dataset() if was_array else self._obj + + size = self._obj.sizes[self._group_dim] + no_slices: list[list[int]] = [ + list(range(*idx.indices(size))) if isinstance(idx, slice) else idx + for idx in self.encoded.group_indices + ] + no_slices = [idx for idx in no_slices if idx] + + for grouper in self.groupers: + if grouper.name not in as_dataset._variables: + as_dataset.coords[grouper.name] = grouper.group + + shuffled = as_dataset._shuffle( + dim=self._group_dim, indices=no_slices, chunks=chunks + ) + shuffled = self._maybe_unstack(shuffled) + new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled + return new_obj + + def _map_shuffled(self, func, args, kwargs) -> None: + def wrapper(x, func, groupers, renamer, *args, **kwargs): + return x.groupby(groupers).map(func, *args, **kwargs).rename(renamer) + + shuffled = self.shuffle() + obj = shuffled._obj.copy(deep=False) + try: + template, concat_dim = _infer_map_blocks_template( + shuffled, func, *args, **kwargs + ) + except Exception as e: + raise ValueError("Could not infer template automatically.") from e + group_dim = shuffled._group_dim + + groupers = {} + for grouper in shuffled.groupers: + name = grouper.group.name + if name not in obj: + obj.coords[name] = grouper.group + groupers[name] = grouper.grouper.reset() + + # map_blocks does not support adding new dimensions that are multiply-chunked + # For example, even renaming an existing dimension to a new name will not work. + # This would be needed for grouped reductions where at least one dimension is destroyed. + # So we engage in a renaming game. + result = map_blocks( + # 1. This renamer renames dimensions named after the grouping variable to the + # dimension we are grouping over. + # For example .groupby("label") where label.dims == ("x",); we rename the + # output "label" dimension back to "x" + partial( + wrapper, func=func, groupers=groupers, renamer={concat_dim: group_dim} + ), + obj, + args=args, + kwargs=kwargs, + # 2. Again do the same renaming transform on the template + template=template.rename({concat_dim: group_dim}), + ) + + if ( + group_dim == concat_dim + and self._obj.sizes[group_dim] == template.sizes[group_dim] + ): + # invert the shuffling + inverse = _inverse_permutation_indices(self.encoded.group_indices) + # output chunk sizes are the same as the input's + indices = [ + arr.tolist() + for arr in np.split( + inverse, np.cumsum(self._obj.chunksizes[self._group_dim])[:-1] + ) + ] + result = result._shuffle(dim=group_dim, indices=indices, chunks="auto") + + # 3. Now invert the renaming + return result.rename({group_dim: concat_dim}) + def map( self, func: Callable, @@ -824,7 +986,9 @@ def _maybe_unstack(self, obj): # and `inserted_dims` # if multiple groupers all share the same single dimension, then # we don't stack/unstack. Do that manually now. - obj = obj.unstack(*self.encoded.unique_coord.dims) + dims_to_unstack = self.encoded.unique_coord.dims + if all(dim in obj.dims for dim in dims_to_unstack): + obj = obj.unstack(*dims_to_unstack) to_drop = [ grouper.name for grouper in self.groupers @@ -1304,6 +1468,8 @@ def map( func: Callable[..., DataArray], args: tuple[Any, ...] = (), shortcut: bool | None = None, + *, + shuffle: bool = False, **kwargs: Any, ) -> DataArray: """Apply a function to each array in the group and concatenate them @@ -1347,9 +1513,16 @@ def map( applied : DataArray The result of splitting, applying and combining this array. """ - grouped = self._iter_grouped_shortcut() if shortcut else self._iter_grouped() - applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped) - return self._combine(applied, shortcut=shortcut) + if shuffle and self._obj.chunksizes: + return self._map_shuffled(func, args=args, kwargs=kwargs) + else: + grouped = ( + self._iter_grouped_shortcut() if shortcut else self._iter_grouped() + ) + applied = ( + maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped + ) + return self._combine(applied, shortcut=shortcut) def apply(self, func, shortcut=False, args=(), **kwargs): """ @@ -1473,6 +1646,8 @@ def map( func: Callable[..., Dataset], args: tuple[Any, ...] = (), shortcut: bool | None = None, + *, + shuffle: bool = False, **kwargs: Any, ) -> Dataset: """Apply a function to each Dataset in the group and concatenate them @@ -1504,9 +1679,12 @@ def map( applied : Dataset The result of splitting, applying and combining this dataset. """ - # ignore shortcut if set (for now) - applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped()) - return self._combine(applied) + if shuffle and self._obj.chunksizes: + return self._map_shuffled(func, args=args, kwargs=kwargs) + else: + # ignore shortcut if set (for now) + applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped()) + return self._combine(applied) def apply(self, func, args=(), shortcut=None, **kwargs): """ diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index a0dfe56807b..41e334e22e1 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -488,7 +488,7 @@ def _wrapper( " Please construct a template with appropriately chunked dask arrays." ) - new_indexes = set(template.xindexes) - set(merged_coordinates) + new_indexes = set(template.xindexes) - set(merged_coordinates.xindexes) modified_indexes = set( name for name, xindex in coordinates.xindexes.items() diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 9dd91d86a47..9c7ed179e84 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Callable, Hashable, Iterable, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from xarray.core._aggregations import ( DataArrayResampleAggregations, @@ -14,6 +14,8 @@ if TYPE_CHECKING: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + from xarray.core.types import T_Chunks + from xarray.groupers import Resampler from xarray.groupers import RESAMPLE_DIM @@ -58,6 +60,60 @@ def _flox_reduce( result = result.rename({RESAMPLE_DIM: self._group_dim}) return result + def shuffle(self, chunks: T_Chunks = None): + """ + Sort or "shuffle" the underlying object. + + "Shuffle" means the object is sorted so that all group members occur sequentially, + in the same chunk. Multiple groups may occur in the same chunk. + This method is particularly useful for chunked arrays (e.g. dask, cubed). + particularly when you need to map a function that requires all members of a group + to be present in a single chunk. For chunked array types, the order of appearance + is not guaranteed, but will depend on the input chunking. + + .. warning:: + + With resampling it is a lot better to use ``.chunk`` instead of ``.shuffle``, + since one can only resample a sorted time coordinate. + + Parameters + ---------- + chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional + How to adjust chunks along dimensions not present in the array being grouped by. + + Returns + ------- + DataArrayGroupBy or DatasetGroupBy + + Examples + -------- + >>> import dask + >>> da = xr.DataArray( + ... dims="x", + ... data=dask.array.arange(10, chunks=3), + ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + ... name="a", + ... ) + >>> shuffled = da.groupby("x").shuffle() + >>> shuffled.quantile(q=0.5).compute() + Size: 32B + array([9., 3., 4., 5.]) + Coordinates: + quantile float64 8B 0.5 + * x (x) int64 32B 0 1 2 3 + + See Also + -------- + dask.dataframe.DataFrame.shuffle + dask.array.shuffle + """ + (grouper,) = self.groupers + shuffled = self._shuffle_obj(chunks).drop_vars(RESAMPLE_DIM) + return shuffled.resample( + {self._group_dim: cast("Resampler", grouper.grouper.reset())}, + restore_coord_dims=self._restore_coord_dims, + ) + def _drop_coords(self) -> T_Xarray: """Drop non-dimension coordinates along the resampled dimension.""" obj = self._obj diff --git a/xarray/core/types.py b/xarray/core/types.py index 2e7572a3858..e161e0721f6 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -321,7 +321,7 @@ def __reversed__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]: ... ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"] GroupKey = Any -GroupIndex = Union[int, slice, list[int]] +GroupIndex = Union[slice, list[int]] GroupIndices = tuple[GroupIndex, ...] Bins = Union[ int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 13053faff58..402520c8b4b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -45,7 +45,13 @@ maybe_coerce_to_str, ) from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions -from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import ( + integer_types, + is_0d_dask_array, + is_chunked_array, + to_duck_array, +) from xarray.util.deprecation_helpers import deprecate_dims NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( @@ -1013,6 +1019,24 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) + def _shuffle( + self, indices: list[list[int]], dim: Hashable, chunks: T_Chunks + ) -> Self: + # TODO (dcherian): consider making this public API + array = self._data + if is_chunked_array(array): + chunkmanager = get_chunked_array_type(array) + return self._replace( + data=chunkmanager.shuffle( + array, + indexer=indices, + axis=self.get_axis_num(dim), + chunks=chunks, + ) + ) + else: + return self.isel({dim: np.concatenate(indices)}) + def isel( self, indexers: Mapping[Any, Any] | None = None, diff --git a/xarray/groupers.py b/xarray/groupers.py index 996f86317b9..610b99822cd 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -26,6 +26,7 @@ DatetimeLike, GroupIndices, ResampleCompatible, + Self, SideOptions, ) from xarray.core.variable import Variable @@ -128,6 +129,13 @@ def factorize(self, group: T_Group) -> EncodedGroups: """ pass + @abstractmethod + def reset(self) -> Self: + """ + Creates a new version of this Grouper clearing any caches. + """ + pass + class Resampler(Grouper): """ @@ -155,6 +163,9 @@ def group_as_index(self) -> pd.Index: self._group_as_index = pd.Index(np.array(self.group).ravel()) return self._group_as_index + def reset(self) -> Self: + return type(self)() + def factorize(self, group: T_Group) -> EncodedGroups: self.group = group @@ -276,6 +287,16 @@ class BinGrouper(Grouper): include_lowest: bool = False duplicates: Literal["raise", "drop"] = "raise" + def reset(self) -> Self: + return type(self)( + bins=self.bins, + right=self.right, + labels=self.labels, + precision=self.precision, + include_lowest=self.include_lowest, + duplicates=self.duplicates, + ) + def __post_init__(self) -> None: if duck_array_ops.isnull(self.bins).all(): raise ValueError("All bin edges are NaN.") @@ -362,6 +383,15 @@ class TimeResampler(Resampler): index_grouper: CFTimeGrouper | pd.Grouper = field(init=False, repr=False) group_as_index: pd.Index = field(init=False, repr=False) + def reset(self) -> Self: + return type(self)( + freq=self.freq, + closed=self.closed, + label=self.label, + origin=self.origin, + offset=self.offset, + ) + def _init_properties(self, group: T_Group) -> None: from xarray import CFTimeIndex diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 90c442d2e1f..95e7d7adfc3 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -78,7 +78,8 @@ def dtype(self) -> _DType_co: ... _Chunks = tuple[_Shape, ...] _NormalizedChunks = tuple[tuple[int, ...], ...] # FYI in some cases we don't allow `None`, which this doesn't take account of. -T_ChunkDim: TypeAlias = int | Literal["auto"] | None | tuple[int, ...] +# # FYI the `str` is for a size string, e.g. "16MB", supported by dask. +T_ChunkDim: TypeAlias = str | int | Literal["auto"] | None | tuple[int, ...] # We allow the tuple form of this (though arguably we could transition to named dims only) T_Chunks: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDim] diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index 55e78450067..a360fd4dff7 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -256,3 +256,18 @@ def store( targets=targets, **kwargs, ) + + def shuffle( + self, x: DaskArray, indexer: list[list[int]], axis: int, chunks: T_Chunks + ) -> DaskArray: + import dask.array + + if not module_available("dask", minversion="2024.08.1"): + raise ValueError( + "This method is very inefficient on dask<2024.08.1. Please upgrade." + ) + if chunks is None: + chunks = "auto" + if chunks != "auto": + raise NotImplementedError("Only chunks='auto' is supported at present.") + return dask.array.shuffle(x, indexer, axis, chunks="auto") diff --git a/xarray/namedarray/parallelcompat.py b/xarray/namedarray/parallelcompat.py index baa0b92bdb7..ebc16b77f14 100644 --- a/xarray/namedarray/parallelcompat.py +++ b/xarray/namedarray/parallelcompat.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from xarray.namedarray._typing import ( + T_Chunks, _Chunks, _DType, _DType_co, @@ -357,6 +358,11 @@ def compute( """ raise NotImplementedError() + def shuffle( + self, x: T_ChunkedArray, indexer: list[list[int]], axis: int, chunks: T_Chunks + ) -> T_ChunkedArray: + raise NotImplementedError() + def persist( self, *data: T_ChunkedArray | Any, **kwargs: Any ) -> tuple[T_ChunkedArray | Any, ...]: diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 7293a6fd931..7b8795cc09e 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -107,6 +107,7 @@ def _importorskip( has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") has_cftime, requires_cftime = _importorskip("cftime") has_dask, requires_dask = _importorskip("dask") +has_dask_ge_2024_08_1, _ = _importorskip("dask", minversion="2024.08.1") with warnings.catch_warnings(): warnings.filterwarnings( "ignore", @@ -151,7 +152,7 @@ def _importorskip( not has_numbagg_or_bottleneck, reason="requires numbagg or bottleneck" ) has_numpy_2, requires_numpy_2 = _importorskip("numpy", "2.0.0") -_, requires_flox_0_9_12 = _importorskip("flox", "0.9.12") +has_flox_0_9_12, requires_flox_0_9_12 = _importorskip("flox", "0.9.12") has_array_api_strict, requires_array_api_strict = _importorskip("array_api_strict") diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index cc795b75118..c87dd61a6e5 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1811,3 +1811,27 @@ def test_minimize_graph_size(): # all the other dimensions. # e.g. previously for 'x', actual == numchunks['y'] * numchunks['z'] assert actual == numchunks[var], (actual, numchunks[var]) + + +@pytest.mark.parametrize( + "chunks, expected_chunks", + [ + ((1,), (1, 3, 3, 3)), + ((10,), (10,)), + ], +) +def test_shuffle_by(chunks, expected_chunks): + from xarray.groupers import UniqueGrouper + + da = xr.DataArray( + dims="x", + data=dask.array.arange(10, chunks=chunks), + coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + name="a", + ) + ds = da.to_dataset() + + for obj in [ds, da]: + actual = obj.shuffle_by(x=UniqueGrouper()) + assert_identical(actual, obj.sortby("x")) + assert actual.chunksizes["x"] == expected_chunks diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 3d948e7840e..1f2b85825d0 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3,6 +3,7 @@ import datetime import operator import warnings +from typing import Literal from unittest import mock import numpy as np @@ -29,8 +30,11 @@ assert_identical, create_test_data, has_cftime, + has_dask, + has_dask_ge_2024_08_1, has_flox, has_pandas_ge_2_2, + raise_if_dask_computes, requires_cftime, requires_dask, requires_flox, @@ -215,6 +219,14 @@ def test_groupby_indexvariable(use_flox: bool) -> None: assert_identical(expected, actual) +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize( + "chunk", + [ + pytest.param(True, marks=pytest.mark.skipif(not has_dask, reason="no dask")), + False, + ], +) @pytest.mark.parametrize( "obj", [ @@ -222,12 +234,22 @@ def test_groupby_indexvariable(use_flox: bool) -> None: xr.Dataset({"foo": ("x", [1, 2, 3, 4, 5, 6])}, {"x": [1, 1, 1, 2, 2, 2]}), ], ) -def test_groupby_map_shrink_groups(obj) -> None: +def test_groupby_map_shrink_groups(obj, chunk: bool, shuffle: bool) -> None: expected = obj.isel(x=[0, 1, 3, 4]) - actual = obj.groupby("x").map(lambda f: f.isel(x=[0, 1])) + if chunk: + obj = obj.chunk(x=2) + actual = obj.groupby("x").map(lambda f: f.isel(x=[0, 1]), shuffle=shuffle) assert_identical(expected, actual) +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize( + "chunk", + [ + pytest.param(True, marks=pytest.mark.skipif(not has_dask, reason="no dask")), + False, + ], +) @pytest.mark.parametrize( "obj", [ @@ -235,7 +257,7 @@ def test_groupby_map_shrink_groups(obj) -> None: xr.Dataset({"foo": ("x", [1, 2, 3])}, {"x": [1, 2, 2]}), ], ) -def test_groupby_map_change_group_size(obj) -> None: +def test_groupby_map_change_group_size(obj, chunk: bool, shuffle: bool) -> None: def func(group): if group.sizes["x"] == 1: result = group.isel(x=[0, 0]) @@ -244,7 +266,9 @@ def func(group): return result expected = obj.isel(x=[0, 0, 1]) - actual = obj.groupby("x").map(func) + if chunk: + obj = obj.chunk(x=2) + actual = obj.groupby("x").map(func, shuffle=shuffle) assert_identical(expected, actual) @@ -628,10 +652,25 @@ def test_groupby_repr_datetime(obj) -> None: assert actual == expected +@pytest.mark.filterwarnings("ignore:No index created for dimension id:UserWarning") @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") @pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:No index created for dimension id:UserWarning") -def test_groupby_drops_nans() -> None: +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize( + "chunk", + [ + pytest.param( + dict(lat=1), marks=pytest.mark.skipif(not has_dask, reason="no dask") + ), + pytest.param( + dict(lat=2, lon=2), marks=pytest.mark.skipif(not has_dask, reason="no dask") + ), + False, + ], +) +def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None: + if shuffle and chunk and not has_dask_ge_2024_08_1: + pytest.skip() # GH2383 # nan in 2D data variable (requires stacking) ds = xr.Dataset( @@ -646,13 +685,17 @@ def test_groupby_drops_nans() -> None: ds["id"].values[3, 0] = np.nan ds["id"].values[-1, -1] = np.nan + if chunk: + ds = ds.chunk(chunk) grouped = ds.groupby(ds.id) + if shuffle: + grouped = grouped.shuffle() # non reduction operation expected1 = ds.copy() - expected1.variable.values[0, 0, :] = np.nan - expected1.variable.values[-1, -1, :] = np.nan - expected1.variable.values[3, 0, :] = np.nan + expected1.variable.data[0, 0, :] = np.nan + expected1.variable.data[-1, -1, :] = np.nan + expected1.variable.data[3, 0, :] = np.nan actual1 = grouped.map(lambda x: x).transpose(*ds.variable.dims) assert_identical(actual1, expected1) @@ -1351,11 +1394,27 @@ def test_groupby_sum(self) -> None: assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y")) assert_allclose(expected_sum_axis1, grouped.sum("y")) + @pytest.mark.parametrize("use_flox", [True, False]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize( + "chunk", + [ + pytest.param( + True, marks=pytest.mark.skipif(not has_dask, reason="no dask") + ), + False, + ], + ) @pytest.mark.parametrize("method", ["sum", "mean", "median"]) - def test_groupby_reductions(self, method) -> None: - array = self.da - grouped = array.groupby("abc") + def test_groupby_reductions( + self, use_flox: bool, method: str, shuffle: bool, chunk: bool + ) -> None: + if shuffle and chunk and not has_dask_ge_2024_08_1: + pytest.skip() + array = self.da + if chunk: + array.data = array.chunk({"y": 5}).data reduction = getattr(np, method) expected = Dataset( { @@ -1373,14 +1432,14 @@ def test_groupby_reductions(self, method) -> None: } )["foo"] - with xr.set_options(use_flox=False): - actual_legacy = getattr(grouped, method)(dim="y") - - with xr.set_options(use_flox=True): - actual_npg = getattr(grouped, method)(dim="y") + with raise_if_dask_computes(): + grouped = array.groupby("abc") + if shuffle: + grouped = grouped.shuffle() - assert_allclose(expected, actual_legacy) - assert_allclose(expected, actual_npg) + with xr.set_options(use_flox=use_flox): + actual = getattr(grouped, method)(dim="y") + assert_allclose(expected, actual) def test_groupby_count(self) -> None: array = DataArray( @@ -1644,13 +1703,14 @@ def test_groupby_bins( ) with xr.set_options(use_flox=use_flox): - actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).sum() + gb = array.groupby_bins("dim_0", bins=bins, **cut_kwargs) + actual = gb.sum() assert_identical(expected, actual) + assert_identical(expected, gb.shuffle().sum()) - actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).map( - lambda x: x.sum() - ) + actual = gb.map(lambda x: x.sum()) assert_identical(expected, actual) + assert_identical(expected, gb.shuffle().map(lambda x: x.sum())) # make sure original array dims are unchanged assert len(array.dim_0) == 4 @@ -1795,6 +1855,7 @@ def test_groupby_fastpath_for_monotonic(self, use_flox: bool) -> None: class TestDataArrayResample: + @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("use_cftime", [True, False]) @pytest.mark.parametrize( "resample_freq", @@ -1809,7 +1870,7 @@ class TestDataArrayResample: ], ) def test_resample( - self, use_cftime: bool, resample_freq: ResampleCompatible + self, use_cftime: bool, shuffle: bool, resample_freq: ResampleCompatible ) -> None: if use_cftime and not has_cftime: pytest.skip() @@ -1832,16 +1893,21 @@ def resample_as_pandas(array, *args, **kwargs): array = DataArray(np.arange(10), [("time", times)]) - actual = array.resample(time=resample_freq).mean() + rs = array.resample(time=resample_freq) + actual = rs.mean() expected = resample_as_pandas(array, resample_freq) assert_identical(expected, actual) + assert_identical(expected, rs.shuffle().mean()) - actual = array.resample(time=resample_freq).reduce(np.mean) - assert_identical(expected, actual) + assert_identical(expected, rs.reduce(np.mean)) + assert_identical(expected, rs.shuffle().reduce(np.mean)) - actual = array.resample(time=resample_freq, closed="right").mean() - expected = resample_as_pandas(array, resample_freq, closed="right") + rs = array.resample(time="24h", closed="right") + actual = rs.mean() + shuffled = rs.shuffle().mean() + expected = resample_as_pandas(array, "24h", closed="right") assert_identical(expected, actual) + assert_identical(expected, shuffled) with pytest.raises(ValueError, match=r"Index must be monotonic"): array[[2, 0, 1]].resample(time=resample_freq) @@ -2667,6 +2733,9 @@ def factorize(self, group) -> EncodedGroups: codes = group.copy(data=codes_).rename("year") return EncodedGroups(codes=codes, full_index=pd.Index(uniques)) + def reset(self): + return type(self)() + da = xr.DataArray( dims="time", data=np.arange(20), @@ -2763,8 +2832,9 @@ def test_multiple_groupers_string(as_dataset) -> None: obj.groupby("labels1", foo=UniqueGrouper()) +@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("use_flox", [True, False]) -def test_multiple_groupers(use_flox) -> None: +def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None: da = DataArray( np.array([1, 2, 3, 0, 2, np.nan]), dims="d", @@ -2776,6 +2846,8 @@ def test_multiple_groupers(use_flox) -> None: ) gb = da.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() repr(gb) expected = DataArray( @@ -2795,6 +2867,8 @@ def test_multiple_groupers(use_flox) -> None: coords = {"a": ("x", [0, 0, 1, 1]), "b": ("y", [0, 0, 1, 1])} square = DataArray(np.arange(16).reshape(4, 4), coords=coords, dims=["x", "y"]) gb = square.groupby(a=UniqueGrouper(), b=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() repr(gb) with xr.set_options(use_flox=use_flox): actual = gb.mean() @@ -2818,11 +2892,15 @@ def test_multiple_groupers(use_flox) -> None: dims=["x", "y", "z"], ) gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() repr(gb) with xr.set_options(use_flox=use_flox): assert_identical(gb.mean("z"), b.mean("z")) gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() repr(gb) with xr.set_options(use_flox=use_flox): actual = gb.mean() @@ -2837,13 +2915,16 @@ def test_multiple_groupers(use_flox) -> None: @pytest.mark.parametrize("use_flox", [True, False]) -def test_multiple_groupers_mixed(use_flox) -> None: +@pytest.mark.parametrize("shuffle", [True, False]) +def test_multiple_groupers_mixed(use_flox: bool, shuffle: bool) -> None: # This groupby has missing groups ds = xr.Dataset( {"foo": (("x", "y"), np.arange(12).reshape((4, 3)))}, coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))}, ) gb = ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() expected_data = np.array( [ [[0.0, np.nan], [np.nan, 3.0]], @@ -2985,3 +3066,15 @@ def test_groupby_multiple_bin_grouper_missing_groups(): }, ) assert_identical(actual, expected) + + +@requires_dask +def test_shuffle_by_simple() -> None: + da = xr.DataArray( + dims="x", + data=[1, 2, 3, 4, 5, 6], + coords={"label": ("x", "a b c a b c".split(" "))}, + ) + actual = da.chunk(x=2).shuffle_by(label=UniqueGrouper()) + expected = da.shuffle_by(label=UniqueGrouper()) + assert_identical(actual, expected)