diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 5c66d995..670c07ec 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -37,4 +37,4 @@ jobs: - name: Test examples run: | cd mesa-examples - pytest -rA -Werror -Wdefault::FutureWarning test_gis_examples.py + pytest -rA -Werror -Wdefault::FutureWarning -Wdefault::DeprecationWarning test_gis_examples.py diff --git a/mesa_geo/raster_layers.py b/mesa_geo/raster_layers.py index 4210d5a8..d68d459d 100644 --- a/mesa_geo/raster_layers.py +++ b/mesa_geo/raster_layers.py @@ -6,7 +6,6 @@ from __future__ import annotations import copy -import inspect import itertools import math import warnings @@ -18,7 +17,7 @@ from affine import Affine from mesa import Model from mesa.agent import Agent -from mesa.space import Coordinate, FloatCoordinate, accept_tuple_argument +from mesa.space import Coordinate, FloatCoordinate, PropertyLayer, accept_tuple_argument from rasterio.warp import ( Resampling, calculate_default_transform, @@ -185,6 +184,7 @@ def __init__( *, rowcol=None, xy=None, + raster_layer=None, ): """ Initialize a cell. @@ -197,11 +197,13 @@ def __init__( Origin is at upper left corner of the grid :param xy: Geographic/projected (x, y) coordinates of the cell center in the CRS. """ - - super().__init__(model) + self.model = model + self.unique_id = None + if not hasattr(self, "random") and model is not None: + self.random = model.random self._pos = pos self._rowcol = indices if rowcol is None else rowcol - self._xy = xy + self.raster_layer = raster_layer @property def pos(self) -> Coordinate | None: @@ -269,12 +271,116 @@ def xy(self) -> FloatCoordinate | None: """ Geographic/projected (x, y) coordinates of the cell center in the CRS. """ - return self._xy + if getattr(self, "raster_layer", None) is not None and self._rowcol is not None: + return rio.transform.xy( + self.raster_layer.transform, + self._rowcol[0], + self._rowcol[1], + offset="center", + ) + return None + + def __getattr__(self, name: str): + if name == "raster_layer": + raise AttributeError + if ( + getattr(self, "raster_layer", None) is not None + and name in self.raster_layer.attributes + ): + x, y = self.pos + return self.raster_layer._properties[name].data[x, y] + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + def __setattr__(self, name: str, value): + if ( + name != "raster_layer" + and getattr(self, "raster_layer", None) is not None + and name in self.raster_layer.attributes + ): + pos = getattr(self, "_pos", None) + if pos is not None: + try: + x, y = pos + self.raster_layer._properties[name].data[x, y] = value + return + except (TypeError, ValueError): + pass + super().__setattr__(name, value) def step(self): pass +class _CellWrapper: + def __init__(self, raster_layer): + self.rl = raster_layer + self._cell_cache = {} + + def _create_cell(self, x, y): + idx = (x, y) + if idx in self._cell_cache: + return self._cell_cache[idx] + + try: + cell = self.rl.cell_cls( + self.rl.model, + pos=(x, y), + indices=(self.rl.height - y - 1, x), + raster_layer=self.rl, + ) + except TypeError: + cell = self.rl.cell_cls( + self.rl.model, + pos=(x, y), + indices=(self.rl.height - y - 1, x), + ) + cell.raster_layer = self.rl + + for attr in self.rl.attributes: + cell.__dict__.pop(attr, None) + + self._cell_cache[idx] = cell + return cell + + def __getitem__(self, index): + class _CellColumn: + def __init__(self, wrapper, x): + self.wrapper = wrapper + self.x = x + + def __getitem__(self, y): + if isinstance(y, int): + return self.wrapper._create_cell(self.x, y) + elif isinstance(y, slice): + return [ + self.wrapper._create_cell(self.x, yi) + for yi in range(*y.indices(self.wrapper.rl.height)) + ] + raise TypeError("Column indices must be integers or slices") + + def __iter__(self): + for y in range(self.wrapper.rl.height): + yield self.wrapper._create_cell(self.x, y) + + def __len__(self): + return self.wrapper.rl.height + + if isinstance(index, int): + return _CellColumn(self, index) + elif isinstance(index, slice): + return [ + _CellColumn(self, xi) for xi in range(*index.indices(self.rl.width)) + ] + + raise TypeError("Raster indices must be integers or slices") + + def __iter__(self): + for x in range(self.rl.width): + yield self[x] + + class RasterLayer(RasterBase): """ Some methods in `RasterLayer` are copied from `mesa.space.Grid`, including: @@ -307,74 +413,32 @@ class RasterLayer(RasterBase): whereas it is `self.cells: List[List[Cell]]` here in `RasterLayer`. """ - cells: list[list[Cell]] + _properties: dict[str, PropertyLayer] _neighborhood_cache: dict[Any, list[Coordinate]] - _attributes: set[str] def __init__( self, width, height, crs, total_bounds, model, cell_cls: type[Cell] = Cell ): - super().__init__(width, height, crs, total_bounds) self.model = model self.cell_cls = cell_cls - self._initialize_cells() + self._properties = {} self._attributes = set() self._neighborhood_cache = {} + super().__init__(width, height, crs, total_bounds) def _update_transform(self) -> None: super()._update_transform() - if getattr(self, "cells", None): - self._sync_cell_xy() def _sync_cell_xy(self) -> None: - for column in self.cells: - for cell in column: - row, col = cell.rowcol - cell._xy = rio.transform.xy(self.transform, row, col, offset="center") - - def _initialize_cells(self) -> None: - try: - init_params = inspect.signature(self.cell_cls.__init__).parameters - except (TypeError, ValueError): - supports_legacy_pos_indices = False - else: - supports_legacy_pos_indices = ( - "pos" in init_params and "indices" in init_params - ) - - if supports_legacy_pos_indices: - - def make_cell(grid_x: int, grid_y: int, row_idx: int, col_idx: int, xy): - # Backward-compatible path for legacy signature: - # __init__(self, model, pos=None, indices=None, ...) - cell = self.cell_cls( - self.model, - pos=(grid_x, grid_y), - indices=(row_idx, col_idx), - ) - # Legacy constructor path does not accept xy; set it manually. - cell._xy = xy - return cell - else: - # New constructor path: __init__(self, model, pos=None, rowcol=None, xy=None, ...) - # or: __init__(self, model, **kwargs) - def make_cell(grid_x: int, grid_y: int, row_idx: int, col_idx: int, xy): - return self.cell_cls( - self.model, - pos=(grid_x, grid_y), - rowcol=(row_idx, col_idx), - xy=xy, - ) + warnings.warn( + "RasterLayer._sync_cell_xy is deprecated and has no effect.", + DeprecationWarning, + stacklevel=2, + ) - self.cells = [] - for grid_x in range(self.width): - col: list[Cell] = [] - for grid_y in range(self.height): - row_idx, col_idx = self.height - grid_y - 1, grid_x - xy = rio.transform.xy(self.transform, row_idx, col_idx, offset="center") - cell = make_cell(grid_x, grid_y, row_idx, col_idx, xy) - col.append(cell) - self.cells.append(col) + @property + def cells(self): + return _CellWrapper(self) @property def attributes(self) -> set[str]: @@ -384,7 +448,7 @@ def attributes(self) -> set[str]: :return: Attributes of the cells in the raster layer. :rtype: Set[str] """ - return self._attributes + return set(self._properties.keys()) @overload def __getitem__(self, index: int) -> list[Cell]: ... @@ -516,13 +580,16 @@ def _default_attr_name() -> str: for band_idx, name in enumerate(names): attr = _default_attr_name() if name is None else name self._attributes.add(attr) - for grid_x in range(self.width): - for grid_y in range(self.height): - setattr( - self.cells[grid_x][grid_y], - attr, - data[band_idx, self.height - grid_y - 1, grid_x], - ) + prop_type = data.dtype.type + prop_layer = PropertyLayer( + attr, + self.width, + self.height, + default_value=prop_type(0), + dtype=prop_type, + ) + prop_layer.data = np.flip(data[band_idx], axis=0).T + self._properties[attr] = prop_layer def get_raster(self, attr_name: str | Sequence[str] | None = None) -> np.ndarray: """ @@ -549,20 +616,23 @@ def get_raster(self, attr_name: str | Sequence[str] | None = None) -> np.ndarray ) if attr_name is None: num_bands = len(self.attributes) - attr_names = self.attributes + attr_names = list(self.attributes) elif isinstance(attr_name, Sequence) and not isinstance(attr_name, str): num_bands = len(attr_name) attr_names = list(attr_name) else: num_bands = 1 attr_names = [attr_name] - data = np.empty((num_bands, self.height, self.width)) + + dtype = ( + np.result_type(*[self._properties[name].data.dtype for name in attr_names]) + if attr_names + else float + ) + data = np.empty((num_bands, self.height, self.width), dtype=dtype) for ind, name in enumerate(attr_names): - for grid_x in range(self.width): - for grid_y in range(self.height): - data[ind, self.height - grid_y - 1, grid_x] = getattr( - self.cells[grid_x][grid_y], name - ) + prop_data = self._properties[name].data + data[ind] = np.flip(prop_data.T, axis=0) return data def iter_neighborhood( @@ -744,8 +814,6 @@ def to_crs(self, crs, inplace=False) -> RasterLayer | None: ] layer.crs = crs layer._transform = transform - if getattr(layer, "cells", None): - layer._sync_cell_xy() if not inplace: return layer @@ -793,7 +861,6 @@ def from_file( ] obj = cls(width, height, dataset.crs, total_bounds, model, cell_cls) obj._transform = dataset.transform - obj._sync_cell_xy() obj.apply_raster(values, attr_name=attr_name) return obj diff --git a/tests/test_RasterLayer.py b/tests/test_RasterLayer.py index 78f5499b..8433db3b 100644 --- a/tests/test_RasterLayer.py +++ b/tests/test_RasterLayer.py @@ -5,6 +5,7 @@ import mesa import numpy as np +import pytest import rasterio as rio import mesa_geo as mg @@ -307,6 +308,18 @@ def test_get_max_cell(self): self.assertEqual(max_cell.pos, (1, 1)) self.assertEqual(max_cell.elevation, 4) + def test_cell_incompatible_assignment(self): + self.raster_layer.apply_raster( + np.array([[[1, 2], [3, 4], [5, 6]]]), attr_name="elevation" + ) + cell = self.raster_layer.cells[0][0] + # Valid assignment updates backend + cell.elevation = 99 + self.assertEqual(cell.elevation, 99) + # Incompatible assignment skips backend but doesn't raise + cell.elevation = None + self.assertIsNone(cell.__dict__.get("elevation")) + def test_deprecated_pos_indices_accessors(self): cell = self.raster_layer.cells[0][0] with warnings.catch_warnings(record=True) as captured: @@ -434,3 +447,110 @@ def test_from_file_multiband_attr_name_none(self): for idx in range(data.shape[0]) ) ) + + +def test_cell_missing_raster_layer_kwarg(): + class OldSchoolCell(mg.Cell): + def __init__(self, model, pos=None, indices=None): + super().__init__(model, pos, indices) + + rl = mg.RasterLayer( + 10, + 10, + "epsg:4326", + total_bounds=[0, 0, 10, 10], + model=mesa.Model(), + cell_cls=OldSchoolCell, + ) + cell = rl.cells[0][0] + assert isinstance(cell, OldSchoolCell) + assert cell.raster_layer is rl + + +def test_cell_wrapper_dunder_methods(): + rl = mg.RasterLayer( + 10, 10, "epsg:4326", total_bounds=[0, 0, 10, 10], model=mesa.Model() + ) + + # Test __iter__ of _CellWrapper + cols = list(rl.cells) + assert len(cols) == 10 + + # Test __iter__ and __len__ of _CellColumn + col = rl.cells[0] + cells = list(col) + assert len(cells) == 10 + assert len(col) == 10 + + # Test exceptions + with pytest.raises(TypeError): + _ = rl.cells["invalid"] + + with pytest.raises(TypeError): + _ = rl.cells[0]["invalid"] + + +def test_cell_coverage(): + rl = mg.RasterLayer( + 10, 10, "epsg:4326", total_bounds=[0, 0, 10, 10], model=mesa.Model() + ) + rl.apply_raster(np.ones((1, 10, 10), dtype=np.int32), attr_name="test_attr") + + wrapper = rl.cells + cell1 = wrapper[0][0] + cell2 = wrapper[0][0] # Hit cache + assert cell1 is cell2 + + with pytest.raises(AttributeError): + _ = cell1.non_existent_attr + + # Hit TypeError fallback in Cell.__setattr__ by assigning None to numpy numeric matrix + cell1.test_attr = None + assert cell1.__dict__["test_attr"] is None + + # Hit the pos is None branch + cell1._pos = None + cell1.test_attr = 5 + assert cell1.__dict__["test_attr"] == 5 + + cell1.step() + + with pytest.warns(DeprecationWarning): + rl._sync_cell_xy() + + +def test_cell_extra_coverage(): + # Hit missing random assignment coverage + cell = mg.Cell(model=mesa.Model(), pos=(5, 5)) + assert hasattr(cell, "random") + + # Test xy property returns None when raster_layer or rowcol is None + assert cell.xy is None + + +def test_raster_layer_getitem(): + rl = mg.RasterLayer( + 10, 10, "epsg:4326", total_bounds=[0, 0, 10, 10], model=mesa.Model() + ) + + # 1. Test int indexing: rl[x] -> returns list of cells + col = rl[3] + assert len(col) == 10 + assert col[5].pos == (3, 5) + + # 2. Test sequence of coordinates: rl[[(0,0), (1,1)]] -> returns list of cells + cells = rl[[(0, 0), (1, 1)]] + assert len(cells) == 2 + assert cells[0].pos == (0, 0) + assert cells[1].pos == (1, 1) + + # 3. Test grid indexing: rl[x, y] -> returns single cell + cell = rl[3, 5] + assert cell.pos == (3, 5) + + # 4. Test slice indexing: rl[3:5, 4:6] + # Note from code: rl[slice, slice] loops over the slices + cells_slice = rl[3:5, 4:6] + # Expect generator/list of cells + cells_list = list(cells_slice) + assert len(cells_list) > 0