"""Base classes for all :mod:`faninsar` datasets.
The base class RasterDataset in this script is modified from the torchgeo package.
"""
from __future__ import annotations
import abc
import contextlib
import functools
import re
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload
import numpy as np
import pandas as pd
import pyproj
import rasterio
import shapely
import xarray as xr
from rasterio import features, fill, plot
from rasterio import mask as rio_mask
from rasterio.crs import CRS
from rasterio.dtypes import dtype_ranges, get_minimum_dtype
from rasterio.enums import Resampling
from rasterio.transform import rowcol as tf_rowcol
from rasterio.transform import xy as tf_xy
from rasterio.vrt import WarpedVRT
from rasterio.warp import calculate_default_transform
from rasterio.warp import transform as warp_transform
from rtree.index import Index, Property
from shapely import ops
from tqdm import tqdm
from typing_extensions import Self
from faninsar._core import geo_tools
from faninsar._core.geo_tools import Profile, array2kml, array2kmz, geoinfo_from_latlon
from faninsar._core.sar.pairs import Pairs
from faninsar.logging import setup_logger
from faninsar.query import BoundingBox, GeoQuery, Points, Polygons, QueryResult
if TYPE_CHECKING:
from collections.abc import Sequence
from rasterio.io import DatasetReader
__all__ = ("ApsDataset", "GeoDataset", "PairDataset", "RasterDataset")
logger = setup_logger(
log_name="FanInSAR.datasets.base",
log_format="%(levelname)s - %(message)s",
)
lat_names = ["latitude", "lat", "y"]
lon_names = ["longitude", "lon", "x", "long", "lng"]
[docs]
class GeoDataset(abc.ABC):
"""Abstract base class for all :mod:`faninsar` datasets.
This class is used to represent a geospatial dataset and provides methods to
index the dataset and retrieve information about the dataset, such as CRS,
resolution, data type, no data value, and a bounds.
"""
# following attributes should be set by the subclass
_crs: CRS | None = None
_res: tuple[float, float] = (0.0, 0.0)
_dtype: np.dtype | None = None
_count: int = 0
_roi: BoundingBox | None = None
_nodata: Any = None
_valid: np.ndarray
[docs]
def __init__(self) -> None:
"""Initialize a new GeoDataset instance."""
self.index = Index(interleaved=True, properties=Property(dimension=2))
def __repr__(self) -> str:
"""Return a string representation of the dataset."""
return f"""\
{self.__class__.__name__} Dataset
bbox: {self.bounds}
file count: {len(self)}"""
def __str__(self) -> str:
"""Return a string representation of the dataset."""
return self.__repr__()
def __len__(self) -> int:
"""Return the number of files in the dataset.
Returns
-------
length of the dataset
"""
return len(self.index)
def __getstate__(
self,
) -> tuple[dict[str, Any], list[tuple[Any, Any, Any]]]:
"""Define how instances are pickled.
Returns
-------
the state necessary to unpickle the instance
"""
objects = self.index.intersection(self.index.bounds, objects=True)
tuples = [(item.id, item.bounds, item.object) for item in objects]
return self.__dict__, tuples
def __setstate__(
self,
state: tuple[
dict[Any, Any],
list[tuple[int, tuple[float, float, float, float, float, float], str]],
],
) -> None:
"""Define how to unpickle an instance.
Args:
----
state: the state of the instance when it was pickled
"""
attrs, tuples = state
self.__dict__.update(attrs)
for item in tuples:
self.index.insert(*item)
@overload
def _ensure_query_crs(self, query: BoundingBox) -> BoundingBox: ...
@overload
def _ensure_query_crs(self, query: Points) -> Points: ...
@overload
def _ensure_query_crs(self, query: Polygons) -> Polygons: ...
def _ensure_query_crs(
self,
query: Points | BoundingBox | Polygons,
) -> Points | BoundingBox | Polygons:
"""Ensure that the query has the same CRS as the dataset."""
if query.crs is None:
warnings.warn(
f"No CRS is specified for the {query}, assuming they are in the"
f" same CRS as the dataset ({self.crs}).",
stacklevel=2,
)
elif query.crs != self.crs:
query = query.to_crs(self.crs)
return query
@property
def crs(self) -> CRS | None:
"""Coordinate reference system (:term:`CRS`) of the dataset.
Returns
-------
The coordinate reference system (:term:`CRS`).
"""
return self._crs
@crs.setter
def crs(self, new_crs: CRS | str) -> None:
"""Change the coordinate reference system :term:`(CRS)` of a GeoDataset.
If ``new_crs == self.crs``, does nothing, otherwise updates the R-tree index.
Parameters
----------
new_crs: CRS or str
New coordinate reference system :term:`(CRS)`. It can be a CRS object
or a string, which will be parsed to a CRS object. The string can be
in any format supported by :meth:`pyproj.crs.CRS.from_user_input`.
"""
if not isinstance(new_crs, CRS):
new_crs = CRS.from_user_input(new_crs)
if new_crs == self.crs:
return
if self.crs is not None and len(self) > 0:
# update the resolution
profile = self.get_profile("bounds")
tf, *_ = calculate_default_transform(
self.crs,
new_crs,
profile["width"],
profile["height"],
self.bounds[0],
self.bounds[1],
self.bounds[2],
self.bounds[3],
)
new_res = (abs(float(tf.a)), abs(float(tf.e)))
if new_res[0] != self.res[0] or new_res[1] != self.res[1]:
msg = (
"the resolution of the dataset has been changed "
f"from {self.res} to {new_res}."
)
logger.warning(msg)
self.res = new_res
# reproject the index
new_index = Index(interleaved=True, properties=Property(dimension=2))
project = pyproj.Transformer.from_crs(
pyproj.CRS(str(self.crs)),
pyproj.CRS(str(new_crs)),
always_xy=True,
).transform
for hit in self.index.intersection(self.index.bounds, objects=True):
old_xmin, old_xmax, old_ymin, old_ymax = hit.bounds
old_box = shapely.geometry.box(old_xmin, old_ymin, old_xmax, old_ymax)
new_box = ops.transform(project, old_box)
new_bounds = tuple(new_box.bounds)
new_index.insert(hit.id, new_bounds, hit.object)
self.index = new_index
self._crs = new_crs
@property
def same_crs(self) -> bool:
"""Whether all files in the dataset have the same CRS with the desired CRS."""
return self._same_crs
@property
def res(self) -> tuple[float, float]:
"""Return the resolution of the dataset.
Returns
-------
res: tuple of floats
resolution of the dataset in x and y directions.
"""
return self._res
@res.setter
def res(self, new_res: float | tuple[float, float]) -> None:
"""Set the resolution of the dataset.
Parameters
----------
new_res : float or tuple of floats (x_res, y_res)
resolution of the dataset . If a float is given, the same resolution
will be used in both x and y directions.
"""
if isinstance(new_res, (int, float, np.integer, np.floating)):
new_res = (float(new_res), float(new_res))
if len(new_res) != 2:
msg = f"Resolution must be a float or a tuple of length 2, got {new_res}"
raise ValueError(
msg,
)
if not all(isinstance(i, float) for i in new_res):
try:
new_res = tuple(float(i) for i in new_res)
except TypeError as e:
msg = "Resolution must be a float or a tuple of floats"
raise TypeError(msg) from e
self._res = new_res
@property
def roi(self) -> BoundingBox | None:
"""Return the region of interest of the dataset.
Returns
-------
roi: BoundingBox object
region of interest of the dataset. If None, the bounds of
entire dataset will be used.
"""
if self._roi:
return self._roi
return self.bounds
@roi.setter
def roi(self, new_roi: BoundingBox) -> None:
"""Set the region of interest of the dataset.
Parameters
----------
new_roi : BoundingBox object, optional
region of interest of the dataset in the CRS of the dataset. If the
crs of the new_roi is different from the crs of the dataset, the new_roi
will be reprojected to the crs of the dataset. If None, the crs of the
dataset will be used.
"""
new_roi = self._check_roi(new_roi)
self._roi = new_roi
def _check_roi(self, roi: BoundingBox | None) -> BoundingBox:
"""Check the roi and return a valid roi.
Parameters
----------
roi : BoundingBox object, optional
region of interest of the dataset in the CRS of the dataset. If the
crs of the new_roi is different from the crs of the dataset, the new_roi
will be reprojected to the crs of the dataset. If None, the crs of the
dataset will be used.
Returns
-------
roi: BoundingBox object
region of interest of the dataset. If None, the bounds of
entire dataset will be used.
"""
if roi is None:
return self.roi
if not isinstance(roi, BoundingBox):
msg = f"roi must be a BoundingBox object, got {type(roi)} instead."
raise TypeError(
msg,
)
if roi.crs != self.crs:
if roi.crs is None:
roi = BoundingBox(*roi, crs=self.crs)
else:
roi = roi.to_crs(self.crs)
return roi
@property
def dtype(self) -> np.dtype | None:
"""Data type of the dataset.
Returns
-------
dtype: numpy.dtype object or None
data type of the dataset
"""
return self._dtype
@dtype.setter
def dtype(self, new_dtype: np.dtype) -> None:
"""Set the data type of the dataset.
Parameters
----------
new_dtype : numpy.dtype
data type of the dataset
"""
self._dtype = new_dtype
@property
def nodata(self) -> float | None:
"""No data value of the dataset.
Returns
-------
nodata: float or int
no data value of the dataset
"""
return self._nodata
@nodata.setter
def nodata(self, new_nodata: float) -> None:
"""Set the no data value of the dataset.
Parameters
----------
new_nodata : float or int
no data value of the dataset
"""
self._nodata = new_nodata
@property
def valid(self) -> np.ndarray:
"""Return a boolean array indicating which files are valid.
Returns
-------
valid: numpy.ndarray
boolean array indicating which files are valid. True means the file
is valid and can be read by rasterio, False means the file is invalid.
"""
return self._valid
@property
def bounds(self) -> BoundingBox:
"""Bounds of the overall dataset.
It is the union of all the files in the dataset.
Returns
-------
bounds: BoundingBox object
(minx, right, bottom, top) of the dataset
"""
return BoundingBox(*self.index.bounds, crs=self.crs)
@property
def shape(self) -> tuple[int, int]:
"""Shape of the dataset.
Returns
-------
shape: tuple of ints
shape of the dataset in (height, width) format
"""
profile = self.get_profile("bounds")
return profile["height"], profile["width"]
def _ensure_bbox(
self,
bbox: BoundingBox | Literal["roi", "bounds"] = "roi",
) -> BoundingBox | None:
"""Return the bounds of the dataset for the given bounding box type.
Parameters
----------
bbox : BoundingBox | Literal["roi", "bounds"], optional
the bounding box used to calculate the bounds of the dataset.
Default is 'roi'.
Returns
-------
bounds: BoundingBox | None
bounds of the dataset for the given bounding box type.
"""
if bbox == "bounds":
return self.bounds
if bbox == "roi":
return self.roi
if isinstance(bbox, BoundingBox):
return self._check_roi(bbox)
msg = f"bbox must be one of ['bounds', 'roi'] or a BoundingBox, but got {bbox}"
raise ValueError(msg)
[docs]
class RasterDataset(GeoDataset):
"""A base class for raster datasets.
Examples
--------
>>> from pathlib import Path
>>> from faninsar.datasets import RasterDataset
>>> from faninsar.query import BoundingBox, GeoQuery, Points,
>>> home_dir = Path("./work/data")
>>> files = list(home_dir.rglob("*unw_phase.tif"))
initialize a RasterDataset and GeoQuery object
>>> ds = RasterDataset(paths=files)
>>> points = Points(
[(490357, 4283413),
(491048, 4283411),
(490317, 4284829)]
)
>>> query = GeoQuery(points=points, boxes=[ds.bounds, ds.bounds])
use the GeoQuery object to index the RasterDataset
>>> sample = ds[query]
output the samples shapes:
>>> print("boxes result shape:", sample.boxes.data.shape)
boxes result shape: (2, 7, 68, 80)
>>> print("points result shape:", sample.points.data.shape)
points result shape: (7, 3)
of course, you can also use the BoundingBox or Points directly to index the
RasterDataset. Those two types will be automatically converted to GeoQuery
object.
>>> sample = ds[points]
>>> sample
{'query': GeoQuery(
boxes=None
points=Points(count=3)
),
'boxes': None,
'points': array([...], dtype=float32)}
>>> sample = ds[ds.bounds]
query': GeoQuery(
boxes=[1 BoundingBox]
points=None
),
'boxes': array([...], dtype=float32),
'points': None}
"""
#: Glob expression used to search for files.
#:
#: This expression should be specific enough that it will not pick up files from
#: other datasets. It should not include a file extension, as the dataset may be in
#: a different file format than what it was originally downloaded as.
pattern = "*"
#: When :attr:`~RasterDataset.separate_files` is True, the following additional
#: groups are searched for to find other files:
#:
#: * ``band``: replaced with requested band name
filename_regex = ".*"
#: Date format string used to parse date from filename.
#:
#: Not used if :attr:`filename_regex` does not contain a ``date`` group.
date_format = "%Y%m%d"
#: Names of all available bands in the dataset
all_bands: ClassVar[list[str]] = []
#: Names of RGB bands in the dataset, used for plotting
rgb_bands: ClassVar[list[str]] = []
#: Color map for the dataset, used for plotting
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {}
[docs]
def __init__( # noqa: PLR0912, PLR0915
self,
root_dir: str = "data",
paths: Sequence[str] | None = None,
crs: CRS | None = None,
res: float | tuple[float, float] | None = None,
dtype: np.dtype | None = None,
nodata: float | None = None,
roi: BoundingBox | None = None,
bands: Sequence[str] | None = None,
cache: bool = True,
resampling: Resampling = Resampling.nearest,
fill_nodata: bool = False,
verbose: bool = True,
ds_name: str = "",
) -> None:
"""Initialize a new raster dataset instance.
Parameters
----------
root_dir : str or Path
root_dir directory where dataset can be found.
paths : list of str, optional
list of file paths to use instead of searching for files in ``root_dir``.
If None, files will be searched for in ``root_dir``.
crs : CRS, optional
the output term:`coordinate reference system (CRS)` of the dataset.
If None, the CRS of the first file found will be used.
res : float, optional
resolution of the output dataset in units of CRS. If None, the resolution
of the first file found will be used.
dtype : numpy.dtype, optional
data type of the output dataset. If None, the data type of the first file
found will be used.
nodata : float or int, optional
no data value of the dataset. If None, the no data value of the first
file found will be used. This parameter is useful when the no data value
is not stored in the file.
roi : BoundingBox, optional
region of interest to load from the dataset. If None, the union of all files
bounds in the dataset will be used.
bands : list of str, optional
names of bands to return (defaults to all bands)
cache : bool, optional
if True, cache file handle to speed up repeated sampling
resampling : Resampling, optional
Resampling algorithm used when reading input files.
Default: `Resampling.nearest`.
fill_nodata : bool, optional
Whether to fill holes in the queried data by interpolating them using
inverse distance weighting method provided by the
:func:`rasterio.fill.fillnodata`. Default: False.
.. note::
This parameter is only used when sampling data using bounding
boxes or polygons queries, and will not work for points queries.
verbose : bool, optional
if True, print verbose output, default: True
ds_name : str, optional
name of the dataset. used for printing verbose output, default: ""
Raises
------
FileNotFoundError: if no files are found in ``root_dir``
"""
super().__init__()
self.root_dir = Path(root_dir)
self.bands = bands or self.all_bands
self.cache = cache
self.resampling = resampling
self.fill_nodata = fill_nodata
self.verbose = verbose
self.ds_name = ds_name
if paths is None:
paths = []
filename_regex = re.compile(self.filename_regex, re.VERBOSE)
for file_path in sorted(self.root_dir.rglob(self.pattern)):
match = re.match(filename_regex, file_path.name)
if match is not None:
paths.append(file_path)
else:
paths = [Path(path) for path in paths]
# Populate the dataset index
count = 0
files_valid = []
self._same_crs = True
for file_path in paths:
try:
with rasterio.open(file_path) as src:
# See if file has a color map
if len(self.cmap) == 0:
with contextlib.suppress(ValueError):
self.cmap = src.colormap(1)
if crs is None:
crs = src.crs
if dtype is None:
dtype = src.dtypes[0]
if nodata is None:
nodata = src.nodata
with WarpedVRT(src, crs=crs) as vrt:
bounds = tuple(vrt.bounds)
if res is None:
res = vrt.res
if crs != src.crs:
self._same_crs = False
except Exception as e: # noqa: PERF203
# Skip files that rasterio is unable to read
warnings.warn(f"Unable to read {file_path}: \n--> : {e}", stacklevel=2)
files_valid.append(False)
continue
else:
self.index.insert(count, bounds, file_path)
files_valid.append(True)
count += 1
if count == 0:
msg = (
f"No {self.__class__.__name__} data was found in "
f"`root_dir='{self.root_dir}'`"
)
if self.bands:
msg += f" with `bands={self.bands}`"
raise FileNotFoundError(msg)
self._files = pd.DataFrame({"paths": paths, "valid": files_valid})
self._valid = np.array(files_valid)
if not self._files.valid.all():
files_invalid = [str(i) for i in self._files.paths[~self._files.valid]]
files_invalid_str = "\t" + "\n\t".join(files_invalid)
msg = (
f"Unable to read {len(files_invalid)} files in "
f"{self.__class__.__name__} dataset:\n{files_invalid_str}",
)
warnings.warn(msg, stacklevel=2)
self.band_indexes = None
if self.bands:
if self.all_bands:
self.band_indexes = [self.all_bands.index(i) + 1 for i in self.bands]
else:
msg = (
f"{self.__class__.__name__} is missing an `all_bands` "
"attribute, so `bands` cannot be specified."
)
raise AssertionError(msg)
self.crs = crs
self.res = res
self.dtype = dtype
self.nodata = nodata
self.count = count
self.roi = roi
def __getitem__(
self,
query: GeoQuery | Points | BoundingBox | Polygons,
) -> QueryResult:
"""Retrieve images values for given query.
Parameters
----------
query : GeoQuery | Points | BoundingBox | Polygons
query to index the dataset. It can be :class:`Points`,
:class:`BoundingBox`, :class:`Polygons`, or a composite
:class:`GeoQuery` (recommended) object.
Returns
-------
result : QueryResult
a QueryResult instance containing the results of the various queries.
"""
if isinstance(query, Points):
query = GeoQuery(points=query)
if isinstance(query, BoundingBox):
query = GeoQuery(boxes=query)
if isinstance(query, Polygons):
query = GeoQuery(polygons=query)
paths = self.files[self.files.valid].paths
return self._sample_files(paths, query)
def _ensure_bands_idx(self, vrt_fh: DatasetReader) -> list[int] | int:
"""Return the proper band indexes to use for the dataset.
The band indexes is a list of integers if multiple bands are requested,
otherwise it is an integer.
"""
bands = self.band_indexes or vrt_fh.indexes
# If only one band is requested, return a 2D array
if len(bands) == 1:
bands = bands[0]
return bands
def _ensure_dtype(self, data: np.ndarray) -> np.ndarray:
"""Ensure that the data has the same dtype as the dataset."""
if data.dtype != self.dtype:
data = data.astype(self.dtype)
return data
def _points_query(self, points: Points, vrt_fh: DatasetReader) -> np.ndarray:
"""Return the values of dataset at given points.
Points that outside the dataset will be masked.
"""
points = self._ensure_query_crs(points)
bands_idx = self._ensure_bands_idx(vrt_fh)
data = np.ma.hstack(list(vrt_fh.sample(points.values, bands_idx, masked=True)))
return self._ensure_dtype(data)
def _bbox_query(self, bbox: BoundingBox, vrt_fh: DatasetReader) -> np.ndarray:
"""Return the values of the dataset at the given bounding box."""
bbox = self._ensure_query_crs(bbox)
win = vrt_fh.window(*bbox)
bands_idx = self._ensure_bands_idx(vrt_fh)
out_shape = [
round((bbox.top - bbox.bottom) / self.res[1]),
round((bbox.right - bbox.left) / self.res[0]),
]
if isinstance(bands_idx, list):
out_shape.insert(0, len(bands_idx))
data = vrt_fh.read(
out_shape=tuple(out_shape),
resampling=self.resampling,
indexes=bands_idx,
window=win,
masked=True,
boundless=self.same_crs,
# WarpedVRT not supports boundless: https://github.com/rasterio/rasterio/issues/2084
)
if data.mask.ndim == 0:
data = np.ma.masked_array(data.data, data == self.nodata)
if self.fill_nodata:
data = fill.fillnodata(data)
return self._ensure_dtype(data)
def _polygons_query(self, polygons: Polygons, vrt_fh: DatasetReader) -> np.ndarray:
"""Return the values of the dataset at the given polygons."""
polygons = self._ensure_query_crs(polygons)
bands_idx = self._ensure_bands_idx(vrt_fh)
mask_params = {
"filled": False,
"pad": polygons.pad,
"all_touched": polygons.all_touched,
"invert": False,
"crop": True,
"indexes": bands_idx,
}
rasterize_params = {
"all_touched": polygons.all_touched,
"fill": 0,
"default_value": 1,
}
shapes = polygons.frame.geometry.to_list()
if len(polygons.desired) > 0:
data_ls = []
transform_ls = []
mask_ls = []
for shp in shapes:
data, out_transform = rio_mask.mask(vrt_fh, [shp], **mask_params)
rasterize_params.update(
{
"out_shape": data.shape if data.ndim == 2 else data.shape[1:3],
"transform": out_transform,
},
)
mask = features.rasterize([shp], **rasterize_params).astype(bool)
if self.fill_nodata:
data = fill.fillnodata(data)
data = np.ma.masked_array(data.data, ~mask)
data_ls.append(self._ensure_dtype(data))
transform_ls.append(out_transform)
mask_ls.append(mask)
else:
mask_params.update({"invert": True, "crop": False})
data, out_transform = rio_mask.mask(vrt_fh, shapes, **mask_params)
rasterize_params.update(
{"out_shape": data.shape[1:3], "transform": out_transform},
)
mask = features.rasterize(shapes, **rasterize_params).astype(bool)
if self.fill_nodata:
data = fill.fillnodata(data)
data = np.ma.masked_array(data.data, ~mask)
data_ls = [self._ensure_dtype(data)]
transform_ls = [out_transform]
mask_ls = [mask]
return data_ls, transform_ls, mask_ls
def _ensure_loading_verbose(self, sequence: Sequence) -> Sequence:
if self.verbose:
sequence = tqdm(
sequence,
desc=f"Loading {self.ds_name} Files",
unit=" files",
)
return sequence
def _ensure_saving_verbose(
self,
sequence: Sequence,
ds_name: str,
unit: str = " files",
) -> Sequence:
if self.verbose:
sequence = tqdm(sequence, desc=f"Saving {ds_name} Files", unit=unit)
return sequence
def _safe_close(self, vrt_fhs: DatasetReader) -> None:
"""Close the file handles if not caching."""
if not self.cache:
for vrt_fh in vrt_fhs:
vrt_fh.close()
def _sample_files(self, paths: Sequence[str], query: GeoQuery) -> QueryResult: # noqa: PLR0912, PLR0915
"""Sample or retrieve values from the dataset for the given query.
Parameters
----------
paths : list of str
list of paths for files to stack
query : GeoQuery
a GeoQuery instance containing the desired queries.
Returns
-------
result : QueryResult
a QueryResult instance containing the results of the various queries.
"""
if self.cache:
vrt_fhs = [self._cached_load_warp_file(fp) for fp in paths]
else:
vrt_fhs = [self._load_warp_file(fp) for fp in paths]
vrt_fhs = self._ensure_loading_verbose(vrt_fhs)
files_points_list = []
files_bbox_list = []
files_polygons_list = []
for vrt_fh in vrt_fhs:
# Get the points values
if query.points is not None:
data = self._points_query(query.points, vrt_fh)
files_points_list.append(data)
# Get the bounding boxes values
if query.boxes is not None:
n_boxes = len(query.boxes)
if n_boxes == 1:
data = self._bbox_query(query.boxes[0], vrt_fh)
files_bbox_list.append(data)
else:
bbox_list = []
for bbox in query.boxes:
data = self._bbox_query(bbox, vrt_fh)
bbox_list.append(data)
files_bbox_list.append(bbox_list)
# Get the polygons values
if query.polygons is not None:
data_ls, transform_ls, mask_ls = self._polygons_query(
query.polygons,
vrt_fh,
)
files_polygons_list.append(data_ls)
# Stack the points values
points_result = None
if len(files_points_list) > 0:
points_values = np.ma.asarray(files_points_list)
dims, points_values = parse_1d_dims(points_values)
points_result = {"data": points_values, "dims": dims}
# parse bounding boxes results
bbox_result = None
n_files = len(files_bbox_list)
if n_files > 0:
n_boxes = len(query.boxes)
if n_boxes == 1:
boxes_values = np.ma.asarray(files_bbox_list)
dims = parse_2d_dims(boxes_values)
else:
# stack the files for each box
boxes_ls = [[] for _ in range(n_boxes)]
for files_box in files_bbox_list:
for i, box in enumerate(files_box):
boxes_ls[i].append(box)
boxes_values = [np.ma.asarray(arr) for arr in boxes_ls]
# get the dims
bbox0 = files_bbox_list[0]
dims = parse_2d_dims(bbox0, details=False)
dims = f"boxes:{n_boxes}, ({dims})"
bbox_result = {"data": boxes_values, "dims": f"({dims})"}
# parse polygons results
polygons_result = None
n_files = len(files_polygons_list)
if n_files > 0:
n_polygons = len(query.polygons)
if n_polygons == 1:
polygons_values = np.ma.asarray(files_polygons_list)
dims = parse_2d_dims(polygons_values)
else:
# stack the files for each polygon
poly_list = [[] for _ in range(n_polygons)]
for file_data in files_polygons_list:
for i, poly_i in enumerate(file_data):
poly_list[i].append(poly_i)
polygons_values = [np.ma.asarray(arr) for arr in poly_list]
# get the dims
polygon0 = polygons_values[0]
dims = parse_2d_dims(polygon0, details=False)
polygons_result = {
"data": polygons_values,
"dims": f"(n_polygons:{n_polygons}, ({dims}))",
"transforms": transform_ls,
"masks": mask_ls,
}
return QueryResult(points_result, bbox_result, polygons_result, query)
@functools.lru_cache(maxsize=128) # noqa: B019
def _cached_load_warp_file(self, file_path: str) -> DatasetReader:
"""Return cached version of :meth:`_load_warp_file`.
Parameters
----------
file_path: str
file to load and warp
Returns
-------
file handle of warped VRT
"""
return self._load_warp_file(file_path)
def _load_warp_file(self, file_path: str) -> DatasetReader:
"""Load and warp a file to the correct CRS and resolution.
Args:
----
file_path: file to load and warp
Returns:
-------
file handle of warped VRT
"""
src = rasterio.open(file_path)
# Only warp if necessary
if src.crs != self.crs:
vrt = WarpedVRT(src, crs=self.crs)
src.close()
return vrt
return src
@property
def count(self) -> int:
"""Number of valid files in the dataset.
.. Note::
This is different from the length of the dataset ``len(GeoDataset)``,
which is the total number of files in the dataset, including invalid
files that cannot be read by rasterio.
Returns
-------
count: int
number of valid files in the dataset
"""
return self._count
@count.setter
def count(self, new_count: int) -> None:
"""Set the number of files in the dataset.
Parameters
----------
new_count : int
number of files in the dataset
"""
self._count = int(new_count)
@property
def files(self) -> pd.DataFrame:
"""Return a list of all files in the dataset.
Returns
-------
list of all files in the dataset
"""
return self._files
[docs]
def get_profile(
self,
bbox: BoundingBox | Literal["roi", "bounds"] = "roi",
) -> Profile | None:
"""Get profile information of dataset for the given bounding box type."""
bbox = self._ensure_bbox(bbox)
if bbox is None:
return None
profile = Profile.from_bounds_res(bbox, self.res)
profile["count"] = self.count
profile["dtype"] = self.dtype
profile["nodata"] = self.nodata
profile["crs"] = self.crs
return profile
[docs]
def row_col(
self,
xy: Sequence,
crs: CRS | str | None = None,
bbox: BoundingBox | Literal["roi", "bounds"] = "roi",
) -> np.ndarray:
"""Convert x, y coordinates to row, col in the dataset.
Parameters
----------
xy: Sequence
Pairs of x, y coordinates (floats)
crs: CRS or str, optional
The CRS of the points. If None, the CRS of the dataset will be used.
allowed CRS formats are the same as those supported by rasterio.
bbox : str, one of ['bounds', 'roi'], optional
the bounding box used to calculate the ``width``, ``height``
and ``transform`` of the dataset for the profile. Default is 'roi'.
Returns
-------
row_col: np.ndarray
row, col in the dataset for the given points(xy)
"""
xy = np.asarray(xy)
if xy.ndim == 1:
xy = xy.reshape(1, -1)
if xy.ndim != 2 or xy.shape[1] != 2:
msg = f"Expected xy to be an array of shape (n, 2), got {xy.shape}"
raise ValueError(
msg,
)
if crs is not None:
crs = CRS.from_user_input(crs)
if crs != self.crs:
xs, ys = warp_transform(crs, self.crs, xy[:, 0], xy[:, 1])
xy = np.column_stack((xs, ys))
profile = self.get_profile(bbox)
rows, cols = tf_rowcol(profile["transform"], xy[:, 0], xy[:, 1])
return np.column_stack((rows, cols)).astype(np.int64)
[docs]
def xy(
self,
row_col: Sequence,
crs: CRS | str | None = None,
bbox: BoundingBox | Literal["roi", "bounds"] = "roi",
) -> np.ndarray:
"""Convert row, col in the dataset to x, y coordinates.
Parameters
----------
row_col: Sequence
Pairs of row, col in the dataset (floats)
crs: CRS or str, optional
The CRS of output points. If None, the CRS of the dataset will be used.
Can be any of the formats supported by :meth:`pyproj.CRS.from_user_input`.
bbox : str, one of ['bounds', 'roi'], optional
the bounding box used to calculate the ``width``, ``height``
and ``transform`` of the dataset for the profile. Default is 'roi'.
Returns
-------
xy: np.ndarray
x, y coordinates in the given CRS (default is the CRS of the dataset)
"""
row_col = np.asarray(row_col)
if row_col.ndim == 1:
row_col = row_col.reshape(1, -1)
if row_col.ndim != 2 or row_col.shape[1] != 2:
msg = (
f"Expected row_col to be an array of shape (n, 2), got {row_col.shape}"
)
raise ValueError(
msg,
)
profile = self.get_profile(bbox)
xs, ys = tf_xy(profile["transform"], row_col[:, 0], row_col[:, 1])
if crs is not None:
crs = CRS.from_user_input(crs)
if crs != self.crs:
xs, ys = warp_transform(self.crs, crs, xs, ys)
return np.column_stack((xs, ys))
[docs]
def parse_mask(
self,
percent: float,
bbox: BoundingBox | Literal["roi", "bounds"] = "roi",
seed: int = 0,
) -> np.ndarray:
"""Parse the mask of the dataset.
The mask is a boolean array where True indicates valid data and False
indicates invalid data, which keeps in line with the GDAL/rasterio strategy.
Parameters
----------
percent : float
Percentage (0,1] of files to be used for parsing the mask. The files are
randomly selected.
bbox : str, one of ['bounds', 'roi'], optional
the desired region of mask. Default is 'roi'.
seed : int, optional
Seed for the random number generator. Default is 0.
"""
# randomly select a subset of files
idx_all = np.arange(self.count)
rng = np.random.default_rng(seed)
idx = rng.choice(idx_all, int(percent * self.count), replace=False)
paths = self.files.paths[self.valid].values[idx]
# get the profile of the dataset
profile = self.get_profile(bbox)
width, height = profile["width"], profile["height"]
mask = np.ones((height, width), dtype=bool)
if self.verbose:
paths = tqdm(paths, desc="Parsing Mask", unit=" files")
for path in paths:
with rasterio.open(path) as src:
bbox = self._ensure_bbox(bbox)
win = None if bbox is None else src.window(*bbox)
mask &= src.read(1, masked=True, window=win).mask
return ~mask
[docs]
def load_mask(
self,
mask_path: str | Path,
bbox: BoundingBox | Literal["roi", "bounds"] = "roi",
) -> np.ndarray:
"""Load a mask from a tiff mask file (.msk).
Parameters
----------
mask_path : str or Path
path to the mask file of tiff format (.msk)
bbox : str, one of ['bounds', 'roi'], optional
the desired region of mask. Default is 'roi'.
"""
bbox = self._ensure_bbox(bbox)
profile = self.get_profile(self.bounds)
with rasterio.open(mask_path) as src:
mask = src.read(1)
if profile["width"] != mask.shape[1] or profile["height"] != mask.shape[0]:
msg = (
f"The shape of the mask {mask.shape} does not match the shape "
f"of the dataset {(profile['width'], profile['height'])}."
)
raise ValueError(
msg,
)
# crop the mask to the desired region
with rasterio.open(self.files.paths[self.valid].values[0]) as src:
win = src.window(*bbox)
return mask[win[0] : win[1], win[2] : win[3]]
[docs]
def reproject(
self,
new_crs: CRS | str,
resampling: Resampling = Resampling.nearest,
nodata: float | None = None,
) -> Self:
"""Reproject the dataset to a new CRS.
Parameters
----------
new_crs : CRS or str
new coordinate reference system (:term:`CRS`) of the dataset.
It can be a CRS object or a string, which will be parsed to a
CRS object. The string can be in any format supported by
:meth:`pyproj.crs.CRS.from_user_input`.
resampling : Resampling, optional
resampling method to use when reprojecting the dataset.
Default is `Resampling.nearest`.
nodata : float or int, optional
no data value of the dataset. If None, the no data value of the
dataset will be used.
"""
if not isinstance(new_crs, CRS):
new_crs = CRS.from_user_input(new_crs)
if new_crs == self.crs:
return self
if nodata is None:
nodata = self.nodata
new_bounds: BoundingBox = self.bounds.to_crs(new_crs)
new_res = (
abs(new_bounds.right - new_bounds.left) / self.shape[1],
abs(new_bounds.top - new_bounds.bottom) / self.shape[0],
)
return self.__class__(
root_dir=self.root_dir,
paths=self.files.paths,
crs=new_crs,
res=new_res,
dtype=self.dtype,
nodata=nodata,
roi=new_bounds,
bands=self.bands,
cache=self.cache,
resampling=resampling,
fill_nodata=self.fill_nodata,
verbose=self.verbose,
ds_name=self.ds_name,
)
[docs]
def resample(
self,
new_res: float | tuple[float, float],
resampling: Resampling = Resampling.nearest,
nodata: float | None = None,
) -> Self:
"""Resample the dataset to a new resolution.
Parameters
----------
new_res : float or tuple of float
new resolution of the dataset in units of CRS. If a single float is
provided, it will be used for both x and y dimensions.
resampling : Resampling, optional
resampling method to use when resampling the dataset.
Default is `Resampling.nearest`.
nodata : float or int, optional
no data value of the dataset. If None, the no data value of the
dataset will be used.
"""
if nodata is None:
nodata = self.nodata
return self.__class__(
root_dir=self.root_dir,
paths=self.files.paths,
crs=self.crs,
res=new_res,
dtype=self.dtype,
nodata=nodata,
roi=self.bounds,
bands=self.bands,
cache=self.cache,
resampling=resampling,
fill_nodata=self.fill_nodata,
verbose=self.verbose,
ds_name=self.ds_name,
)
[docs]
def show(
self,
arr: np.ndarray,
**kwargs,
) -> Self:
"""Show the array using the dataset's geo information.
Parameters
----------
arr : np.ndarray
The array with same shape as the dataset to show. The geo information
of the dataset will be used to plot the array.
kwargs : key value pairs, optional
Additional keyword arguments to pass to the :func:`rasterio.plot.show`
function.
"""
if kwargs is None:
kwargs = {}
if "transform" in kwargs:
msg = (
"show() function does not support `transform` argument, since "
"the `transform` of the dataset will be used to plot the array."
)
warnings.warn(msg, stacklevel=2)
kwargs["transform"] = self.get_profile().transform
plot.show(arr, **kwargs)
[docs]
def to_tiffs(
self,
out_dir: str | Path,
roi: BoundingBox | None = None,
) -> None:
"""Save the dataset to a directory of tiff files for given region of interest.
Parameters
----------
out_dir : str or Path
path to the directory to save the tiff files
roi : BoundingBox, optional
region of interest to save. If None, the roi of the dataset will be used.
"""
roi = self._check_roi(roi)
profile = self.get_profile(roi)
profile["count"] = 1
for f in self.files.paths[self.valid]:
out_file = Path(out_dir) / f.name
src = self._load_warp_file(f)
dest_arr = self._bbox_query(roi, src).squeeze(0)
with rasterio.open(out_file, "w", **profile.profile) as dst:
dst.write(dest_arr, 1)
[docs]
def to_netcdf(
self,
filename: str | Path,
roi: BoundingBox | None = None,
) -> None:
"""Save the dataset to a netCDF file for given region of interest.
Parameters
----------
filename : str
path to the netCDF file to save
roi : BoundingBox, optional
region of interest to save. If None, the roi of the dataset will be used.
"""
if roi is None:
roi = self.roi
profile = self.get_profile(roi)
lat, lon = profile.to_latlon()
sample = self[roi]
ds = xr.Dataset(
{"image": (["band", "lat", "lon"], sample.boxes.data)},
coords={
"band": list(range(profile["count"])),
"lat": lat,
"lon": lon,
},
)
ds = geo_tools.write_geoinfo_into_ds(
ds,
"image",
crs=self.crs,
x_dim="lon",
y_dim="lat",
)
ds.to_netcdf(filename)
[docs]
def array2tiff(
self,
arr: np.ndarray,
filename: str | Path,
bounds: BoundingBox | None = None,
bbox: BoundingBox | None = None,
band_names: Sequence[str] | None = None,
arr_type: Literal["data", "mask"] = "data",
nodata: float | None = None,
overwrite: bool = False,
) -> None:
"""Save a numpy array to a tiff file using the geoinformation of dataset.
Parameters
----------
arr : numpy.ndarray
numpy array to save. arr can be a 2D array or a 3D array. If arr is a
3D array, the first dimension should be the band dimension.
filename : str or Path
path to the tiff file to save
bounds : BoundingBox, optional
the bounds of the arr. Default is None, which means the roi of the
dataset will be used.
bbox : BoundingBox, optional
if specified, the input array will be saved to the given part/bbox of
dataset. Default is None, which means the array will be saved to the
entire dataset.
band_names : Sequence of str, optional
names of bands to save. Default is None, which will use the band indexes.
arr_type : str, one of ['data', 'mask'], optional
type of the array to save. Default is 'data'.
nodata : float or int, optional
no data value of the dataset. If None, will automatically parse the
a proper no data value for the array.
overwrite : bool, optional
if True, overwrite the existing file. Default is False, which means
the array will be saved in append mode (r+ mode).
"""
# check arr dimension
if arr.ndim == 2:
indexes = [1]
arr = arr[np.newaxis, :, :]
elif arr.ndim == 3:
indexes = [i + 1 for i in range(arr.shape[0])]
else:
msg = (
f"Expected arr to be an array with shape of (n_lat, n_lon) or "
f"(n_band, n_lat, n_lon), got {arr.shape}"
)
raise ValueError(msg)
# check length of band_names
if band_names is not None and len(band_names) != arr.shape[0]:
msg = (
f"Expected band_names to be of length {arr.shape[0]}, "
f"got {len(band_names)}"
)
raise ValueError(msg)
# parse profile
if bounds is None:
bounds = self.roi
profile = self.get_profile(bounds)
profile["count"] = arr.shape[0]
profile["driver"] = "GTiff"
profile["dtype"] = get_minimum_dtype(arr)
profile["nodata"] = get_nodata(arr, nodata, profile["dtype"])
mode = "w"
filename = Path(filename)
if filename.exists() and not overwrite:
mode = "r+"
with rasterio.open(filename, mode, **profile.to_dict()) as dst:
# parse window
win = None if bbox is None else dst.window(*bbox)
# write array to tiff
if arr_type == "mask":
if arr.shape[0] == 1:
arr = arr[0]
dst.write_mask(arr)
elif arr_type == "data":
dst.write(arr, indexes, window=win)
# update band names
if band_names is not None:
dst.descriptions = band_names
band_names_str = ";".join(band_names)
band_names_file = filename.with_suffix(".band_name.txt")
with band_names_file.open("w") as f:
f.write(band_names_str)
[docs]
def array2kml(
self,
arr: np.ndarray,
out_file: str | Path,
bounds: BoundingBox | None = None,
img_kwargs: dict | None = None,
cbar_kwargs: dict | None = None,
verbose: bool = True,
) -> None:
"""Write a numpy array into a kml file.
Parameters
----------
arr: numpy.ndarray
the numpy array to be written into kml file.
out_file: str or Path
the path of the kml file.
bounds : BoundingBox, optional
the bounds of the arr. Default is None, which means the roi of the
dataset will be used.
img_kwargs: dict
the keyword arguments for :func:`matplotlib.pyplot.imshow` function.
cbar_kwargs: dict
the keyword arguments for :func:`save_colorbar` function, except for
the out_file and mappable argument.
verbose: bool
whether to print the information of the kml file. Default is verbose.
"""
if cbar_kwargs is None:
cbar_kwargs = {}
if img_kwargs is None:
img_kwargs = {}
if bounds is None:
bounds = self.roi
wgs84 = CRS.from_epsg(4326)
if self.crs != wgs84:
profile = self.get_profile(bounds)
lat, lon = profile.to_latlon()
dtype = get_minimum_dtype(arr)
nodata = get_nodata(arr, None, dtype)
da = xr.DataArray(arr, coords=[lat, lon], dims=["y", "x"])
da.rio.set_spatial_dims("x", "y", inplace=True)
da.rio.write_crs(self.crs, inplace=True)
da = da.rio.reproject(wgs84, nodata=nodata)
# update arr and bounds
arr = da.values
bounds, *_ = geoinfo_from_latlon(da.y, da.x)
bounds.set_crs(wgs84)
array2kml(arr, out_file, bounds, img_kwargs, cbar_kwargs, verbose)
[docs]
def array2kmz(
self,
arr: np.ndarray,
out_file: str | Path,
bounds: BoundingBox | None = None,
img_kwargs: dict | None = None,
cbar_kwargs: dict | None = None,
keep_kml: bool = False,
verbose: bool = True,
) -> None:
"""Write a numpy array into a kmz file.
Parameters
----------
arr: numpy.ndarray
the numpy array to be written into kmz file.
out_file: str or Path
the path of the kmz file.
bounds : BoundingBox, optional
the bounds of the arr. Default is None, which means the roi of the
dataset will be used.
img_kwargs: dict
the keyword arguments for :func:`matplotlib.pyplot.imshow` function.
cbar_kwargs: dict
the keyword arguments for :func:`save_colorbar` function, except for
the out_file and mappable argument.
keep_kml: bool
whether to keep the kml file. Default is False.
verbose: bool
whether to print the information of the kmz file. Default is verbose.
"""
if cbar_kwargs is None:
cbar_kwargs = {}
if img_kwargs is None:
img_kwargs = {}
if bounds is None:
bounds = self.roi
wgs84 = CRS.from_epsg(4326)
if self.crs != wgs84:
profile = self.get_profile(bounds)
lat, lon = profile.to_latlon()
dtype = get_minimum_dtype(arr)
nodata = get_nodata(arr, None, dtype)
da = xr.DataArray(arr, coords=[lat, lon], dims=["y", "x"])
da.rio.set_spatial_dims("x", "y", inplace=True)
da.rio.write_crs(self.crs, inplace=True)
da = da.rio.reproject(wgs84, nodata=nodata)
# update arr and bounds
arr = da.values
bounds, *_ = geoinfo_from_latlon(da.y, da.x)
bounds.set_crs(wgs84)
array2kmz(arr, out_file, bounds, img_kwargs, cbar_kwargs, keep_kml, verbose)
class HierarchicalDataset(GeoDataset):
"""A base class for hierarchical dataset, like h5 and nc files.
.. note::
This class is used to load and sample data from a single file. If you
want to load and sample data from multiple files, you should use
:class:`MultiHierarchicalDataset`.
"""
lat_name: str = "lat"
lon_name: str = "lon"
def __init__(
self,
path: str | Path,
group: str | None = None,
roi: BoundingBox | None = None,
) -> None:
super().__init__()
self._path = Path(path)
self._group = group
self._roi = roi
self._update_geo_info()
warnings.warn(
"HierarchicalDataset is still in development and may not work as expected.",
stacklevel=2,
)
def __repr__(self) -> str:
return self._repr_str
def _update_geo_info(self) -> None:
bound, res, shape, crs, ds_info = self._parse_geo_info(self._path)
self._bound = bound
self._res = res
self._crs = crs
self._shape = shape
self._lat = ds_info[0]
self._lon = ds_info[1]
self._variables = ds_info[2]
self._repr_str = ds_info[3]
def _parse_lat_lon_name(self, ds: xr.Dataset) -> tuple[str, str]:
"""Parse the name of the latitude and longitude variables."""
lat_name = None
lon_name = None
if self.lat_name in ds.variables and self.lon_name in ds.variables:
return None
for name in ds.variables:
if name.lower() in lat_names:
lat_name = name
if name.lower() in lon_names:
lon_name = name
if lat_name is None or lon_name is None:
msg = (
"The dataset does not contain latitude and longitude variables. "
"Please specify the names of the latitude and longitude variables."
)
raise ValueError(
msg,
)
return lat_name, lon_name
def _parse_geo_info(
self,
path: str | Path,
) -> tuple[BoundingBox, tuple[float, float], tuple[int, int], CRS]:
"""Parse the geoinformation of the dataset."""
with xr.open_dataset(path) as ds:
coord_names = self._parse_lat_lon_name(ds)
if coord_names is not None:
self.lat_name, self.lon_name = coord_names
repr_str = ds.__repr__()
variables = list(ds.variables)
lat = ds[self.lat_name].values
lon = ds[self.lon_name].values
crs = ds.rio.crs
# parse geo-information
if crs is None:
if (
np.all(lat >= -90)
and np.all(lat <= 90)
and np.all(lon >= -180)
and np.all(lon <= 180)
):
warnings.warn(
"No CRS is specified for the dataset, assuming the lat/lon values "
"are in the range of WGS84.",
stacklevel=2,
)
crs = CRS.from_epsg(4326)
else:
msg = (
"No CRS is specified for the dataset, and the lat/lon values are "
"not in the range of WGS84. Please specify the CRS of the dataset"
"using the :meth:`set_crs` method later."
)
raise ValueError(
msg,
)
else:
crs = CRS.from_user_input(ds.rio.crs)
# parse bound, resolution, shape
bound, res, shape = geoinfo_from_latlon(lat, lon)
bound.set_crs(crs)
return bound, res, shape, crs, (lat, lon, variables, repr_str)
def __getitem__(self, var: str) -> xr.DataArray | xr.Dataset:
"""Get the variable from the dataset."""
with xr.open_dataset(self.path, group=self.group) as ds:
return ds[var]
def flush_geo_info(self) -> None:
"""Flush the geoinformation of the dataset to the given file."""
with xr.open_dataset(self.path, group=self.group, mode="a") as ds:
ds.rio.write_crs(self.crs)
ds.rio.set_spatial_dims(x_dim=self.lon_name, y_dim=self.lat_name)
self._update_geo_info()
def _bbox_query(
self,
bbox: BoundingBox,
variable: str | None = None,
**kwargs,
) -> xr.DataArray | xr.Dataset:
"""Retrieve the data of the dataset for the given bounding box."""
bbox = self._ensure_query_crs(bbox)
# get slice for lat/lon values
if self.lat[0] < self.lat[-1]:
slice_lat = slice(bbox.bottom, bbox.top)
else:
slice_lat = slice(bbox.top, bbox.bottom)
slice_lon = slice(bbox.left, bbox.right)
# open and read the dataset
if variable is None:
ds = xr.open_dataarray(self.path, group=self.group, **kwargs)
else:
ds = xr.open_dataset(self.path, group=self.group, **kwargs)[variable]
if "y" not in ds.coords or "x" not in ds.coords:
ds = ds.rename({self.lat_name: "y", self.lon_name: "x"})
data = ds.sel(y=slice_lat, x=slice_lon)
# close dataset
ds.close()
return data
def _points_query(
self,
points: Points,
variable: str | None = None,
) -> np.ndarray:
"""Return the values of dataset at given points.
Points that outside the dataset will be masked.
"""
def _polygons_query(
self,
polygons: Polygons,
variable: str | None = None,
) -> np.ndarray:
"""Return the values of the dataset at the given polygons."""
def query(
self,
query: GeoQuery | Points | BoundingBox | Polygons,
variable: str | None = None,
**kwargs,
) -> QueryResult:
"""Retrieve images values for given query.
Parameters
----------
query : GeoQuery | Points | BoundingBox | Polygons
query to index the dataset. It can be :class:`Points`,
:class:`BoundingBox`, :class:`Polygons`,
or a composite :class:`GeoQuery` (recommended) object.
variable : str, optional
name of the variable to retrieve. If None, all variables will be retrieved.
**kwargs : dict
keyword arguments to pass to :meth:`xarray.open_dataarray` if
variable is None, otherwise to :meth:`xarray.open_dataset`.
"""
if isinstance(query, Points):
query = GeoQuery(points=query)
if isinstance(query, BoundingBox):
query = GeoQuery(boxes=query)
if isinstance(query, Polygons):
query = GeoQuery(polygons=query)
return self._sample_data(query, variable, **kwargs)
def _sample_data(
self,
query: GeoQuery,
variable: str | None = None,
**kwargs,
) -> QueryResult:
"""Sample data from the dataset for the given query."""
# TODO: refine points and polygons query
# parse points result
points_result = None
if query.points is not None:
points_values = self._points_query(query.points, variable, **kwargs)
dims, points_result = parse_1d_dims(points_values, multi_files=False)
points_result = {"data": points_values, "dims": dims}
# parse bounding boxes result
boxes_result = None
if query.boxes is not None:
if len(query.boxes) == 1:
boxes_values = self._bbox_query(query.boxes[0], variable, **kwargs)
dims = parse_2d_dims(boxes_values)
else:
boxes_values = [
self._bbox_query(bbox, variable, **kwargs) for bbox in query.boxes
]
dims = parse_2d_dims(boxes_values[0], details=False)
dims = f"boxes:{len(boxes_values)}, ({dims})"
boxes_result = {"data": boxes_values, "dims": f"({dims})"}
# parse polygons result
polygons_result = None
if query.polygons is not None:
self._polygons_query(query.polygons, variable)
return QueryResult(points_result, boxes_result, polygons_result, query)
def sel(
self,
variable: str | None = None,
**kwargs,
) -> xr.DataArray | xr.Dataset:
"""Select a variable from the dataset.
This method is a wrapper of :meth:`xarray.Dataset.sel` or
:meth:`xarray.DataArray.sel`.
Parameters
----------
variable : str, optional
name of the variable to select. If None, the entire dataset will
be selected.
**kwargs : dict
keyword arguments to pass to :meth:`xarray.Dataset.sel` or
:meth:`xarray.DataArray.sel`.
"""
with xr.open_dataset(self.path, group=self.group) as ds:
return ds.sel(**kwargs) if variable is None else ds[variable].sel(**kwargs)
def isel(
self,
variable: str | None = None,
**kwargs,
) -> xr.DataArray | xr.Dataset:
"""Index a variable from the dataset.
This method is a wrapper of :meth:`xarray.Dataset.isel` or
:meth:`xarray.DataArray.isel`.
Parameters
----------
variable : str, optional
name of the variable to index. If None, the entire dataset will be indexed.
**kwargs : dict
keyword arguments to pass to :meth:`xarray.Dataset.isel` or
:meth:`xarray.DataArray.isel`.
"""
with xr.open_dataset(self.path, group=self.group) as ds:
if variable is None:
data = ds.isel(**kwargs)
else:
data = ds[variable].isel(**kwargs)
return data
def set_crs(self, crs: CRS | str) -> None:
"""Set the CRS of the dataset.
.. note::
This method is used to set the CRS of the dataset if it is not
specified in the dataset. If the CRS is already specified in the
dataset, this method will overwrite the CRS.
"""
self._crs = CRS.from_user_input(crs)
self._bounds.set_crs(self._crs)
@property
def path(self) -> Path:
"""The path of the dataset."""
return self._path
@property
def group(self) -> str:
"""The group of the dataset."""
return self._group
@property
def shape(self) -> tuple[int, int]:
"""The shape of the dataset in (height, width)."""
return self._shape
@property
def bounds(self) -> BoundingBox:
"""The bounds of the dataset."""
return self._bound
@property
def lat(self) -> np.ndarray:
"""The latitudes of the dataset."""
return self._lat
@property
def lon(self) -> np.ndarray:
"""The longitudes of the dataset."""
return self._lon
@property
def variables(self) -> list[str]:
"""The variables of the dataset."""
return self._variables
def get_profile(
self,
bbox: Literal["roi", "bounds"] | BoundingBox = "roi",
) -> Profile | None:
bbox = self._ensure_bbox(bbox)
if bbox is None:
return None
profile = Profile.from_bounds_res(bbox, self.res)
profile["crs"] = self.crs
return profile
def array2tiff( # noqa: PLR0912
self,
arr: np.ndarray,
filename: str | Path,
bounds: BoundingBox | None = None,
bbox: BoundingBox | None = None,
band_names: Sequence[str] | None = None,
arr_type: Literal["data", "mask"] = "data",
nodata: float | None = None,
overwrite: bool = False,
) -> None:
"""Save a numpy array to a tiff file using the geoinformation of dataset.
Parameters
----------
arr : numpy.ndarray
numpy array to save. arr can be a 2D array or a 3D array. If arr is a
3D array, the first dimension should be the band dimension.
filename : str or Path
path to the tiff file to save
bounds : BoundingBox, optional
the bounds of the output dataset. Default is None, which means the
roi of the dataset will be used.
bbox : BoundingBox, optional
if specified, the input array will be saved to the given part/bbox of
dataset. Default is None, which means the array will be saved to the
entire dataset.
band_names : Sequence of str, optional
names of bands to save. Default is None, which will use the band indexes.
arr_type : str, one of ['data', 'mask'], optional
type of the array to save. Default is 'data'.
nodata : float or int, optional
no data value of the dataset. If None, will automatically parse the
a proper no data value for the array.
overwrite : bool, optional
if True, overwrite the existing file. Default is False, which means
the array will be saved in append mode (r+ mode).
"""
# check arr dimension
if arr.ndim == 2:
indexes = [1]
arr = arr[np.newaxis, :, :]
elif arr.ndim == 3:
indexes = [i + 1 for i in range(arr.shape[0])]
else:
msg = (
f"Expected arr to be an array with shape of (n_lat, n_lon) or "
f"(n_band, n_lat, n_lon), got {arr.shape}"
)
raise ValueError(msg)
# check length of band_names
if band_names is not None and len(band_names) != arr.shape[0]:
msg = (
"Expected band_names to be of length "
f"{arr.shape[0]}, got {len(band_names)}"
)
raise ValueError(msg)
# parse profile
if bounds is None:
bounds = self.roi
profile = self.get_profile(bounds)
profile["count"] = arr.shape[0]
profile["driver"] = "GTiff"
profile["dtype"] = get_minimum_dtype(arr)
if nodata is None:
if np.issubdtype(arr.dtype, np.floating):
nodata = np.nan
else:
rng = dtype_ranges[profile["dtype"]]
nodata = rng[1] - 1 if np.any(arr == rng[0]) else rng[0]
profile["nodata"] = nodata
mode = "w"
if Path(filename).exists() and not overwrite:
mode = "r+"
dst = rasterio.open(filename, mode, **profile)
# parse whether to update band names
desc = np.asarray(dst.descriptions, dtype="str")
update_tags = False
if band_names is not None and np.all(desc == "None"):
update_tags = True
# parse window
win = None if bbox is None else dst.window(*bbox)
# write array to tiff
if arr_type == "mask":
dst.write_mask(arr)
elif arr_type == "data":
dst.write(arr, indexes, window=win)
if update_tags:
for i, name in enumerate(band_names):
dst.update_tags(i + 1, NAME=name)
dst.close()
class MultiHierarchicalDataset(GeoDataset):
def __init__(self, paths: Sequence[str | Path], **kwargs) -> None:
pass
[docs]
class PairDataset(RasterDataset):
"""A base class for pair datasets."""
_pairs: Pairs | None = None
_datetime: pd.DatetimeIndex | None = None
[docs]
def query(
self,
query: GeoQuery | Points | BoundingBox | Polygons,
pairs: Pairs | None = None,
) -> QueryResult:
"""Retrieve images values for given query.
This method is an more flexible implementation compared to
:meth:`__getitem__`, which can retrieve images only for the given pairs.
Parameters
----------
query : GeoQuery | Points | BoundingBox | Polygons
query to index the dataset. It can be :class:`Points`,
:class:`BoundingBox`, :class:`Polygons`, or a composite
:class:`GeoQuery` (recommended) object.
pairs : Pairs, optional
pairs to use for the query. If None, all pairs will be used.
Returns
-------
result : QueryResult
a QueryResult instance containing the results of the various queries.
"""
if isinstance(query, Points):
query = GeoQuery(points=query)
if isinstance(query, BoundingBox):
query = GeoQuery(boxes=query)
if isinstance(query, Polygons):
query = GeoQuery(polygons=query)
mask = self.files.valid
if pairs is not None:
mask = mask * self.pairs.where(pairs, return_type="mask")
paths = self.files[mask].paths
return self._sample_files(paths, query)
[docs]
@classmethod
def parse_pairs(cls, paths: list[Path]) -> Pairs:
"""Parse pairs from filenames. *Must be implemented in subclass*.
Parameters
----------
paths : list of pathlib.Path
list of file paths to parse pairs
Returns
-------
pairs : Pairs object
pairs parsed from filenames
Example
-------
for the HyP3 dataset, pairs are parsed from the filenames as follows:
>>> names = [f.name for f in paths]]
>>> pair_names = ["_".join(i.split("_")[1:3]) for i in names]
for the HyP3 dataset, the pair names are the second and third parts of the
filename, separated by an underscore. After parsing the pair names, the
:class:`Pairs` object can be created by using the ``from_names`` method.
>>> pairs = Pairs.from_names(pair_names)
"""
msg = "parse_pairs method must be implemented in subclass"
raise NotImplementedError(msg)
[docs]
@classmethod
def parse_datetime(cls, paths: list[Path]) -> pd.DatetimeIndex:
"""Parse datetime from filenames. *Must be implemented in subclass*.
Parameters
----------
paths : list of pathlib.Path
list of file paths to parse datetime
Returns
-------
datetime : pd.DatetimeIndex
datetime parsed from filenames
"""
msg = "parse_datetime method must be implemented in subclass"
raise NotImplementedError(msg)
@property
def pairs(self) -> Pairs:
"""Return Pairs parsed from filenames."""
return self._pairs
@property
def datetime(self) -> pd.DatetimeIndex:
"""Return the datetime for each pair in the dataset."""
return self._datetime
[docs]
class ApsDataset(RasterDataset):
"""A base class for aps (atmospheric phase screen) datasets."""
#: This expression is used to find the APS files.
pattern = "*"
_pairs = None
[docs]
def to_pair_files(
self,
out_dir: str | Path,
pairs: Pairs,
ref_points: Points,
roi: BoundingBox | None = None,
overwrite: bool = False,
prefix: str = "APS",
) -> None:
"""Generate aps-pair files for given pairs and reference points.
Parameters
----------
out_dir : str or Path
path to the directory to save the aps-pair files
pairs : Pairs
pairs to generate aps-pair files
ref_points : Points
reference points which values are subtracted for all aps-pair files
roi : BoundingBox, optional
region of interest to save. If None, the roi of the dataset will be used.
overwrite : bool, optional
if True, overwrite existing files, default: False
prefix : str, optional
prefix of the aps-pair files, default: "APS"
"""
if roi is None:
roi = self.roi
profile = self.get_profile(roi)
profile["count"] = 1
dates = self.parse_dates(self.files.paths)
dates_missing = np.setdiff1d(pairs.dates, dates)
if len(dates_missing) > 0:
msg = (
f"Following dates are missing in the {self.ds_name} "
f"dataset. \n{dates_missing}",
)
warnings.warn(msg, stacklevel=2)
df_paths = pd.Series(self.files.paths.values, index=dates)
mask = ~np.any(np.isin(pairs.values, dates_missing), axis=1)
pairs = pairs[mask]
pairs_names = self._ensure_saving_verbose(
pairs.to_names(),
ds_name=f"{self.ds_name} Pair",
unit=" pairs",
)
for pair_name in pairs_names:
primary, secondary = pair_name.split("_")
out_file = Path(out_dir) / f"{prefix}_{pair_name}.tif"
if out_file.exists() and not overwrite:
msg = f"{out_file.name} already exists, skipping"
logger.info(msg)
continue
with rasterio.open(out_file, "w", **profile.profile) as dst:
src_primary = self._load_warp_file(df_paths[primary])
src_secondary = self._load_warp_file(df_paths[secondary])
dest_arr = (
self._bbox_query(roi, src_primary).squeeze(0)
- self._bbox_query(roi, src_secondary).squeeze(0)
- (
self._points_query(ref_points, src_primary)
- self._points_query(ref_points, src_secondary)
).mean()
)
dst.write(dest_arr, 1)
[docs]
@classmethod
@abc.abstractmethod
def parse_dates(cls, paths: Sequence[str] | None = None) -> pd.DatetimeIndex:
"""Parse acquisition dates from filenames.
*Must be implemented in subclass*.
Parameters
----------
paths : list of pathlib.Path
list of file paths to parse datetime
Returns
-------
datetime : pd.DatetimeIndex
datetime parsed from filenames
"""
class ApsPairs(PairDataset):
"""A dataset manages the data of APS pairs."""
#: This expression is used to find the GACOSPairs files.
pattern = "*.tif"
def __init__(
self,
root_dir: str = "data",
paths: Sequence[str] | None = None,
crs: CRS | None = None,
res: float | tuple[float, float] | None = None,
dtype: np.dtype | None = None,
nodata: float | None = None,
roi: BoundingBox | None = None,
bands: Sequence[str] | None = None,
cache: bool = True,
resampling: Resampling = Resampling.nearest,
fill_nodata: bool = False,
verbose: bool = True,
ds_name: str = "",
) -> None:
"""Initialize a new ApsPairs instance.
Parameters
----------
root_dir : str or Path
root_dir directory where dataset can be found.
paths : list of str, optional
list of file paths to use instead of searching for files in ``root_dir``.
If None, files will be searched for in ``root_dir``.
crs : CRS, optional
the output term:`coordinate reference system (CRS)` of the dataset.
If None, the CRS of the first file found will be used.
res : float, optional
resolution of the output dataset in units of CRS. If None, the resolution
of the first file found will be used.
dtype : numpy.dtype, optional
data type of the output dataset. If None, the data type of the first file
found will be used.
nodata : float or int, optional
no data value of the output dataset. If None, the no data value of
the first file found will be used. This parameter is useful when the
no data value is not stored in the file.
roi : BoundingBox, optional
region of interest to load from the dataset. If None, the union of all files
bounds in the dataset will be used.
bands : list of str, optional
names of bands to return (defaults to all bands)
cache : bool, optional
if True, cache file handle to speed up repeated sampling
resampling : Resampling, optional
Resampling algorithm used when reading input files.
Default: `Resampling.nearest`.
fill_nodata : bool, optional
Whether to fill holes in the queried data by interpolating them using
inverse distance weighting method provided by the
:func:`rasterio.fill.fillnodata`. Default: False.
.. note::
This parameter is only used when sampling data using bounding
boxes or polygons queries, and will not work for points queries.
verbose : bool, optional
if True, print verbose output, default: True
ds_name : str, optional
name of the dataset. used for printing verbose output, default: ""
Raises
------
FileNotFoundError: if no files are found in ``root_dir``
"""
super().__init__(
root_dir=root_dir,
paths=paths,
crs=crs,
res=res,
dtype=dtype,
nodata=nodata,
roi=roi,
bands=bands,
cache=cache,
resampling=resampling,
fill_nodata=fill_nodata,
verbose=verbose,
ds_name=ds_name,
)
self._pairs = self.parse_pairs(self.files.paths[self.valid])
self._datetime = self.parse_datetime(self.files.paths[self.valid])
@classmethod
def parse_pairs(cls, paths: list[Path]) -> Pairs:
"""Parse pairs from a list of APS-pair file paths."""
names = [Path(f).stem for f in paths]
pair_names = ["_".join(i.split("_")[1:3]) for i in names]
return Pairs.from_names(pair_names)
@classmethod
def parse_datetime(cls, paths: list[Path]) -> pd.DatetimeIndex:
"""Parse datetime from a list of file paths."""
names = [Path(f).stem for f in paths]
pair_names = ["_".join(i.split("_")[1:3]) for i in names]
date_names = np.unique([i.split("_") for i in pair_names])
return pd.DatetimeIndex(date_names)
@property
def dates(self) -> pd.DatetimeIndex:
"""Return the dates of the dataset."""
return self._datetime
def get_nodata(
arr: np.ndarray,
nodata: float | None,
dtype: np.dtype,
) -> float:
"""Get a proper no data value for the array."""
if nodata is None:
if np.issubdtype(arr.dtype, np.floating):
nodata = np.nan
else:
rng = dtype_ranges[dtype]
nodata = rng[1] if np.any(arr == rng[0]) else rng[0] - 1
return nodata
def parse_1d_dims(
values_1d: np.ndarray,
multi_files: bool = True,
) -> tuple[str, np.ndarray]:
"""Parse the dimensions of 1D array. (used by points)."""
if multi_files:
if values_1d.ndim == 2:
n_files, n_points = values_1d.shape
dims = f"(files:{n_files}, points:{n_points})"
elif values_1d.ndim == 3:
n_files, n_points, n_bands = values_1d.shape
values_1d = values_1d.transpose(0, 2, 1)
dims = f"(files:{n_files}, bands:{n_bands}, points:{n_points})"
else:
msg = f"values_1d must be 2D or 3D, got {values_1d.ndim}"
raise ValueError(msg)
elif values_1d.ndim == 1:
n_points = values_1d.shape[0]
dims = f"points:{n_points}"
elif values_1d.ndim == 2:
n_points, n_bands = values_1d.shape
values_1d = values_1d.T
dims = f"bands:{n_bands}, points:{n_points}"
return dims, values_1d
def parse_2d_dims(
values_2d: np.ndarray,
details: bool = True,
multi_files: bool = True,
) -> str:
"""Parse the dimensions of 2D array. (used by bbox, polygons)."""
if multi_files:
if values_2d.ndim == 4:
n_files, n_bands, height, width = values_2d.shape
dims = f"files:{n_files}, bands:{n_bands}, height, width"
if details:
dims = (
f"files:{n_files}, bands:{n_bands}, height:{height}, width:{width}"
)
elif values_2d.ndim == 3:
n_files, height, width = values_2d.shape
dims = f"files:{n_files}, height, width"
if details:
dims = f"files:{n_files}, height:{height}, width:{width}"
else:
msg = f"values_2d must be 3D or 4D, got {values_2d.ndim}"
raise ValueError(msg)
elif values_2d.ndim == 3:
n_bands, height, width = values_2d.shape
values_2d = values_2d.transpose(1, 2, 0)
dims = f"bands:{n_bands}, height, width"
if details:
dims = f"bands:{n_bands}, height:{height}, width:{width}"
elif values_2d.ndim == 2:
height, width = values_2d.shape
dims = "height, width"
if details:
dims = f"height:{height}, width:{width}"
else:
msg = f"values_2d must be 2D or 3D, got {values_2d.ndim}"
raise ValueError(msg)
return dims