Skip to content

Commit 5fff32e

Browse files
authored
Merge pull request #2 from zimea/feat/aggregate-node-features
Feat/aggregate node features
2 parents 55e95f2 + f94b736 commit 5fff32e

File tree

7 files changed

+1590
-9
lines changed

7 files changed

+1590
-9
lines changed

docs/notebooks/multi_sample_comparison_of_node_features.ipynb

Lines changed: 1182 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = [ "hatchling" ]
55
[project]
66
name = "spatial-sample-aggregation"
77
version = "0.0.1"
8-
description = "Aggregate spatial slides into sample-level statistyics"
8+
description = "Aggregate spatial slides into sample-level statistics"
99
readme = "README.md"
1010
license = { file = "LICENSE" }
1111
maintainers = [
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .aggregate import basic_tool
1+
from .aggregate import aggregate_by_node
2+
from .compute_node_features import aggregate_by_group, compute_node_feature, get_neighbor_counts
Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import pandas as pd
22
from anndata import AnnData
3+
from squidpy._constants._pkg_constants import Key
4+
from squidpy.gr._utils import _assert_categorical_obs, _assert_connectivity_key
5+
6+
from .compute_node_features import aggregate_by_group, compute_node_feature
37

48

59
def aggregate_by_edge(
6-
adata: AnnData, sample_key: str, annotation_key: str, use_edge_weight: bool = False
10+
adata: AnnData, library_key: str, annotation_key: str, use_edge_weight: bool = False
711
) -> pd.DataFrame:
812
"""
913
Aggregate spatial neighborhood graph taking into account neighbors
@@ -15,14 +19,56 @@ def aggregate_by_edge(
1519

1620

1721
def aggregate_by_node(
18-
adata, *, sample_key: str, annotation_key: str, metric: str = "shannon", aggregate_by: str = "mean"
19-
) -> pd.DataFrame:
22+
adata: AnnData,
23+
*,
24+
library_key: str,
25+
cluster_key: str = None, # TODO: annotation_key --> cluster_key to adapt to squidpy notation
26+
metric: str = "shannon",
27+
aggregation: str = "mean", # TODO: new parameter --> check squidpy
28+
connectivity_key: str = "spatial_connectivities", # TODO: new parameter
29+
key_added: str = None,
30+
**kwargs,
31+
) -> None:
2032
"""
21-
Compute a metric on every node of the neighborhood graph. Then aggregate this metric by a group (e.g. cell-type).
33+
Compute a node-level metric and aggregate it by a sample group.
2234
2335
Parameters
2436
----------
25-
metric
26-
possible metrics are shannon entropy, count (-> get percentage of niches/cell-types), ... (?)
37+
- adata: AnnData, input data
38+
- library_key: str, column in `adata.obs` to group by
39+
- cluster_key: Optional[str], cell type or similar annotation
40+
- metric: str, metric to compute ('shannon', 'degree', 'mean_distance')
41+
- aggregation: str, aggregation method ('mean', 'median', 'sum', 'none')
42+
- connectivity_key: str, adjacency matrix key
43+
- key_added: Optional[str], key under which aggregated results are stored in `adata.uns`. Defaults to `metric`.
44+
- kwargs: Additional parameters passed to metric computation functions.
45+
46+
Returns
47+
-------
48+
- None (Results are stored in `adata.obs[key_added]` and the agggregated features are added in `adata.uns[key_added]` if aggregation is not None)
2749
"""
28-
pass
50+
# Determine where to store the results (default to metric name)
51+
if key_added is None:
52+
key_added = metric
53+
54+
# TODO: adapt to squidpy: connectivity_key = Key.obsp.spatial_conn(connectivity_key)
55+
_assert_categorical_obs(adata, cluster_key)
56+
_assert_connectivity_key(adata, connectivity_key)
57+
58+
# Compute node-level feature
59+
node_features = compute_node_feature(
60+
adata, metric, connectivity_key=connectivity_key, cluster_key=cluster_key, library_key=library_key, **kwargs
61+
)
62+
63+
# TODO: adapt to squidpy gr_utils _save_data(adata, attr="obs", key=Key.obs.feature(feature_column), data=node_features)
64+
adata.obs[key_added] = node_features # TODO: store in obs here or in the indivdiual functions
65+
66+
# Aggregate the computed metric at the sample level
67+
aggregate_by_group(
68+
adata,
69+
library_key=library_key,
70+
node_feature_key=key_added,
71+
cluster_key=cluster_key,
72+
aggregation=aggregation,
73+
key_added=key_added,
74+
)
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import numpy as np
2+
import pandas as pd
3+
import scipy
4+
from anndata import AnnData
5+
from scipy.stats import entropy
6+
from squidpy._utils import NDArrayA
7+
8+
9+
# TODO: this should go into squidpy/gr/_nhood.py
10+
def _get_neighbor_counts(
11+
data: NDArrayA,
12+
indices: NDArrayA,
13+
indptr: NDArrayA,
14+
cats: NDArrayA, # Array mapping cell indices to their types
15+
output: NDArrayA, # Shape: (n_cells, n_celltypes)
16+
) -> NDArrayA:
17+
indices_list = np.split(indices, indptr[1:-1])
18+
data_list = np.split(data, indptr[1:-1])
19+
for i in range(len(data_list)): # Iterate over cells
20+
cur_row = i # Each row corresponds to a cell
21+
cur_indices = indices_list[i]
22+
cur_data = data_list[i]
23+
for j, val in zip(cur_indices, cur_data, strict=False):
24+
cur_col = cats[j] # Column corresponds to cell type
25+
output[cur_row, cur_col] += val
26+
return output
27+
28+
29+
def get_neighbor_counts(
30+
adata, cluster_key="cell_type", connectivity_key="spatial_connectivities", key_added="composition_matrix"
31+
):
32+
"""Computes the number of each cell type in one-hop neighbors and stores it in adata.obsm['neighbor_counts']."""
33+
cats = adata.obs[cluster_key]
34+
mask = ~pd.isnull(cats).values
35+
cats = cats.loc[mask]
36+
if not len(cats):
37+
raise RuntimeError(f"After removing NaNs in `adata.obs[{cluster_key!r}]`, none remain.")
38+
39+
g = adata.obsp[connectivity_key]
40+
41+
if isinstance(g, scipy.sparse.coo_matrix):
42+
g = g.tocsr()
43+
g = g[mask, :][:, mask]
44+
n_cats = len(cats.cat.categories)
45+
46+
g_data = np.broadcast_to(1, shape=len(g.data))
47+
dtype = int if pd.api.types.is_bool_dtype(g.dtype) or pd.api.types.is_integer_dtype(g.dtype) else float
48+
output: NDArrayA = np.zeros((len(cats), n_cats), dtype=dtype)
49+
50+
neighbor_counts = _get_neighbor_counts(g_data, g.indices, g.indptr, cats.cat.codes.to_numpy(), output)
51+
52+
# adding the neighbor counts to adata.obsm
53+
# TODO: adapt to squidpy gr_utils _save_data(adata, attr="obsm", key=Key.obsm.feature(feature_column), data=node_features)
54+
adata.obsm[key_added] = neighbor_counts
55+
56+
return neighbor_counts
57+
58+
59+
def compute_node_feature(adata: AnnData, metric: str, connectivity_key: str, **kwargs) -> NDArrayA:
60+
"""
61+
Compute a node-level feature based on the selected metric.
62+
63+
Parameters
64+
----------
65+
- adata: AnnData object
66+
- metric: str, the metric to compute ('shannon', 'degree', 'mean_distance')
67+
- connectivity_key: str, the key for the adjacency matrix in `adata.obsp`
68+
- kwargs: additional parameters for specific computations (e.g., `n_hops` for Shannon)
69+
70+
Returns
71+
-------
72+
- np.ndarray: Node-level feature values indexed by cell ID
73+
"""
74+
node_feature_functions = {
75+
"shannon": compute_shannon_diversity,
76+
"degree": calculate_degree,
77+
"mean_distance": calculate_mean_distance,
78+
}
79+
80+
if metric not in node_feature_functions:
81+
raise ValueError(f"Unsupported metric: {metric}. Choose from 'shannon', 'degree', or 'mean_distance'")
82+
83+
return node_feature_functions[metric](adata, connectivity_key=connectivity_key, **kwargs).reshape(-1, 1)
84+
85+
86+
def calculate_degree(adata: AnnData, connectivity_key: str, **kwargs) -> NDArrayA:
87+
"""Compute the degree of each node."""
88+
return adata.obsp[connectivity_key].sum(axis=1)
89+
90+
91+
def calculate_mean_distance(adata: AnnData, connectivity_key: str, **kwargs) -> NDArrayA:
92+
"""Compute the mean distance to neighbors."""
93+
return np.nanmean(adata.obsp[connectivity_key].toarray(), axis=1)
94+
95+
96+
def compute_shannon_diversity(
97+
adata: AnnData,
98+
connectivity_key: str = "spatial_connectivities",
99+
cluster_key: str = "cell_type",
100+
key_added: str = "composition_matrix",
101+
**kwargs,
102+
) -> NDArrayA:
103+
"""
104+
Compute Shannon diversity index for each node based on neighbor counts.
105+
106+
Parameters
107+
----------
108+
- adata: AnnData object
109+
- connectivity_key: str, key in adata.obsp corresponding to the adjacency matrix
110+
- cluster_key: str, column in adata.obs that contains categorical annotations (e.g., cell type)
111+
- kwargs: additional arguments (not used here but included for interface consistency)
112+
113+
Returns
114+
-------
115+
- np.ndarray: Shannon diversity values indexed by cell ID
116+
"""
117+
# Compute neighbor counts directly
118+
neighbor_counts = get_neighbor_counts(
119+
adata, cluster_key=cluster_key, connectivity_key=connectivity_key, key_added=key_added
120+
)
121+
122+
# Normalize to probabilities
123+
probabilities = neighbor_counts / neighbor_counts.sum(axis=1, keepdims=True)
124+
125+
# Compute Shannon diversity (entropy), ignoring zero probabilities
126+
shannon_diversity = np.apply_along_axis(lambda p: entropy(p[p > 0], base=2), 1, probabilities)
127+
128+
return shannon_diversity.astype(np.float64)
129+
130+
131+
def aggregate_by_group(
132+
adata: AnnData,
133+
library_key: str,
134+
node_feature_key: str,
135+
cluster_key: str | None = None,
136+
aggregation: str = "mean",
137+
key_added: str = "aggregated_features",
138+
) -> None:
139+
"""
140+
Aggregate node-level features by a sample group and optionally by annotation.
141+
142+
Parameters
143+
----------
144+
- adata: AnnData object
145+
- library_key: str, column in `adata.obs` indicating the sample group
146+
- node_feature_key: str, column in `adata.obs` containing the node-level feature to aggregate
147+
- cluster_key: Optional[str], column in `adata.obs` for additional grouping (e.g., cell type)
148+
- aggregation: str, aggregation method ('mean', 'median', 'sum', None)
149+
- key_added: str, key under which results are stored in `adata.uns`
150+
151+
Returns
152+
-------
153+
- None (Results are stored in `adata.uns[output_key]`)
154+
"""
155+
if node_feature_key not in adata.obs.columns:
156+
raise ValueError(f"Column '{node_feature_key}' not found in adata.obs")
157+
158+
if library_key not in adata.obs.columns:
159+
raise ValueError(f"Column '{library_key}' not found in adata.obs")
160+
161+
if cluster_key and cluster_key not in adata.obs.columns:
162+
raise ValueError(f"Column '{cluster_key}' not found in adata.obs")
163+
164+
# Select the aggregation function
165+
agg_methods = {
166+
"mean": "mean",
167+
"median": "median",
168+
"sum": "sum",
169+
}
170+
171+
if aggregation is None:
172+
return
173+
174+
if aggregation not in agg_methods:
175+
raise ValueError(f"Unsupported aggregation method: {aggregation}")
176+
177+
# Perform aggregation
178+
if cluster_key:
179+
aggregated = (
180+
adata.obs.groupby([library_key, cluster_key])[node_feature_key]
181+
.agg(agg_methods[aggregation])
182+
.unstack() # Pivot so that annotation_key values become columns
183+
)
184+
else:
185+
aggregated = adata.obs.groupby(library_key)[node_feature_key].agg(agg_methods[aggregation])
186+
187+
# TODO: adapt to squidpy save function
188+
adata.uns[key_added] = aggregated

tests/conftest.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import pytest
2+
import pandas as pd
3+
import numpy as np
4+
import scipy.sparse
5+
from anndata import AnnData
6+
7+
@pytest.fixture
8+
def sample_adata():
9+
"""Creates a small AnnData object for testing."""
10+
# Create observation dataframe with 2 samples and 10 cells each (total 20 cells)
11+
obs = pd.DataFrame(
12+
{
13+
"cell_id": [f"cell_{i}" for i in range(20)],
14+
"cell_type": ["A", "B", "C", "A", "B", "C", "A", "B", "C", "A"] * 2, # Repeating for two samples
15+
"sample_id": ["S1"] * 10 + ["S2"] * 10, # First 10 cells in S1, next 10 in S2
16+
"node_feature": np.random.rand(20)
17+
}
18+
).set_index("cell_id")
19+
20+
# Create an adjacency matrix with 2 separate connected components (one per sample)
21+
adjacency_matrix = scipy.sparse.block_diag(
22+
[
23+
scipy.sparse.csr_matrix(
24+
[
25+
[0, 1, 0, 1, 0, 0, 0, 0, 0, 1],
26+
[1, 0, 1, 0, 1, 0, 0, 0, 0, 0],
27+
[0, 1, 0, 1, 0, 1, 0, 0, 0, 0],
28+
[1, 0, 1, 0, 0, 0, 1, 0, 0, 0],
29+
[0, 1, 0, 0, 0, 1, 0, 1, 0, 0],
30+
[0, 0, 1, 0, 1, 0, 1, 0, 1, 0],
31+
[0, 0, 0, 1, 0, 1, 0, 1, 0, 0],
32+
[0, 0, 0, 0, 1, 0, 1, 0, 1, 0],
33+
[0, 0, 0, 0, 0, 1, 0, 1, 0, 1],
34+
[1, 0, 0, 0, 0, 0, 0, 0, 1, 0],
35+
]
36+
),
37+
scipy.sparse.csr_matrix(
38+
[
39+
[0, 1, 0, 1, 0, 0, 0, 0, 0, 1],
40+
[1, 0, 1, 0, 1, 0, 0, 0, 0, 0],
41+
[0, 1, 0, 1, 0, 1, 0, 0, 0, 0],
42+
[1, 0, 1, 0, 0, 0, 1, 0, 0, 0],
43+
[0, 1, 0, 0, 0, 1, 0, 1, 0, 0],
44+
[0, 0, 1, 0, 1, 0, 1, 0, 1, 0],
45+
[0, 0, 0, 1, 0, 1, 0, 1, 0, 0],
46+
[0, 0, 0, 0, 1, 0, 1, 0, 1, 0],
47+
[0, 0, 0, 0, 0, 1, 0, 1, 0, 1],
48+
[1, 0, 0, 0, 0, 0, 0, 0, 1, 0],
49+
]
50+
),
51+
]
52+
)
53+
54+
# Create AnnData object
55+
adata = AnnData(obs=obs)
56+
adata.obs["sample_id"] = adata.obs["sample_id"].astype("category")
57+
adata.obs["cell_type"] = adata.obs["cell_type"].astype("category")
58+
adata.obsm["spatial"] = np.random.rand(20, 2)
59+
adata.obsp["spatial_connectivities"] = adjacency_matrix
60+
61+
return adata

0 commit comments

Comments
 (0)