Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions tests/core/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,40 +382,45 @@ def test_grid_dumps():
@pytest.mark.parametrize(
argnames=["x_coord", "y_coord", "exp_exception", "exp_message", "exp_map"],
argvalues=[
(
pytest.param(
[0, 1, 2],
[[0, 1], [0, 1]],
pytest.raises(ValueError),
"The x/y coordinate arrays are not 1 dimensional",
None,
id="coords_not_1D",
),
(
pytest.param(
[0, 1, 2],
[0, 1],
pytest.raises(ValueError),
"The x/y coordinates are of unequal length",
None,
id="coords_not_equal_length",
),
(
pytest.param(
[0, 1, 2],
[0, 1, 2],
does_not_raise(),
None,
[[], [], []],
id="outside_grid",
),
(
pytest.param(
[500000, 500100, 500200],
[200000, 200100, 200200],
does_not_raise(),
None,
[[90], [80, 81, 90, 91], [71, 72, 81, 82]],
id="on_borders",
),
(
pytest.param(
[500050, 500150, 500250],
[200050, 200150, 200250],
does_not_raise(),
None,
[[90], [81], [72]],
id="within_cells",
),
],
)
Expand All @@ -436,7 +441,37 @@ def test_map_xy_to_cell_ids(
assert str(excep.value) == exp_message

if exp_map is not None:
assert cell_map == exp_map
assert all([set(obs) == set(exp) for obs, exp in zip(cell_map, exp_map)])


@pytest.mark.parametrize(
argnames="nx, ny", argvalues=((50, 50), (100, 100), (150, 150))
)
def test_map_xy_to_cell_ids_performance(nx, ny):
"""Scaling test of map_xy_to_cell_ids.

Creates a grid of nx, ny cells and then simply maps the computed centroids back to
the cells as a tool to check the run time.

1953.27s call test_grid.py::test_map_xy_to_cell_ids_performance[150-150]
371.25s call test_grid.py::test_map_xy_to_cell_ids_performance[100-100]
23.24s call test_grid.py::test_map_xy_to_cell_ids_performance[50-50]

1.19s call test_grid.py::test_map_xy_to_cell_ids_performance[150-150]
0.56s call test_grid.py::test_map_xy_to_cell_ids_performance[100-100]
0.13s call test_grid.py::test_map_xy_to_cell_ids_performance[50-50]

TODO: Not sure this is best placed as a test, but saving the code and information
here until we have a more formal performance testing ground.
"""

from virtual_ecosystem.core.grid import Grid

grid = Grid(cell_nx=nx, cell_ny=ny)

_ = grid.map_xy_to_cell_id(
x_coords=grid.centroids[:, 0], y_coords=grid.centroids[:, 1]
)


@pytest.mark.parametrize(
Expand Down
10 changes: 8 additions & 2 deletions virtual_ecosystem/core/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import numpy as np
from numpy.typing import NDArray
from scipy.spatial.distance import cdist, pdist, squareform # type: ignore
from shapely import GeometryCollection, Point, Polygon, STRtree # type: ignore
from shapely.affinity import scale, translate # type: ignore
from shapely.geometry import GeometryCollection, Point, Polygon # type: ignore

from virtual_ecosystem.core.exceptions import ConfigurationError
from virtual_ecosystem.core.logger import LOGGER
Expand Down Expand Up @@ -238,6 +238,9 @@ def __init__(
self.centroids: np.ndarray
"""A list of the centroid of each cell as shapely.geometry.Point objects, in
cell_id order."""
self._strtree: STRtree
"""An STRtree object of the grid polygons, used for searching coordinates
matches in loading data."""

# Retrieve the creator function from the grid registry and handle unknowns
creator = GRID_REGISTRY.get(self.grid_type, None)
Expand Down Expand Up @@ -265,6 +268,9 @@ def __init__(
centroids = [cell.centroid for cell in self.polygons]
self.centroids = np.array([(gm.xy[0][0], gm.xy[1][0]) for gm in centroids])

# Populate the STRtree
self._strtree = STRtree(self.polygons)

# Get the bounds as a 4 tuple
self.bounds: GeometryCollection = GeometryCollection(self.polygons).bounds
"""A GeometryCollection providing the bounds of the cell polygons."""
Expand Down Expand Up @@ -510,7 +516,7 @@ def map_xy_to_cell_id(
# object https://shapely.readthedocs.io/en/latest/strtree.html

return [
[id for id, ply in zip(self.cell_id, self.polygons) if ply.intersects(pt)]
self._strtree.query(geometry=pt, predicate="intersects").tolist()
for pt in xyp
]

Expand Down