Skip to content

Commit d5eaf5c

Browse files
fix: Move heavy imports to runtime in spatial_genes.py
Moved torch, numpy, pandas, and sklearn imports from module level to runtime within functions to avoid import errors in minimal environments. This allows the CLI to start without requiring all optional dependencies. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent baf948d commit d5eaf5c

File tree

1 file changed

+59
-23
lines changed

1 file changed

+59
-23
lines changed

chatspatial/tools/spatial_genes.py

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,25 @@
99
import sys
1010
import tempfile
1111
import shutil
12-
import numpy as np
13-
import pandas as pd
14-
import torch
15-
import torch.nn as nn
16-
from typing import Dict, List, Tuple, Optional, Any
12+
from typing import Dict, List, Tuple, Optional, Any, TYPE_CHECKING
1713
from pathlib import Path
18-
import scanpy as sc
19-
from sklearn.decomposition import PCA
20-
from sklearn.preprocessing import StandardScaler
2114
import warnings
2215
import logging
2316

17+
if TYPE_CHECKING:
18+
import numpy as np
19+
import pandas as pd
20+
import torch
21+
import torch.nn as nn
22+
import scanpy as sc
23+
from sklearn.decomposition import PCA
24+
from sklearn.preprocessing import StandardScaler
25+
2426
logger = logging.getLogger(__name__)
2527

26-
# Try to import GASTON from standard Python package installation
27-
try:
28-
import gaston
29-
from gaston import neural_net, spatial_gene_classification, binning_and_plotting
30-
from gaston import dp_related, segmented_fit, process_NN_output
31-
GASTON_AVAILABLE = True
32-
GASTON_IMPORT_ERROR = None
33-
except ImportError as e:
34-
GASTON_AVAILABLE = False
35-
# Only show warning when GASTON is actually requested
36-
GASTON_IMPORT_ERROR = str(e)
28+
# GASTON import will be done at runtime
29+
GASTON_AVAILABLE = None
30+
GASTON_IMPORT_ERROR = None
3731

3832
from ..models.data import SpatialVariableGenesParameters
3933
from ..models.analysis import SpatialVariableGenesResult
@@ -98,6 +92,25 @@ async def _identify_spatial_genes_gaston(
9892
context=None
9993
) -> SpatialVariableGenesResult:
10094
"""Identify spatial variable genes using GASTON method."""
95+
# Import dependencies at runtime
96+
import numpy as np
97+
import pandas as pd
98+
import torch
99+
import torch.nn as nn
100+
101+
# Check GASTON availability at runtime
102+
global GASTON_AVAILABLE, GASTON_IMPORT_ERROR
103+
if GASTON_AVAILABLE is None:
104+
try:
105+
import gaston
106+
from gaston import neural_net, spatial_gene_classification, binning_and_plotting
107+
from gaston import dp_related, segmented_fit, process_NN_output
108+
GASTON_AVAILABLE = True
109+
GASTON_IMPORT_ERROR = None
110+
except ImportError as e:
111+
GASTON_AVAILABLE = False
112+
GASTON_IMPORT_ERROR = str(e)
113+
101114
if not GASTON_AVAILABLE:
102115
error_msg = (
103116
f"GASTON is not available: {GASTON_IMPORT_ERROR}\n\n"
@@ -231,6 +244,10 @@ async def _identify_spatial_genes_spatialde(
231244
context=None
232245
) -> SpatialVariableGenesResult:
233246
"""Identify spatial variable genes using SpatialDE method."""
247+
# Import dependencies at runtime
248+
import numpy as np
249+
import pandas as pd
250+
import scanpy as sc
234251
try:
235252
import SpatialDE
236253
from SpatialDE.util import qvalue
@@ -355,7 +372,9 @@ async def _identify_spatial_genes_spatialde(
355372
return result
356373

357374

358-
async def _gaston_feature_engineering_glmpca(adata, n_components: int, context) -> np.ndarray:
375+
async def _gaston_feature_engineering_glmpca(adata, n_components: int, context):
376+
# Import dependencies at runtime
377+
import numpy as np
359378
"""GASTON-specific feature engineering using GLM-PCA (algorithm requirement)."""
360379
try:
361380
from glmpca.glmpca import glmpca
@@ -387,7 +406,11 @@ async def _gaston_feature_engineering_glmpca(adata, n_components: int, context)
387406
return glmpca_result["factors"]
388407

389408

390-
async def _gaston_feature_engineering_pearson(adata, n_components: int, context) -> np.ndarray:
409+
async def _gaston_feature_engineering_pearson(adata, n_components: int, context):
410+
# Import dependencies at runtime
411+
import numpy as np
412+
import scanpy as sc
413+
from sklearn.decomposition import PCA
391414
"""GASTON-specific feature engineering using Pearson residuals PCA (algorithm requirement)."""
392415
if context:
393416
await context.info("Computing GASTON Pearson residuals feature engineering (algorithm requirement)")
@@ -409,6 +432,10 @@ async def _train_gaston_model(
409432
context
410433
) -> Tuple[Any, List[float], float]:
411434
"""Train GASTON neural network model."""
435+
# Import dependencies at runtime
436+
import numpy as np
437+
import torch
438+
import torch.nn as nn
412439

413440
# Load data
414441
S = np.load(coords_file)
@@ -442,11 +469,15 @@ async def _train_gaston_model(
442469

443470

444471
async def _analyze_spatial_patterns(
445-
model, spatial_coords: np.ndarray, expression_features: np.ndarray,
472+
model, spatial_coords, expression_features,
446473
adata, params: SpatialVariableGenesParameters, context
447474
) -> Dict[str, Any]:
448475
"""Analyze spatial patterns from trained GASTON model using complete GASTON workflow."""
449-
476+
# Import dependencies at runtime
477+
import numpy as np
478+
import pandas as pd
479+
import torch
480+
450481
if context:
451482
await context.info("Processing neural network output following GASTON tutorial")
452483

@@ -649,6 +680,9 @@ async def _identify_spatial_genes_spark(
649680
context=None
650681
) -> SpatialVariableGenesResult:
651682
"""Identify spatial variable genes using SPARK method."""
683+
# Import dependencies at runtime
684+
import numpy as np
685+
import pandas as pd
652686
try:
653687
from rpy2 import robjects as ro
654688
from rpy2.robjects import conversion, default_converter
@@ -854,6 +888,8 @@ async def _identify_spatial_genes_spark(
854888

855889
def _set_random_seeds(seed: int):
856890
"""Set random seeds for reproducibility."""
891+
import numpy as np
892+
import torch
857893
np.random.seed(seed)
858894
torch.manual_seed(seed)
859895
if torch.cuda.is_available():

0 commit comments

Comments
 (0)