# This file is part of the Open Data Cube, see https://opendatacube.org for more information
#
# Copyright (c) 2015-2020 ODC Contributors
# SPDX-License-Identifier: Apache-2.0
"""
Add ``.odc.`` extension to :py:class:`xarray.Dataset` and :class:`xarray.DataArray`.
"""
from __future__ import annotations
import functools
import json
import math
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import (
Any,
Callable,
Dict,
Hashable,
List,
Literal,
Optional,
Set,
Tuple,
TypeVar,
Union,
)
import numpy
import xarray
from affine import Affine
from ._interop import have, is_dask_collection
from ._rgba import colorize, to_rgba
from .crs import CRS, CRSError, SomeCRS, norm_crs_or_error
from .gcp import GCPGeoBox, GCPMapping
from .geobox import Coordinate, GeoBox, GeoboxAnchor
from .geom import Geometry
from .math import (
affine_from_axis,
approx_equal_affine,
is_affine_st,
is_nodata_empty,
maybe_int,
resolution_from_affine,
resolve_fill_value,
resolve_nodata,
)
from .masking import (
bits_to_bool,
enum_to_bool,
mask_invalid_data,
mask_clouds,
mask_ls,
mask_s2,
scale_and_offset,
)
from .overlap import compute_output_geobox
from .roi import roi_is_empty
from .types import Nodata, Resolution, SomeNodata, SomeResolution, SomeShape, xy_
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-lines
if have.rasterio:
from ._compress import compress
from ._map import add_to, explore
from .cog import to_cog, write_cog
from .warp import rio_reproject
XarrayObject = Union[xarray.DataArray, xarray.Dataset]
XrT = TypeVar("XrT", xarray.DataArray, xarray.Dataset)
F = TypeVar("F", bound=Callable)
SomeGeoBox = Union[GeoBox, GCPGeoBox]
_DEFAULT_CRS_COORD_NAME = "spatial_ref"
# these attributes are pruned during reproject
SPATIAL_ATTRIBUTES = ("crs", "crs_wkt", "grid_mapping", "gcps", "epsg")
NODATA_ATTRIBUTES = ("nodata", "_FillValue")
REPROJECT_SKIP_ATTRS: set[str] = set(SPATIAL_ATTRIBUTES + NODATA_ATTRIBUTES)
# dimensions with these names are considered spatial
STANDARD_SPATIAL_DIMS = [
("y", "x"),
("yc", "xc"),
("latitude", "longitude"),
("lat", "lon"),
]
@dataclass
class GeoState:
"""
Geospatial information for xarray object.
"""
spatial_dims: Optional[Tuple[str, str]] = None
crs_coord: Optional[xarray.DataArray] = None
transform: Optional[Affine] = None
crs: Optional[CRS] = None
geobox: Optional[SomeGeoBox] = None
gcp: Optional[GCPMapping] = None
def _get_crs_from_attrs(obj: XarrayObject, sdims: Tuple[str, str]) -> Optional[CRS]:
"""
Looks for attribute named ``crs`` containing CRS string.
- Checks spatials coords attrs
- Checks data variable attrs
- Checks dataset attrs
Returns
=======
Content for `.attrs[crs]` usually it's a string
None if not present in any of the places listed above
"""
crs_set: Set[CRS] = set()
def _add_candidate(crs):
if crs is None:
return
if isinstance(crs, str):
try:
crs_set.add(CRS(crs))
except CRSError:
warnings.warn(f"Failed to parse CRS: {crs}")
elif isinstance(crs, CRS):
# support current bad behaviour of injecting CRS directly into
# attributes in example notebooks
crs_set.add(crs)
else:
warnings.warn(f"Ignoring crs attribute of type: {type(crs)}")
def process_attrs(attrs):
_add_candidate(attrs.get("crs", None))
_add_candidate(attrs.get("crs_wkt", None))
def process_datavar(x):
process_attrs(x.attrs)
for dim in sdims:
if dim in x.coords:
process_attrs(x.coords[dim].attrs)
if isinstance(obj, xarray.Dataset):
process_attrs(obj.attrs)
for dv in obj.data_vars.values():
process_datavar(dv)
else:
process_datavar(obj)
crs = None
if len(crs_set) >= 1:
crs = crs_set.pop()
if len(crs_set) > 0:
if any(other != crs for other in crs_set):
warnings.warn("Have several candidates for a CRS")
return crs
[docs]
def spatial_dims(
xx: Union[xarray.DataArray, xarray.Dataset], relaxed: bool = False
) -> Optional[Tuple[str, str]]:
"""
Find spatial dimensions of ``xx``.
Checks for presence of dimensions named:
``y, x | latitude, longitude | lat, lon``
If ``relaxed=True`` and none of the above dimension names are found,
assume that last two dimensions are spatial dimensions.
:returns: ``None`` if no dimensions with expected names are found
:returns: ``('y', 'x') | ('latitude', 'longitude') | ('lat', 'lon')``
"""
def skip_dim(dim: str) -> bool:
if dim in ("time", "band", "bands", "wavelength", "wavelengths"):
return True
# skip dimensions without coord of the same name
if dim not in xx.coords:
return True
coord = xx.coords[dim]
# Primary coordinate for spatial dimension must have floating point type
if coord.dtype.kind != "f":
return True
return False
_dims = [str(dim) for dim in xx.dims]
dims = set(_dims)
for guess in STANDARD_SPATIAL_DIMS:
if dims.issuperset(guess):
return guess
_dims = [dim for dim in _dims if not skip_dim(str(dim))]
if relaxed and len(_dims) >= 2:
return _dims[-2], _dims[-1]
return None
def _mk_crs_coord(
crs: CRS,
name: str = _DEFAULT_CRS_COORD_NAME,
gcps=None,
transform: Optional[Affine] = None,
) -> xarray.DataArray:
# pylint: disable=protected-access
cf = crs.proj.to_cf()
epsg = 0 if crs.epsg is None else crs.epsg
crs_wkt = cf.get("crs_wkt", None) or crs.wkt
if gcps is not None:
# Store as string
cf["gcps"] = json.dumps(_gcps_to_json(gcps))
if transform is not None:
cf["GeoTransform"] = _render_geo_transform(transform, precision=24)
return xarray.DataArray(
numpy.asarray(epsg, "int32"),
name=name,
dims=(),
attrs={"spatial_ref": crs_wkt, **cf},
)
def _gcps_to_json(gcps):
def _to_feature(p):
coords = [p.x, p.y] if p.z is None else [p.x, p.y, p.z]
return {
"type": "Feature",
"properties": {
"id": str(p.id),
"info": (p.info or ""),
"row": p.row,
"col": p.col,
},
"geometry": {"type": "Point", "coordinates": coords},
}
return {"type": "FeatureCollection", "features": list(map(_to_feature, gcps))}
def _coord_to_xr(name: str, c: Coordinate, **attrs) -> xarray.DataArray:
"""
Construct xr.DataArray from named Coordinate object.
This can then be used to define coordinates for ``xr.Dataset|xr.DataArray``
"""
attrs = {"units": c.units, "resolution": c.resolution, **attrs}
return xarray.DataArray(
c.values, coords={name: c.values}, dims=(name,), attrs=attrs
)
[docs]
def assign_crs(
xx: XrT,
crs: SomeCRS,
crs_coord_name: str = _DEFAULT_CRS_COORD_NAME,
) -> XrT:
"""
Assign CRS for a non-georegistered array or dataset.
Returns a new object with CRS information populated.
.. code-block:: python
xx = xr.open_rasterio("some-file.tif")
print(xx.odc.crs)
print(xx.astype("float32").crs)
:param xx: :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray`
:param crs: CRS to assign
:param crs_coord_name: how to name crs coordinate (defaults to ``spatial_ref``)
"""
crs = norm_crs_or_error(crs)
crs_coord = _mk_crs_coord(crs, name=crs_coord_name)
xx = xx.assign_coords({crs_coord_name: crs_coord})
if isinstance(xx, xarray.DataArray):
xx.encoding.update(grid_mapping=crs_coord_name)
elif isinstance(xx, xarray.Dataset):
for band in xx.data_vars.values():
band.encoding.update(grid_mapping=crs_coord_name)
return xx
[docs]
def mask(
xx: XrT, poly: Geometry, invert: bool = False, all_touched: bool = True
) -> XrT:
"""
Apply a polygon geometry as a mask, setting all
:py:class:`xarray.Dataset` or :py:class:`xarray.DataArray` pixels
outside the rasterized polygon to ``NaN``.
:param xx:
:py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray`.
:param poly:
A :py:class:`odc.geo.geom.Geometry` polygon used to mask ``xx``.
:param invert:
Whether to invert the mask before applying it to ``xx``. If
``True``, only pixels inside of ``poly`` will be masked.
:param all_touched:
If ``True``, the rasterize step will burn in all pixels touched
by ``poly``. If ``False``, only pixels whose centers are within
the polygon or that are selected by Bresenham's line algorithm
will be burned in.
:return:
A :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray`
masked by ``poly``.
.. seealso:: :py:meth:`odc.geo.xr.rasterize`
"""
# Rasterise `poly` into geobox of `xx`
rasterized = rasterize(
poly=poly,
how=xx.odc.geobox,
all_touched=all_touched,
value_inside=not invert,
)
# Mask data outside rasterized `poly`
xx_masked = xx.where(rasterized.data)
# Remove nodata attribute from arrays
if isinstance(xx_masked, xarray.Dataset):
for var in xx_masked.data_vars:
xx_masked[var].attrs.pop("nodata", None)
else:
xx_masked.attrs.pop("nodata", None)
return xx_masked
[docs]
def crop(
xx: XrT, poly: Geometry, apply_mask: bool = True, all_touched: bool = True
) -> XrT:
"""
Crops and optionally mask an :py:class:`xarray.Dataset` or
:py:class:`xarray.DataArray` to the spatial extent of a geometry.
:param xx:
:py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray`.
:param poly:
A :py:class:`odc.geo.geom.Geometry` polygon used to crop ``xx``.
:param apply_mask:
Whether to mask out pixels outside of the rasterized extent of
``poly`` by setting them to ``NaN``.
:param all_touched:
If ``True`` and ``apply_mask=True``, the rasterize step will
burn in all pixels touched by ``poly``. If ``False``, only
pixels whose centers are within the polygon or that are selected
by Bresenham's line algorithm will be burned in.
:return:
A :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray`
cropped and optionally masked to the spatial extent of ``poly``.
.. seealso:: :py:meth:`odc.geo.xr.mask`
"""
meta: ODCExtension = xx.odc
sdims = meta.spatial_dims
gbox = meta.geobox
if sdims is None or gbox is None:
raise ValueError("Can't locate spatial dimensions")
if not isinstance(gbox, GeoBox):
raise ValueError("Can't crop GCPGeoBox")
# Create new geobox with pixel grid of `xx` but enclosing `poly`.
poly_geobox = gbox.enclosing(poly)
# Calculate ROI slices into `xx` for intersection between both geoboxes.
roi = gbox.overlap_roi(poly_geobox)
# Verify that `poly` overlaps with `xx` by checking if the returned
# ROI is empty
if roi_is_empty(roi):
raise ValueError(
"The supplied `poly` must overlap spatially with the extent of `xx`."
)
# Crop spatial dims of `xx` using ROI
xx_cropped = xx.isel({sdims[0]: roi[0], sdims[1]: roi[1]})
# Optionally mask data outside rasterized `poly`
if apply_mask:
xx_cropped = mask(xx_cropped, poly, all_touched=all_touched)
return xx_cropped
[docs]
def xr_coords(
gbox: SomeGeoBox,
crs_coord_name: Optional[str] = _DEFAULT_CRS_COORD_NAME,
always_yx: bool = False,
dims: Optional[Tuple[str, str]] = None,
) -> Dict[Hashable, xarray.DataArray]:
"""
Dictionary of Coordinates in xarray format.
:param gbox:
:py:class:`~odc.geo.geobox.GeoBox` or :py:class:`~odc.geo.gcp.GCPGeoBox`
:param crs_coord_name:
Use custom name for CRS coordinate, default is "spatial_ref". Set to
``None`` to not generate CRS coordinate at all.
:param always_yx:
If True, always use names ``y,x`` for spatial coordinates even for
geographic geoboxes.
:param dims:
Use custom names for spatial dimensions, default is to use ``y,x`` or
``latitude, longitude`` based on projection used. Dimensions are supplied
in "array" order, i.e. ``('y', 'x')``.
:returns:
Dictionary ``name:str -> xr.DataArray``. Where names are either as
supplied by ``dims=`` or otherwise ``y,x`` for projected or
``latitude, longitude`` for geographic.
"""
if dims is None:
if always_yx:
dims = ("y", "x")
else:
dims = gbox.dimensions
attrs = {}
crs = gbox.crs
if crs is not None:
attrs["crs"] = str(crs)
gcps = None
transform: Optional[Affine] = None
if isinstance(gbox, GCPGeoBox):
coords: Dict[Hashable, xarray.DataArray] = {
name: _mk_pixel_coord(name, sz) for name, sz in zip(dims, gbox.shape)
}
gcps = gbox.gcps()
else:
transform = gbox.transform
if gbox.axis_aligned:
coords = {
name: _coord_to_xr(name, coord, **attrs)
for name, coord in zip(dims, gbox.coordinates.values())
}
else:
coords = {
name: _mk_pixel_coord(name, sz) for name, sz in zip(dims, gbox.shape)
}
if crs_coord_name is not None and crs is not None:
coords[crs_coord_name] = _mk_crs_coord(
crs, crs_coord_name, gcps=gcps, transform=transform
)
return coords
def _mk_pixel_coord(
name: str,
sz: int,
) -> xarray.DataArray:
data = numpy.arange(0.5, sz, dtype="float32")
xx = xarray.DataArray(
data, coords={name: data}, dims=(name,), attrs={"units": "pixel"}
)
return xx
def _is_spatial_ref(coord) -> bool:
return coord.ndim == 0 and (
"spatial_ref" in coord.attrs or "crs_wkt" in coord.attrs
)
def _locate_crs_coords(xx: XarrayObject) -> List[xarray.DataArray]:
grid_mapping = xx.encoding.get("grid_mapping", None)
if grid_mapping is None:
grid_mapping = xx.attrs.get("grid_mapping")
if grid_mapping is not None:
# Specific mapping is defined via NetCDF/CF convention
coord = xx.coords.get(grid_mapping, None)
if coord is None:
warnings.warn(
f"grid_mapping={grid_mapping} is not pointing to valid coordinate"
)
return []
return [coord]
# Find all dimensionless coordinates with `spatial_ref|crs_wkt` attribute present
return [coord for coord in xx.coords.values() if _is_spatial_ref(coord)]
def _extract_crs(crs_coord: xarray.DataArray) -> Optional[CRS]:
_wkt = crs_coord.attrs.get("spatial_ref", None) # GDAL convention?
if _wkt is None:
_wkt = crs_coord.attrs.get("crs_wkt", None) # CF convention
if _wkt is None:
return None
try:
return CRS(_wkt)
except CRSError:
return None
def _extract_gcps(crs_coord: xarray.DataArray) -> Optional[GCPMapping]:
gcps = crs_coord.attrs.get("gcps", None)
if gcps is None:
return None
crs = _extract_crs(crs_coord)
try:
if isinstance(gcps, str):
gcps = json.loads(gcps)
wld = Geometry(gcps, crs=crs)
pix = [
xy_(f["properties"]["col"], f["properties"]["row"])
for f in gcps["features"]
]
return GCPMapping(pix, wld)
except (IndexError, KeyError, ValueError, json.JSONDecodeError):
return None
def _extract_geo_transform(crs_coord: xarray.DataArray) -> Optional[Affine]:
geo_transform_parts = crs_coord.attrs.get("GeoTransform", "").split(" ")
if len(geo_transform_parts) != 6:
return None
try:
c, a, b, f, d, e = map(float, geo_transform_parts)
except ValueError:
return None
return Affine.from_gdal(c, a, b, f, d, e)
def _render_geo_transform(transform: Affine, precision: int = 24) -> str:
return " ".join(
map(lambda x: f"{x:.{precision}f}".rstrip("0").rstrip("."), transform.to_gdal())
)
def _extract_transform(
src: XarrayObject,
sdims: Tuple[str, str],
crs_coord: Optional[xarray.DataArray],
gcp: bool,
) -> Optional[Affine]:
if any(dim not in src.coords for dim in sdims):
# special case of no spatial dims at all
# happens for GCP/rotated sources loaded by rioxarray
if gcp or crs_coord is None:
return None
return _extract_geo_transform(crs_coord)
_yy, _xx = (src[dim] for dim in sdims)
original_transform: Affine | None = None
if crs_coord is not None:
original_transform = _extract_geo_transform(crs_coord)
# First try to compute from 1-D X/Y coords
try:
transform = affine_from_axis(_xx.values, _yy.values)
except ValueError:
# This can fail when any dimension is shorter than 2 elements
# Figure out fallback resolution if possible and try again
if crs_coord is None or original_transform is None:
return None
try:
transform = affine_from_axis(
_xx.values,
_yy.values,
resolution_from_affine(original_transform),
)
except ValueError:
return None
if original_transform is not None:
if not is_affine_st(original_transform):
# non-axis aligned geobox detected
# adjust transform
# world <- pix' <- pix
transform = original_transform * transform
if any(map(math.isnan, transform)):
transform = original_transform
if approx_equal_affine(transform, original_transform):
transform = original_transform
return transform
def _locate_geo_info(src: XarrayObject) -> GeoState:
# pylint: disable=too-many-locals
if len(src.dims) < 2:
return GeoState()
sdims = spatial_dims(src, relaxed=True)
if sdims is None:
return GeoState()
crs_coord: Optional[xarray.DataArray] = None
crs: Optional[CRS] = None
geobox: Optional[SomeGeoBox] = None
gcp: Optional[GCPMapping] = None
ny, nx = (src.coords[dim].shape[0] for dim in sdims)
_crs_coords = _locate_crs_coords(src)
num_candidates = len(_crs_coords)
if num_candidates > 0:
if num_candidates > 1:
warnings.warn("Multiple CRS coordinates are present")
crs_coord = _crs_coords[0]
crs = _extract_crs(crs_coord)
gcp = _extract_gcps(crs_coord)
else:
# try looking in attributes
crs = _get_crs_from_attrs(src, sdims)
transform = _extract_transform(src, sdims, crs_coord, gcp is not None)
if gcp is not None:
geobox = GCPGeoBox((ny, nx), gcp, transform)
elif transform is not None:
geobox = GeoBox((ny, nx), transform, crs)
return GeoState(
spatial_dims=sdims,
crs_coord=crs_coord,
transform=transform,
crs=crs,
geobox=geobox,
gcp=gcp,
)
def _wrap_op(method: F) -> F:
@functools.wraps(method, assigned=("__doc__",))
def wrapped(*args, **kw):
# pylint: disable=protected-access
_self, *rest = args
return method(_self._xx, *rest, **kw)
return wrapped # type: ignore
[docs]
def xr_reproject(
src: XrT,
how: Union[SomeCRS, GeoBox],
*,
resampling: Union[str, int] = "nearest",
dst_nodata: SomeNodata = "auto",
dtype=None,
resolution: Union[SomeResolution, Literal["auto", "fit", "same"]] = "auto",
shape: Union[SomeShape, int, None] = None,
tight: bool = False,
anchor: GeoboxAnchor = "default",
tol: float = 0.01,
round_resolution: Union[None, bool, Callable[[float, str], float]] = None,
**kw,
) -> XrT:
"""
Reproject raster to different projection/resolution.
:param src:
:py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` to reproject.
:param how:
How to reproject the raster. Can be a GeoBox or a CRS (e.g. CRS object or
an "ESPG:XXXX" string/integer). If a CRS is provided, the output pixel
grid can be customised further via ``resolution``, ``shape``, ``tight``,
``anchor``, ``tol``, ``round_resolution``.
:param resampling:
Resampling method to use when reprojecting the raster. Defaults to
"nearest", also supports "average", "bilinear", "cubic", "cubic_spline",
"lanczos", "mode", "gauss", "max", "min", "med", "q1", "q3".
:param dst_nodata:
Set a custom nodata value for the output resampled raster.
:param resolution:
* "same" use exactly the same resolution as src
* "fit" use center pixel to determine scale change between the two
* | "auto" is to use the same resolution on the output if CRS units are
| the same between the source and destination and otherwise use "fit"
* Ignored if ``shape=`` is supplied
* Else resolution in the units of the output crs
:param shape:
Span that many pixels, if it's a single number then span that many pixels
along the longest dimension, other dimension will be computed to maintain
roughly square pixels. Takes precedence over ``resolution=`` parameter.
:param tight:
By default output pixel grid is adjusted to align pixel edges to X/Y axis,
suppling ``tight=True`` produces unaligned geobox on the output.
:param anchor:
Control pixel snapping, default is to snap pixel edge to ``X=0,Y=0``.
Ignored when ``tight=True`` is supplied.
:param tol:
Fraction of the output pixel that can be ignored, defaults to 1/100.
Bounding box of the output geobox is allowed to be smaller by that amount
than transformed footprint of the original.
:param round_resolution:
``round_resolution(res: float, units: str) -> float``
This method uses :py:mod:`rasterio`.
.. seealso:: :py:meth:`odc.geo.overlap.compute_output_geobox`
"""
kw = {
"shape": shape,
"resolution": resolution,
"tight": tight,
"anchor": anchor,
"tol": tol,
"round_resolution": round_resolution,
**kw,
}
if isinstance(src, xarray.DataArray):
return _xr_reproject_da(
src, how, resampling=resampling, dst_nodata=dst_nodata, dtype=dtype, **kw
)
return _xr_reproject_ds(
src, how, resampling=resampling, dst_nodata=dst_nodata, dtype=dtype, **kw
)
def _extract_output_geobox_params(kw):
# NOTE: modifies input, removes keys
out = {}
for k in ("tight", "anchor", "resolution", "shape", "tol", "round_resolution"):
if k in kw:
out[k] = kw.pop(k)
return out
def _xr_reproject_ds(
src: Any,
how: Union[SomeCRS, GeoBox],
*,
resampling: Union[str, int] = "nearest",
dst_nodata: SomeNodata = "auto",
dtype=None,
**kw,
) -> xarray.Dataset:
assert isinstance(src, xarray.Dataset)
if have.rasterio is False: # pragma: nocover
raise RuntimeError("Please install `rasterio` to use this method")
assert isinstance(src.odc, ODCExtensionDs)
if src.odc.geobox is None:
raise ValueError("Can not reproject non-georegistered array.")
kw_gbox = _extract_output_geobox_params(kw)
if isinstance(how, GeoBox):
dst_geobox = how
else:
dst_geobox = src.odc.output_geobox(how, **kw_gbox)
def _maybe_reproject(dv: xarray.DataArray):
if dv.odc.geobox is None:
# pass-through data variables without a geobox
strip_coords = [str(c.name) for c in _locate_crs_coords(dv)]
if len(strip_coords) > 0:
dv = dv.drop_vars(strip_coords)
return dv
return _xr_reproject_da(
dv,
how=dst_geobox,
resampling=resampling,
dst_nodata=dst_nodata,
dtype=dtype,
**kw,
)
return src.map(_maybe_reproject)
def _xr_reproject_da(
src: Any,
how: Union[SomeCRS, GeoBox],
*,
resampling: Union[str, int] = "nearest",
dst_nodata: SomeNodata = "auto",
dtype=None,
**kw,
) -> xarray.DataArray:
# pylint: disable=too-many-locals
assert isinstance(src, xarray.DataArray)
if have.rasterio is False: # pragma: nocover
raise RuntimeError("Please install `rasterio` to use this method")
assert isinstance(src.odc, ODCExtensionDa) # for mypy sake
src_gbox = src.odc.geobox
if src_gbox is None or src_gbox.crs is None:
raise ValueError("Can not reproject non-georegistered array.")
kw_gbox = _extract_output_geobox_params(kw)
if isinstance(how, GeoBox):
dst_geobox = how
else:
dst_geobox = src.odc.output_geobox(how, **kw_gbox)
if dtype is None:
dtype = src.dtype
# compute destination shape by replacing spatial dimensions shape
ydim = src.odc.ydim
assert ydim + 1 == src.odc.xdim
dst_shape = (*src.shape[:ydim], *dst_geobox.shape, *src.shape[ydim + 2 :])
src_nodata = resolve_nodata(kw.pop("src_nodata", "auto"), src.dtype, src.odc.nodata)
dst_nodata = resolve_nodata(dst_nodata, dtype, src_nodata)
fill_value = resolve_fill_value(dst_nodata, src_nodata, dtype)
if is_dask_collection(src):
from ._dask import dask_rio_reproject
dst: Any = dask_rio_reproject(
src.data,
src_gbox,
dst_geobox,
resampling=resampling,
src_nodata=src_nodata,
dst_nodata=fill_value,
ydim=ydim,
dtype=dtype,
**kw,
)
else:
dst = numpy.full(dst_shape, fill_value, dtype=dtype)
dst = rio_reproject(
src.values,
dst,
src_gbox,
dst_geobox,
resampling=resampling,
src_nodata=src_nodata,
dst_nodata=fill_value,
ydim=ydim,
dtype=dtype,
**kw,
)
attrs = {k: v for k, v in src.attrs.items() if k not in REPROJECT_SKIP_ATTRS}
if not is_nodata_empty(dst_nodata):
assert dst_nodata is not None
attrs.update({k: maybe_int(float(dst_nodata), 1e-6) for k in NODATA_ATTRIBUTES})
# new set of coords (replace x,y dims)
# discard all coords that reference spatial dimensions
sdims = src.odc.spatial_dims
assert sdims is not None
sdims = set(sdims)
def should_keep(coord):
if _is_spatial_ref(coord):
return False
return sdims.isdisjoint(coord.dims)
coords = {k: coord for k, coord in src.coords.items() if should_keep(coord)}
coords.update(xr_coords(dst_geobox))
dims = (*src.dims[:ydim], *dst_geobox.dimensions, *src.dims[ydim + 2 :])
out = xarray.DataArray(dst, coords=coords, dims=dims, attrs=attrs)
out.encoding["grid_mapping"] = _DEFAULT_CRS_COORD_NAME
return out
[docs]
class ODCExtension:
"""
ODC extension base class.
Common accessors for both Array/Dataset.
"""
[docs]
def __init__(self, state: GeoState):
self._state = state
@property
def spatial_dims(self) -> Optional[Tuple[str, str]]:
"""Return names of spatial dimensions, or ``None``."""
return self._state.spatial_dims
@property
def transform(self) -> Optional[Affine]:
return self._state.transform
affine = transform
@property
def crs(self) -> Optional[CRS]:
"""Query :py:class:`~odc.geo.crs.CRS`."""
return self._state.crs
@property
def geobox(self) -> Optional[SomeGeoBox]:
"""Query :py:class:`~odc.geo.geobox.GeoBox` or :py:class:`~odc.geo.gcp.GCPGeoBox`."""
return self._state.geobox
@property
def aspect(self) -> float:
gbox = self._state.geobox
if gbox is None:
return 1
return gbox.aspect
[docs]
def output_geobox(self, crs: SomeCRS, **kw) -> GeoBox:
"""
Compute geobox of this data in other projection.
.. seealso:: :py:meth:`odc.geo.overlap.compute_output_geobox`
"""
gbox = self.geobox
if gbox is None:
raise ValueError("Not geo registered")
return compute_output_geobox(gbox, crs, **kw)
[docs]
def map_bounds(self) -> Tuple[Tuple[float, float], Tuple[float, float]]:
"""See :py:meth:`odc.geo.geobox.GeoBox.map_bounds`."""
gbox = self.geobox
if gbox is None:
raise ValueError("Not geo registered")
return gbox.map_bounds()
@property
def crs_coord(self) -> xarray.DataArray | None:
"""Return CRS coordinate DataArray."""
return self._state.crs_coord
@property
def grid_mapping(self) -> str | None:
"""Return name of the grid mapping coordinate."""
if c := self.crs_coord:
return str(c.name)
return None
mask = _wrap_op(mask)
crop = _wrap_op(crop)
if have.rasterio:
explore = _wrap_op(explore)
reproject = _wrap_op(xr_reproject)
[docs]
@xarray.register_dataarray_accessor("odc")
class ODCExtensionDa(ODCExtension):
"""
ODC extension for :py:class:`xarray.DataArray`.
"""
[docs]
def __init__(self, xx: xarray.DataArray):
ODCExtension.__init__(self, _locate_geo_info(xx))
self._xx = xx
@property
def uncached(self) -> "ODCExtensionDa":
return ODCExtensionDa(self._xx)
[docs]
def reload(self) -> xarray.DataArray:
"""Reload geospatial state info in-place."""
self._state = _locate_geo_info(self._xx)
return self._xx
@property
def ydim(self) -> int:
"""Index of the Y dimension."""
if (sdims := self.spatial_dims) is not None:
return self._xx.dims.index(sdims[0])
raise ValueError("Can't locate spatial dimensions")
@property
def xdim(self) -> int:
"""Index of the X dimension."""
if (sdims := self.spatial_dims) is not None:
return self._xx.dims.index(sdims[1])
raise ValueError("Can't locate spatial dimensions")
[docs]
def assign_crs(
self, crs: SomeCRS, crs_coord_name: str = _DEFAULT_CRS_COORD_NAME
) -> xarray.DataArray:
"""See :py:meth:`odc.geo.xr.assign_crs`."""
return assign_crs(self._xx, crs=crs, crs_coord_name=crs_coord_name)
@property
def nodata(self) -> Nodata:
"""Extract ``nodata/_FillValue`` attribute if set."""
attrs = self._xx.attrs
encoding = self._xx.encoding
for k in ["nodata", "_FillValue"]:
nodata = attrs.get(k, ())
if nodata == ():
nodata = encoding.get(k, ())
if nodata == ():
continue
if nodata is None:
return None
return float(nodata)
return None
@nodata.setter
def nodata(self, value: Nodata):
nodata = resolve_nodata(value, self._xx.dtype)
if nodata is None:
for k in ["nodata", "_FillValue"]:
self._xx.attrs.pop(k, None)
self._xx.encoding.pop(k, None)
return
self._xx.attrs["nodata"] = nodata
self._xx.encoding["_FillValue"] = nodata
colorize = _wrap_op(colorize)
scale_and_offset = _wrap_op(scale_and_offset)
bits_to_bool = _wrap_op(bits_to_bool)
enum_to_bool = _wrap_op(enum_to_bool)
mask_invalid_data = _wrap_op(mask_invalid_data)
if have.rasterio:
write_cog = _wrap_op(write_cog)
to_cog = _wrap_op(to_cog)
compress = _wrap_op(compress)
add_to = _wrap_op(add_to)
[docs]
@xarray.register_dataset_accessor("odc")
class ODCExtensionDs(ODCExtension):
"""
ODC extension for :py:class:`xarray.Dataset`.
"""
[docs]
def __init__(self, ds: xarray.Dataset):
ODCExtension.__init__(self, _locate_geo_info(ds))
self._xx = ds
[docs]
def reload(self) -> xarray.Dataset:
"""Reload geospatial state info in-place."""
self._state = _locate_geo_info(self._xx)
return self._xx
@property
def uncached(self) -> "ODCExtensionDs":
return ODCExtensionDs(self._xx)
def assign_crs(
self, crs: SomeCRS, crs_coord_name: str = _DEFAULT_CRS_COORD_NAME
) -> xarray.Dataset:
return assign_crs(self._xx, crs=crs, crs_coord_name=crs_coord_name)
[docs]
def to_rgba(
self,
bands: Optional[Tuple[str, str, str]] = None,
*,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
) -> xarray.DataArray:
return to_rgba(self._xx, bands=bands, vmin=vmin, vmax=vmax)
scale_and_offset = _wrap_op(scale_and_offset)
mask_invalid_data = _wrap_op(mask_invalid_data)
mask_clouds = _wrap_op(mask_clouds)
mask_ls = _wrap_op(mask_ls)
mask_s2 = _wrap_op(mask_s2)
ODCExtensionDs.to_rgba.__doc__ = to_rgba.__doc__
def _xarray_geobox(xx: XarrayObject) -> Optional[GeoBox]:
if isinstance(xx, xarray.DataArray):
return xx.odc.geobox
for dv in xx.data_vars.values():
geobox = dv.odc.geobox
if geobox is not None:
return geobox
return None
def register_geobox():
"""
Backwards compatiblity layer for datacube ``.geobox`` property.
"""
xarray.Dataset.geobox = property(_xarray_geobox) # type: ignore
xarray.DataArray.geobox = property(_xarray_geobox) # type: ignore
[docs]
def wrap_xr(
im: Any,
gbox: SomeGeoBox,
*,
time=None,
nodata: SomeNodata = "auto",
crs_coord_name: Optional[str] = _DEFAULT_CRS_COORD_NAME,
always_yx: bool = False,
dims: Optional[Tuple[str, ...]] = None,
axis: Optional[int] = None,
**attrs,
) -> xarray.DataArray:
"""
Wrap xarray around numpy array with CRS and x,y coords.
:param im: numpy array to wrap, last two axes are Y,X
:param gbox: Geobox, must same shape as last two axis of ``im``
:param time: optional time axis value(s), defaults to None
:param nodata: optional `nodata` value, defaults to None
:param crs_coord_name: allows to change name of the crs coordinate variable
:param always_yx: If True, always use names ``y,x`` for spatial coordinates
:param dims: Custom names for spatial dimensions
:param axis: Which axis of the input array corresponds to Y,X
:param attrs: Any other attributes to set on the result
:return: xarray DataArray
"""
# pylint: disable=too-many-locals,too-many-arguments
assert dims is None or len(dims) == im.ndim
if axis is None:
axis = 1 if time is not None else 0
elif axis < 0: # handle numpy style negative axis
axis = int(im.ndim) + axis
if im.ndim == 2 and axis == 1:
im = im[numpy.newaxis, ...]
assert axis >= 0
assert im.ndim - axis - 2 >= 0
assert im.shape[axis : axis + 2] == gbox.shape
def _prefix_dims(n):
if n == 0:
return ()
if n == 1:
return ("time",)
return ("time", *[f"dim_{i}" for i in range(n - 1)])
def _postfix_dims(n):
if n == 0:
return ()
if n == 1:
return ("band",)
return (f"b_{i}" for i in range(n))
sdims: Optional[Tuple[str, str]] = None
if dims is None:
sdims = ("y", "x") if always_yx else gbox.dimensions
dims = (*_prefix_dims(axis), *sdims, *_postfix_dims(im.ndim - axis - 2))
else:
sdims = dims[axis], dims[axis + 1]
prefix_dims = dims[:axis]
postfix_dims = dims[axis + 2 :]
coords = xr_coords(
gbox,
crs_coord_name=crs_coord_name,
always_yx=always_yx,
dims=sdims,
)
if time is not None:
if not isinstance(time, xarray.DataArray):
if len(prefix_dims) > 0 and isinstance(time, (str, datetime)):
time = [time]
time = xarray.DataArray(time, dims=prefix_dims[:1]).astype("datetime64[ns]")
coords["time"] = time
if postfix_dims:
for a, dim in enumerate(postfix_dims):
nb = im.shape[axis + 2 + a]
coords[dim] = xarray.DataArray(
[f"b{i}" for i in range(nb)], dims=(dim,), name=dim
)
_nodata = resolve_nodata(nodata, im.dtype)
if not is_nodata_empty(_nodata) or nodata != "auto":
attrs = {"nodata": _nodata, **attrs}
out = xarray.DataArray(im, coords=coords, dims=dims, attrs=attrs)
if crs_coord_name is not None:
out.encoding["grid_mapping"] = crs_coord_name
return out
[docs]
def xr_zeros(
geobox: SomeGeoBox,
dtype="float64",
*,
chunks: Optional[Union[Tuple[int, int], Tuple[int, int, int]]] = None,
time=None,
crs_coord_name: Optional[str] = _DEFAULT_CRS_COORD_NAME,
**kw,
) -> xarray.DataArray:
"""
Construct geo-registered xarray from a :py:class:`~odc.geo.geobox.GeoBox`.
:param gbox: Desired footprint and resolution
:param dtype: Pixel data type
:param chunks: Create a dask array instead of numpy array
:param time: When set adds time dimension
:param crs_coord_name: allows to change name of the crs coordinate variable
:return: :py:class:`xarray.DataArray` filled with zeros (numpy or dask)
.. seealso:: :py:meth:`odc.geo.xr.wrap_xr`
"""
if time is not None:
_shape: Tuple[int, ...] = (len(time), *geobox.shape.yx)
else:
_shape = geobox.shape.yx
if chunks is not None:
from dask import array as da # pylint: disable=import-outside-toplevel
return wrap_xr(
da.zeros(_shape, dtype=dtype, chunks=chunks),
geobox,
crs_coord_name=crs_coord_name,
time=time,
**kw,
)
return wrap_xr(
numpy.zeros(_shape, dtype=dtype),
geobox,
crs_coord_name=crs_coord_name,
time=time,
**kw,
)
[docs]
def rasterize(
poly: Geometry,
how: Union[float, int, Resolution, GeoBox],
*,
value_inside: bool = True,
all_touched: bool = False,
) -> xarray.DataArray:
"""
Generate raster from geometry.
This method is a wrapper for :py:meth:`rasterio.features.make_mask`.
:param poly:
Geometry shape to rasterize.
:param how:
This could be either just resolution or a GeoBox that fully defines output
raster extent/resolution/projection.
:param all_touched:
If ``True``, all pixels touched by geometries will be burned in. If
``False``, only pixels whose center is within the polygon or that
are selected by Bresenham's line algorithm will be burned in.
:param value_inside:
By default pixels inside a polygon will have value of ``True`` and ``False``
outside, but this can be flipped.
:return: geo-registered data array
"""
# pylint: disable=import-outside-toplevel
if have.rasterio is False: # pragma: nocover
raise RuntimeError("Please install `rasterio` to use this method")
from rasterio.features import geometry_mask
if isinstance(how, GeoBox):
geobox = how
else:
geobox = GeoBox.from_geopolygon(poly, resolution=how)
if poly.crs != geobox.crs and geobox.crs is not None:
poly = poly.to_crs(geobox.crs)
pix = geometry_mask(
[poly.geom],
geobox.shape,
geobox.transform,
all_touched=all_touched,
invert=value_inside,
)
return wrap_xr(pix, geobox)