Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ jobs:
- name: Test examples
run: |
cd mesa-examples
pytest -rA -Werror -Wdefault::FutureWarning test_gis_examples.py
pytest -rA -Werror -Wdefault::FutureWarning -Wdefault::DeprecationWarning test_gis_examples.py
221 changes: 144 additions & 77 deletions mesa_geo/raster_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations

import copy
import inspect
import itertools
import math
import warnings
Expand All @@ -18,7 +17,7 @@
from affine import Affine
from mesa import Model
from mesa.agent import Agent
from mesa.space import Coordinate, FloatCoordinate, accept_tuple_argument
from mesa.space import Coordinate, FloatCoordinate, PropertyLayer, accept_tuple_argument
from rasterio.warp import (
Resampling,
calculate_default_transform,
Expand Down Expand Up @@ -185,6 +184,7 @@ def __init__(
*,
rowcol=None,
xy=None,
raster_layer=None,
):
"""
Initialize a cell.
Expand All @@ -197,11 +197,13 @@ def __init__(
Origin is at upper left corner of the grid
:param xy: Geographic/projected (x, y) coordinates of the cell center in the CRS.
"""

super().__init__(model)
self.model = model
self.unique_id = None
if not hasattr(self, "random") and model is not None:
self.random = model.random
self._pos = pos
self._rowcol = indices if rowcol is None else rowcol
self._xy = xy
self.raster_layer = raster_layer

@property
def pos(self) -> Coordinate | None:
Expand Down Expand Up @@ -269,12 +271,116 @@ def xy(self) -> FloatCoordinate | None:
"""
Geographic/projected (x, y) coordinates of the cell center in the CRS.
"""
return self._xy
if getattr(self, "raster_layer", None) is not None and self._rowcol is not None:
return rio.transform.xy(
self.raster_layer.transform,
self._rowcol[0],
self._rowcol[1],
offset="center",
)
return None

def __getattr__(self, name: str):
if name == "raster_layer":
raise AttributeError
if (
getattr(self, "raster_layer", None) is not None
and name in self.raster_layer.attributes
):
x, y = self.pos
return self.raster_layer._properties[name].data[x, y]
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)

def __setattr__(self, name: str, value):
if (
name != "raster_layer"
and getattr(self, "raster_layer", None) is not None
and name in self.raster_layer.attributes
):
pos = getattr(self, "_pos", None)
if pos is not None:
try:
x, y = pos
self.raster_layer._properties[name].data[x, y] = value
return
except (TypeError, ValueError):
pass
super().__setattr__(name, value)

def step(self):
pass


class _CellWrapper:
def __init__(self, raster_layer):
self.rl = raster_layer
self._cell_cache = {}

def _create_cell(self, x, y):
idx = (x, y)
if idx in self._cell_cache:
return self._cell_cache[idx]

try:
cell = self.rl.cell_cls(
self.rl.model,
pos=(x, y),
indices=(self.rl.height - y - 1, x),
raster_layer=self.rl,
)
except TypeError:
cell = self.rl.cell_cls(
self.rl.model,
pos=(x, y),
indices=(self.rl.height - y - 1, x),
)
cell.raster_layer = self.rl

for attr in self.rl.attributes:
cell.__dict__.pop(attr, None)

self._cell_cache[idx] = cell
return cell

def __getitem__(self, index):
class _CellColumn:
def __init__(self, wrapper, x):
self.wrapper = wrapper
self.x = x

def __getitem__(self, y):
if isinstance(y, int):
return self.wrapper._create_cell(self.x, y)
elif isinstance(y, slice):
return [
self.wrapper._create_cell(self.x, yi)
for yi in range(*y.indices(self.wrapper.rl.height))
]
raise TypeError("Column indices must be integers or slices")

def __iter__(self):
for y in range(self.wrapper.rl.height):
yield self.wrapper._create_cell(self.x, y)

def __len__(self):
return self.wrapper.rl.height

if isinstance(index, int):
return _CellColumn(self, index)
elif isinstance(index, slice):
return [
_CellColumn(self, xi) for xi in range(*index.indices(self.rl.width))
]

raise TypeError("Raster indices must be integers or slices")

def __iter__(self):
for x in range(self.rl.width):
yield self[x]


class RasterLayer(RasterBase):
"""
Some methods in `RasterLayer` are copied from `mesa.space.Grid`, including:
Expand Down Expand Up @@ -307,74 +413,32 @@ class RasterLayer(RasterBase):
whereas it is `self.cells: List[List[Cell]]` here in `RasterLayer`.
"""

cells: list[list[Cell]]
_properties: dict[str, PropertyLayer]
_neighborhood_cache: dict[Any, list[Coordinate]]
_attributes: set[str]

def __init__(
self, width, height, crs, total_bounds, model, cell_cls: type[Cell] = Cell
):
super().__init__(width, height, crs, total_bounds)
self.model = model
self.cell_cls = cell_cls
self._initialize_cells()
self._properties = {}
self._attributes = set()
self._neighborhood_cache = {}
super().__init__(width, height, crs, total_bounds)

def _update_transform(self) -> None:
super()._update_transform()
if getattr(self, "cells", None):
self._sync_cell_xy()

def _sync_cell_xy(self) -> None:
for column in self.cells:
for cell in column:
row, col = cell.rowcol
cell._xy = rio.transform.xy(self.transform, row, col, offset="center")

def _initialize_cells(self) -> None:
try:
init_params = inspect.signature(self.cell_cls.__init__).parameters
except (TypeError, ValueError):
supports_legacy_pos_indices = False
else:
supports_legacy_pos_indices = (
"pos" in init_params and "indices" in init_params
)

if supports_legacy_pos_indices:

def make_cell(grid_x: int, grid_y: int, row_idx: int, col_idx: int, xy):
# Backward-compatible path for legacy signature:
# __init__(self, model, pos=None, indices=None, ...)
cell = self.cell_cls(
self.model,
pos=(grid_x, grid_y),
indices=(row_idx, col_idx),
)
# Legacy constructor path does not accept xy; set it manually.
cell._xy = xy
return cell
else:
# New constructor path: __init__(self, model, pos=None, rowcol=None, xy=None, ...)
# or: __init__(self, model, **kwargs)
def make_cell(grid_x: int, grid_y: int, row_idx: int, col_idx: int, xy):
return self.cell_cls(
self.model,
pos=(grid_x, grid_y),
rowcol=(row_idx, col_idx),
xy=xy,
)
warnings.warn(
"RasterLayer._sync_cell_xy is deprecated and has no effect.",
DeprecationWarning,
stacklevel=2,
)

self.cells = []
for grid_x in range(self.width):
col: list[Cell] = []
for grid_y in range(self.height):
row_idx, col_idx = self.height - grid_y - 1, grid_x
xy = rio.transform.xy(self.transform, row_idx, col_idx, offset="center")
cell = make_cell(grid_x, grid_y, row_idx, col_idx, xy)
col.append(cell)
self.cells.append(col)
@property
def cells(self):
return _CellWrapper(self)

@property
def attributes(self) -> set[str]:
Expand All @@ -384,7 +448,7 @@ def attributes(self) -> set[str]:
:return: Attributes of the cells in the raster layer.
:rtype: Set[str]
"""
return self._attributes
return set(self._properties.keys())

@overload
def __getitem__(self, index: int) -> list[Cell]: ...
Expand Down Expand Up @@ -516,13 +580,16 @@ def _default_attr_name() -> str:
for band_idx, name in enumerate(names):
attr = _default_attr_name() if name is None else name
self._attributes.add(attr)
for grid_x in range(self.width):
for grid_y in range(self.height):
setattr(
self.cells[grid_x][grid_y],
attr,
data[band_idx, self.height - grid_y - 1, grid_x],
)
prop_type = data.dtype.type
prop_layer = PropertyLayer(
attr,
self.width,
self.height,
default_value=prop_type(0),
dtype=prop_type,
)
prop_layer.data = np.flip(data[band_idx], axis=0).T
self._properties[attr] = prop_layer

def get_raster(self, attr_name: str | Sequence[str] | None = None) -> np.ndarray:
"""
Expand All @@ -549,20 +616,23 @@ def get_raster(self, attr_name: str | Sequence[str] | None = None) -> np.ndarray
)
if attr_name is None:
num_bands = len(self.attributes)
attr_names = self.attributes
attr_names = list(self.attributes)
elif isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
num_bands = len(attr_name)
attr_names = list(attr_name)
else:
num_bands = 1
attr_names = [attr_name]
data = np.empty((num_bands, self.height, self.width))

dtype = (
np.result_type(*[self._properties[name].data.dtype for name in attr_names])
if attr_names
else float
)
data = np.empty((num_bands, self.height, self.width), dtype=dtype)
for ind, name in enumerate(attr_names):
for grid_x in range(self.width):
for grid_y in range(self.height):
data[ind, self.height - grid_y - 1, grid_x] = getattr(
self.cells[grid_x][grid_y], name
)
prop_data = self._properties[name].data
data[ind] = np.flip(prop_data.T, axis=0)
return data

def iter_neighborhood(
Expand Down Expand Up @@ -744,8 +814,6 @@ def to_crs(self, crs, inplace=False) -> RasterLayer | None:
]
layer.crs = crs
layer._transform = transform
if getattr(layer, "cells", None):
layer._sync_cell_xy()

if not inplace:
return layer
Expand Down Expand Up @@ -793,7 +861,6 @@ def from_file(
]
obj = cls(width, height, dataset.crs, total_bounds, model, cell_cls)
obj._transform = dataset.transform
obj._sync_cell_xy()
obj.apply_raster(values, attr_name=attr_name)
return obj

Expand Down
Loading