diff --git a/src/bioemu/sample.py b/src/bioemu/sample.py index b92209d..668a3da 100644 --- a/src/bioemu/sample.py +++ b/src/bioemu/sample.py @@ -74,6 +74,7 @@ def main( denoiser_type: SupportedDenoisersLiteral | None = "dpm", denoiser_config_path: str | Path | None = None, cache_embeds_dir: str | Path | None = None, + cache_so3_dir: str | Path | None = None, msa_host_url: str | None = None, filter_samples: bool = True, ) -> None: @@ -95,6 +96,7 @@ def main( denoiser_type: Denoiser to use for sampling, if `denoiser_config_path` not specified. Comes in with default parameter configuration. Must be one of ['dpm', 'heun'] denoiser_config_path: Path to the denoiser config, defining the denoising process. cache_embeds_dir: Directory to store MSA embeddings. If not set, this defaults to `COLABFOLD_DIR/embeds_cache`. + cache_so3_dir: Directory to store SO3 precomputations. If not set, this defaults to `~/sampling_so3_cache`. msa_host_url: MSA server URL. If not set, this defaults to colabfold's remote server. If sequence is an a3m file, this is ignored. filter_samples: Filter out unphysical samples with e.g. long bond distances or steric clashes. """ @@ -111,6 +113,9 @@ def main( with open(model_config_path) as f: model_config = yaml.safe_load(f) + if cache_so3_dir is not None: + model_config["sdes"]["node_orientations"]["cache_dir"] = cache_so3_dir + # User may have provided an MSA file instead of a sequence. This will be used for embeddings. msa_file = sequence if str(sequence).endswith(".a3m") else None