Source code for faninsar.samplers.patch

"""A module defines the patch samplers for sampling data from a dataset."""

from __future__ import annotations

import abc
import math
from typing import TYPE_CHECKING

import numpy as np

from faninsar.logging import setup_logger
from faninsar.query import BoundingBox

if TYPE_CHECKING:
    from collections.abc import Iterator

    from faninsar.datasets import GeoDataset

logger = setup_logger(
    log_name="FanInSAR.samplers.batch",
    log_format="%(levelname)s - %(message)s",
)


class PatchSampler(abc.ABC):
    """Abstract base class for patch samplers."""

    _boxes: np.ndarray
    _length: int
    _shape: tuple[int]

    def __len__(self) -> int:
        """Return the length of the patch sampler."""
        return self._length

    def __str__(self) -> str:
        """Return the string representation of the patch sampler."""
        return f"{self.__class__.__name__}[boxes={self.shape}]"

    def __repr__(self) -> str:
        """Return the string representation of the patch sampler."""
        return self.__str__()

    def __getitem__(self, index: int | slice) -> BoundingBox | np.ndarray:
        """Get the bounding boxes/patches of at the given index.

        Parameters
        ----------
        index : Any
            The index of the patch. It can be any form of index that can be used
            to index a numpy array. See numpy indexing for more information:
            https://numpy.org/doc/stable/user/basics.indexing.html

        Returns
        -------
        BoundingBox | np.ndarray
            The bounding box of the patch at the given index.

        """
        return self._boxes[index]

    @property
    def boxes(self) -> np.ndarray:
        """The bounding boxes/patches of the sampler used to sample dataset."""
        return self._boxes

    @property
    def shape(self) -> tuple[int]:
        """The shape of the patch sampler."""
        return self._shape


[docs] class RowSampler(PatchSampler): """A sampler samples data from a dataset in a row-wise manner."""
[docs] def __init__( self, dataset: GeoDataset, roi: BoundingBox | None = None, row_num: int | None = None, height: int | None = None, verbose: bool = True, ) -> None: """Initialize a RowSampler. Parameters ---------- dataset : GeoDataset The dataset needs to be sampled. roi : BoundingBox or Sequence, optional The the region of interest bounding box. If not provided, the bounding box of the dataset will be used. row_num : int, optional The number of patches to be sampled for row-wise sampling. If height is provided, this parameter will be ignored. height : int, optional The height (in pixels) of the patch to be sampled for row-wise sampling . if not provided, the row_num will be used. verbose : bool, optional Whether to print verbose information. The verbose of the dataset will be set to this value. Default is True. """ self.dataset = dataset self.res = dataset.res self.dataset.verbose = verbose if roi is not None: self.dataset.roi = roi profile = dataset.get_profile("roi") ds_height = profile["height"] if height is not None: height = int(height) row_num = math.ceil(ds_height / height) else: if row_num is None: msg = "Either height or row_num must be provided." raise ValueError(msg) if row_num > ds_height: msg = ( f"row_num ({row_num}) is larger than the height ({ds_height})\n" "of the dataset. The row_num will be set to the height of the " "dataset.\n If this cannot meet your requirement, please " "try to choose other Sampler." ) logger.warning(msg, stacklevel=2) row_num = ds_height row_num = int(row_num) height = math.floor(ds_height / row_num) self.row_num = row_num self.height = height self._length = row_num self._shape = (row_num,) self._boxes = self._gen_patch_boxes()
def _gen_patch_boxes(self) -> None: roi = self.dataset.roi patch_boxes = [] for i in range(self.row_num): bottom = (i * self.height) * self.res[1] + roi.bottom top = self.height * self.res[1] + bottom if i == self.row_num - 1: top = roi.top patch_boxes.append( BoundingBox(roi.left, bottom, roi.right, top, crs=self.dataset.crs), ) return np.asarray(patch_boxes, dtype=np.object_) def __iter__(self) -> Iterator: """Iterate over the bounding boxes of the patches.""" for i in range(self.row_num): yield self.boxes[i]
[docs] class ColSampler(PatchSampler): """A sampler samples data from a dataset in a col-wise manner. This class is used to sample data from a dataset. The dataset is represented by a bounding box, and the sampler is used to sample data in the bounding box. The result of sampling is an iterator that yields data from the dataset. """
[docs] def __init__( self, dataset: GeoDataset, roi: BoundingBox | None = None, col_num: int | None = None, width: int | None = None, verbose: bool = True, ) -> None: """Initialize a ColSampler. Parameters ---------- dataset : GeoDataset The dataset needs to be sampled. roi : BoundingBox or Sequence, optional The the region of interest bounding box. If not provided, the bounding box of the dataset will be used. col_num : int, optional The number of patches to be sampled for row-wise sampling. If width is provided, this parameter will be ignored. width : int, optional The width (in pixel) of the patch to be sampled for col-wise sampling. if not provided, the col_num will be used. verbose : bool, optional Whether to print verbose information. The verbose of the dataset will be set to this value. Default is True. """ self.dataset = dataset self.res = dataset.res[1] self.dataset.verbose = verbose if roi is not None: self.dataset.roi = roi profile = dataset.get_profile("roi") ds_width = profile["width"] if width is not None: width = int(width) col_num = math.ceil(ds_width / width) else: if col_num is None: msg = "Either width or col_num must be provided." raise ValueError(msg) if col_num > ds_width: msg = ( f"col_num ({col_num}) is larger than the width ({ds_width})\n" "of the dataset. The col_num will be set to the width of the " "dataset.\n If this cannot meet your requirement, please " "try to choose other Sampler." ) logger.warning(msg) col_num = width col_num = int(col_num) width = math.floor(ds_width / col_num) self.col_num = col_num self.width = width
def __iter__(self) -> Iterator: """Iterate over the bounding boxes of the patches.""" roi = self.dataset.roi width = self.width col_num = self.col_num patch_boxes = [] for i in range(col_num): bottom = (i * width) * self.res + roi.bottom top = width * self.res + bottom patch_boxes.append([roi.left, bottom, roi.right, top]) # make last patch top equal to roi top patch_boxes[-1][3] = roi.top for patch_bbox in patch_boxes: yield BoundingBox(*patch_bbox, crs=self.dataset.crs) def __len__(self) -> int: """Return the length of the patch sampler.""" return self.col_num
[docs] class RowColSampler(PatchSampler): """A sampler samples data from a dataset in a row-col-wise manner. This class is used to sample data from a dataset. The dataset is represented by a bounding box, and the sampler is used to sample data in the bounding box. The result of sampling is an iterator that yields data from the dataset. """
[docs] def __init__( self, dataset: GeoDataset, roi: BoundingBox | None = None, height: int | None = None, width: int | None = None, row_num: int | None = None, col_num: int | None = None, verbose: bool = True, ) -> None: """Initialize a RowColSampler. Parameters ---------- dataset : GeoDataset The dataset to be sampled. roi : BoundingBox or Sequence, optional The region of interest bounding box. If not provided, the bounding box of the dataset will be used. height : int, optional The height (in pixels) for each patch. If row_num is provided, this parameter will be ignored. width : int, optional The width (in pixels) for each patch. If col_num is provided, this parameter will be ignored. row_num : int, optional The number of rows to be sampled for row-col-wise sampling. If height is provided, this parameter will be ignored. col_num : int, optional The number of columns to be sampled for row-col-wise sampling. If width is provided, this parameter will be ignored. verbose : bool, optional Whether to print verbose information. The verbose of the dataset will be set to this value. Default is True. """ self.dataset = dataset self.res = dataset.res self.dataset.verbose = verbose if roi is not None: self.dataset.roi = roi profile = dataset.get_profile("roi") ds_height = profile["height"] ds_width = profile["width"] # row direction if height is not None: height = int(height) row_num = math.ceil(ds_height / height) else: if row_num is None: msg = "Either height or row_num must be provided." raise ValueError(msg) if row_num > ds_height: msg = ( f"row_num ({row_num}) is larger than the height ({ds_height})\n" "of the dataset. The row_num will be set to the height of the" " dataset.\n If this cannot meet your requirement, please try" " to choose other Sampler.", ) logger.warning(msg) row_num = ds_height row_num = int(row_num) height = math.floor(ds_height / row_num) # col direction if width is not None: width = int(width) col_num = math.ceil(ds_width / width) else: if col_num is None: msg = "Either width or col_num must be provided." raise ValueError(msg) if col_num > ds_width: msg = ( f"col_num ({col_num}) is larger than the width ({ds_width})\n" "of the dataset. The col_num will be set to the width of the" " dataset.\n If this cannot meet your requirement, please try" " to choose other Sampler.", ) logger.warning(msg) col_num = width col_num = int(col_num) width = math.floor(ds_width / col_num) self.row_num = row_num self.height = height self.col_num = col_num self.width = width self._shape = (row_num, col_num) self._length = row_num * col_num self._boxes = self._gen_patch_boxes()
def _gen_patch_boxes(self) -> None: roi = self.dataset.roi patch_boxes = [] for i in range(self.row_num): patch_boxes_row = [] for j in range(self.col_num): left = (j * self.width) * self.res[0] + roi.left right = self.width * self.res[0] + left bottom = (i * self.height) * self.res[1] + roi.bottom top = self.height * self.res[1] + bottom if j == self.col_num - 1: right = roi.right if i == self.row_num - 1: top = roi.top bbox = BoundingBox(left, bottom, right, top, crs=self.dataset.crs) patch_boxes_row.append(bbox) patch_boxes.append(patch_boxes_row) return np.asarray(patch_boxes, dtype=np.object_) def __iter__(self) -> Iterator: """Iterate over the bounding boxes of the patches.""" for i in range(self.row_num): for j in range(self.col_num): patch_bbox = self.boxes[i, j] yield patch_bbox