From 4a4301ef367059e440aec7cb5010f3e1ae97a5a8 Mon Sep 17 00:00:00 2001 From: Tejasv Singh Date: Sat, 28 Feb 2026 04:19:55 +0530 Subject: [PATCH] Refactor RasterLayer to use PropertyLayer as backend (#201) Summary This PR refactors RasterLayer to use Mesa's new PropertyLayer vectorized backend. This resolves the massive performance and memory bottlenecks caused by eager instantiation of Python objects for every pixel while completely preserving the existing Cell spatial API for backward compatibility. Bug / Issue Resolves #201 Context & Impact: The current implementation of RasterLayer (prior to this PR) proactively creates a Python Cell object (inheriting from Agent) for every single coordinate in a raster. When dealing with moderately-sized to large geographic rasters (e.g., GeoTIFFs), this object-oriented initialization scales extremely poorly, causing severe initialization delays and massive memory consumption overhead. Implementation To address this architectural bottleneck, I transitioned RasterLayer from an object-oriented memory model to a data-oriented (vectorized) memory model using PropertyLayer. - Added Vectorized Storage (PropertyLayer): RasterLayer no longer holds cells; it exclusively manages instances of Python numeric arrays mapped directly to numpy transformations. - Added O(1) Ephemeral Views: _CellWrapper dynamically proxies properties dynamically without permanent Agent object initialization overhead. --- .github/workflows/examples.yml | 2 +- mesa_geo/raster_layers.py | 221 +++++++++++++++++++++------------ tests/test_RasterLayer.py | 120 ++++++++++++++++++ 3 files changed, 265 insertions(+), 78 deletions(-) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 5c66d995..670c07ec 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -37,4 +37,4 @@ jobs: - name: Test examples run: | cd mesa-examples - pytest -rA -Werror -Wdefault::FutureWarning test_gis_examples.py + pytest -rA -Werror -Wdefault::FutureWarning -Wdefault::DeprecationWarning test_gis_examples.py diff --git a/mesa_geo/raster_layers.py b/mesa_geo/raster_layers.py index 4210d5a8..d68d459d 100644 --- a/mesa_geo/raster_layers.py +++ b/mesa_geo/raster_layers.py @@ -6,7 +6,6 @@ from __future__ import annotations import copy -import inspect import itertools import math import warnings @@ -18,7 +17,7 @@ from affine import Affine from mesa import Model from mesa.agent import Agent -from mesa.space import Coordinate, FloatCoordinate, accept_tuple_argument +from mesa.space import Coordinate, FloatCoordinate, PropertyLayer, accept_tuple_argument from rasterio.warp import ( Resampling, calculate_default_transform, @@ -185,6 +184,7 @@ def __init__( *, rowcol=None, xy=None, + raster_layer=None, ): """ Initialize a cell. @@ -197,11 +197,13 @@ def __init__( Origin is at upper left corner of the grid :param xy: Geographic/projected (x, y) coordinates of the cell center in the CRS. """ - - super().__init__(model) + self.model = model + self.unique_id = None + if not hasattr(self, "random") and model is not None: + self.random = model.random self._pos = pos self._rowcol = indices if rowcol is None else rowcol - self._xy = xy + self.raster_layer = raster_layer @property def pos(self) -> Coordinate | None: @@ -269,12 +271,116 @@ def xy(self) -> FloatCoordinate | None: """ Geographic/projected (x, y) coordinates of the cell center in the CRS. """ - return self._xy + if getattr(self, "raster_layer", None) is not None and self._rowcol is not None: + return rio.transform.xy( + self.raster_layer.transform, + self._rowcol[0], + self._rowcol[1], + offset="center", + ) + return None + + def __getattr__(self, name: str): + if name == "raster_layer": + raise AttributeError + if ( + getattr(self, "raster_layer", None) is not None + and name in self.raster_layer.attributes + ): + x, y = self.pos + return self.raster_layer._properties[name].data[x, y] + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + def __setattr__(self, name: str, value): + if ( + name != "raster_layer" + and getattr(self, "raster_layer", None) is not None + and name in self.raster_layer.attributes + ): + pos = getattr(self, "_pos", None) + if pos is not None: + try: + x, y = pos + self.raster_layer._properties[name].data[x, y] = value + return + except (TypeError, ValueError): + pass + super().__setattr__(name, value) def step(self): pass +class _CellWrapper: + def __init__(self, raster_layer): + self.rl = raster_layer + self._cell_cache = {} + + def _create_cell(self, x, y): + idx = (x, y) + if idx in self._cell_cache: + return self._cell_cache[idx] + + try: + cell = self.rl.cell_cls( + self.rl.model, + pos=(x, y), + indices=(self.rl.height - y - 1, x), + raster_layer=self.rl, + ) + except TypeError: + cell = self.rl.cell_cls( + self.rl.model, + pos=(x, y), + indices=(self.rl.height - y - 1, x), + ) + cell.raster_layer = self.rl + + for attr in self.rl.attributes: + cell.__dict__.pop(attr, None) + + self._cell_cache[idx] = cell + return cell + + def __getitem__(self, index): + class _CellColumn: + def __init__(self, wrapper, x): + self.wrapper = wrapper + self.x = x + + def __getitem__(self, y): + if isinstance(y, int): + return self.wrapper._create_cell(self.x, y) + elif isinstance(y, slice): + return [ + self.wrapper._create_cell(self.x, yi) + for yi in range(*y.indices(self.wrapper.rl.height)) + ] + raise TypeError("Column indices must be integers or slices") + + def __iter__(self): + for y in range(self.wrapper.rl.height): + yield self.wrapper._create_cell(self.x, y) + + def __len__(self): + return self.wrapper.rl.height + + if isinstance(index, int): + return _CellColumn(self, index) + elif isinstance(index, slice): + return [ + _CellColumn(self, xi) for xi in range(*index.indices(self.rl.width)) + ] + + raise TypeError("Raster indices must be integers or slices") + + def __iter__(self): + for x in range(self.rl.width): + yield self[x] + + class RasterLayer(RasterBase): """ Some methods in `RasterLayer` are copied from `mesa.space.Grid`, including: @@ -307,74 +413,32 @@ class RasterLayer(RasterBase): whereas it is `self.cells: List[List[Cell]]` here in `RasterLayer`. """ - cells: list[list[Cell]] + _properties: dict[str, PropertyLayer] _neighborhood_cache: dict[Any, list[Coordinate]] - _attributes: set[str] def __init__( self, width, height, crs, total_bounds, model, cell_cls: type[Cell] = Cell ): - super().__init__(width, height, crs, total_bounds) self.model = model self.cell_cls = cell_cls - self._initialize_cells() + self._properties = {} self._attributes = set() self._neighborhood_cache = {} + super().__init__(width, height, crs, total_bounds) def _update_transform(self) -> None: super()._update_transform() - if getattr(self, "cells", None): - self._sync_cell_xy() def _sync_cell_xy(self) -> None: - for column in self.cells: - for cell in column: - row, col = cell.rowcol - cell._xy = rio.transform.xy(self.transform, row, col, offset="center") - - def _initialize_cells(self) -> None: - try: - init_params = inspect.signature(self.cell_cls.__init__).parameters - except (TypeError, ValueError): - supports_legacy_pos_indices = False - else: - supports_legacy_pos_indices = ( - "pos" in init_params and "indices" in init_params - ) - - if supports_legacy_pos_indices: - - def make_cell(grid_x: int, grid_y: int, row_idx: int, col_idx: int, xy): - # Backward-compatible path for legacy signature: - # __init__(self, model, pos=None, indices=None, ...) - cell = self.cell_cls( - self.model, - pos=(grid_x, grid_y), - indices=(row_idx, col_idx), - ) - # Legacy constructor path does not accept xy; set it manually. - cell._xy = xy - return cell - else: - # New constructor path: __init__(self, model, pos=None, rowcol=None, xy=None, ...) - # or: __init__(self, model, **kwargs) - def make_cell(grid_x: int, grid_y: int, row_idx: int, col_idx: int, xy): - return self.cell_cls( - self.model, - pos=(grid_x, grid_y), - rowcol=(row_idx, col_idx), - xy=xy, - ) + warnings.warn( + "RasterLayer._sync_cell_xy is deprecated and has no effect.", + DeprecationWarning, + stacklevel=2, + ) - self.cells = [] - for grid_x in range(self.width): - col: list[Cell] = [] - for grid_y in range(self.height): - row_idx, col_idx = self.height - grid_y - 1, grid_x - xy = rio.transform.xy(self.transform, row_idx, col_idx, offset="center") - cell = make_cell(grid_x, grid_y, row_idx, col_idx, xy) - col.append(cell) - self.cells.append(col) + @property + def cells(self): + return _CellWrapper(self) @property def attributes(self) -> set[str]: @@ -384,7 +448,7 @@ def attributes(self) -> set[str]: :return: Attributes of the cells in the raster layer. :rtype: Set[str] """ - return self._attributes + return set(self._properties.keys()) @overload def __getitem__(self, index: int) -> list[Cell]: ... @@ -516,13 +580,16 @@ def _default_attr_name() -> str: for band_idx, name in enumerate(names): attr = _default_attr_name() if name is None else name self._attributes.add(attr) - for grid_x in range(self.width): - for grid_y in range(self.height): - setattr( - self.cells[grid_x][grid_y], - attr, - data[band_idx, self.height - grid_y - 1, grid_x], - ) + prop_type = data.dtype.type + prop_layer = PropertyLayer( + attr, + self.width, + self.height, + default_value=prop_type(0), + dtype=prop_type, + ) + prop_layer.data = np.flip(data[band_idx], axis=0).T + self._properties[attr] = prop_layer def get_raster(self, attr_name: str | Sequence[str] | None = None) -> np.ndarray: """ @@ -549,20 +616,23 @@ def get_raster(self, attr_name: str | Sequence[str] | None = None) -> np.ndarray ) if attr_name is None: num_bands = len(self.attributes) - attr_names = self.attributes + attr_names = list(self.attributes) elif isinstance(attr_name, Sequence) and not isinstance(attr_name, str): num_bands = len(attr_name) attr_names = list(attr_name) else: num_bands = 1 attr_names = [attr_name] - data = np.empty((num_bands, self.height, self.width)) + + dtype = ( + np.result_type(*[self._properties[name].data.dtype for name in attr_names]) + if attr_names + else float + ) + data = np.empty((num_bands, self.height, self.width), dtype=dtype) for ind, name in enumerate(attr_names): - for grid_x in range(self.width): - for grid_y in range(self.height): - data[ind, self.height - grid_y - 1, grid_x] = getattr( - self.cells[grid_x][grid_y], name - ) + prop_data = self._properties[name].data + data[ind] = np.flip(prop_data.T, axis=0) return data def iter_neighborhood( @@ -744,8 +814,6 @@ def to_crs(self, crs, inplace=False) -> RasterLayer | None: ] layer.crs = crs layer._transform = transform - if getattr(layer, "cells", None): - layer._sync_cell_xy() if not inplace: return layer @@ -793,7 +861,6 @@ def from_file( ] obj = cls(width, height, dataset.crs, total_bounds, model, cell_cls) obj._transform = dataset.transform - obj._sync_cell_xy() obj.apply_raster(values, attr_name=attr_name) return obj diff --git a/tests/test_RasterLayer.py b/tests/test_RasterLayer.py index 78f5499b..8433db3b 100644 --- a/tests/test_RasterLayer.py +++ b/tests/test_RasterLayer.py @@ -5,6 +5,7 @@ import mesa import numpy as np +import pytest import rasterio as rio import mesa_geo as mg @@ -307,6 +308,18 @@ def test_get_max_cell(self): self.assertEqual(max_cell.pos, (1, 1)) self.assertEqual(max_cell.elevation, 4) + def test_cell_incompatible_assignment(self): + self.raster_layer.apply_raster( + np.array([[[1, 2], [3, 4], [5, 6]]]), attr_name="elevation" + ) + cell = self.raster_layer.cells[0][0] + # Valid assignment updates backend + cell.elevation = 99 + self.assertEqual(cell.elevation, 99) + # Incompatible assignment skips backend but doesn't raise + cell.elevation = None + self.assertIsNone(cell.__dict__.get("elevation")) + def test_deprecated_pos_indices_accessors(self): cell = self.raster_layer.cells[0][0] with warnings.catch_warnings(record=True) as captured: @@ -434,3 +447,110 @@ def test_from_file_multiband_attr_name_none(self): for idx in range(data.shape[0]) ) ) + + +def test_cell_missing_raster_layer_kwarg(): + class OldSchoolCell(mg.Cell): + def __init__(self, model, pos=None, indices=None): + super().__init__(model, pos, indices) + + rl = mg.RasterLayer( + 10, + 10, + "epsg:4326", + total_bounds=[0, 0, 10, 10], + model=mesa.Model(), + cell_cls=OldSchoolCell, + ) + cell = rl.cells[0][0] + assert isinstance(cell, OldSchoolCell) + assert cell.raster_layer is rl + + +def test_cell_wrapper_dunder_methods(): + rl = mg.RasterLayer( + 10, 10, "epsg:4326", total_bounds=[0, 0, 10, 10], model=mesa.Model() + ) + + # Test __iter__ of _CellWrapper + cols = list(rl.cells) + assert len(cols) == 10 + + # Test __iter__ and __len__ of _CellColumn + col = rl.cells[0] + cells = list(col) + assert len(cells) == 10 + assert len(col) == 10 + + # Test exceptions + with pytest.raises(TypeError): + _ = rl.cells["invalid"] + + with pytest.raises(TypeError): + _ = rl.cells[0]["invalid"] + + +def test_cell_coverage(): + rl = mg.RasterLayer( + 10, 10, "epsg:4326", total_bounds=[0, 0, 10, 10], model=mesa.Model() + ) + rl.apply_raster(np.ones((1, 10, 10), dtype=np.int32), attr_name="test_attr") + + wrapper = rl.cells + cell1 = wrapper[0][0] + cell2 = wrapper[0][0] # Hit cache + assert cell1 is cell2 + + with pytest.raises(AttributeError): + _ = cell1.non_existent_attr + + # Hit TypeError fallback in Cell.__setattr__ by assigning None to numpy numeric matrix + cell1.test_attr = None + assert cell1.__dict__["test_attr"] is None + + # Hit the pos is None branch + cell1._pos = None + cell1.test_attr = 5 + assert cell1.__dict__["test_attr"] == 5 + + cell1.step() + + with pytest.warns(DeprecationWarning): + rl._sync_cell_xy() + + +def test_cell_extra_coverage(): + # Hit missing random assignment coverage + cell = mg.Cell(model=mesa.Model(), pos=(5, 5)) + assert hasattr(cell, "random") + + # Test xy property returns None when raster_layer or rowcol is None + assert cell.xy is None + + +def test_raster_layer_getitem(): + rl = mg.RasterLayer( + 10, 10, "epsg:4326", total_bounds=[0, 0, 10, 10], model=mesa.Model() + ) + + # 1. Test int indexing: rl[x] -> returns list of cells + col = rl[3] + assert len(col) == 10 + assert col[5].pos == (3, 5) + + # 2. Test sequence of coordinates: rl[[(0,0), (1,1)]] -> returns list of cells + cells = rl[[(0, 0), (1, 1)]] + assert len(cells) == 2 + assert cells[0].pos == (0, 0) + assert cells[1].pos == (1, 1) + + # 3. Test grid indexing: rl[x, y] -> returns single cell + cell = rl[3, 5] + assert cell.pos == (3, 5) + + # 4. Test slice indexing: rl[3:5, 4:6] + # Note from code: rl[slice, slice] loops over the slices + cells_slice = rl[3:5, 4:6] + # Expect generator/list of cells + cells_list = list(cells_slice) + assert len(cells_list) > 0