This is the official code for Discovering the Representation Bottleneck of Graph Neural Networks from Multi-order Interactions.
[IEEE TKDE] [arXiv].
Unlike social networks or knowledge graphs, no edges are explicitly defined for molecular
graphs in the 3D Euclidean space, and researchers usually employ KNN-graphs and
fully-connected graphs to construction the connectivity between
nodes or entities (e.g., atoms or residues). Our work reveals that these two standard graph construction methods can bring improper inductive bias,
which prevents GNNs from learning interactions of the optimal order [1] (i.e., complexity) and therefore reduce their performance. To overcome this limitation, we design a new graph rewiring approach to dynamically adjust the receptive fields of each node or entity.
Some necessary packages are required before running the code.
pip install torch
pip install sklearn
pip install einops
pip install matplotlib
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-{your!torch!version}+cu{your!cuda!version}.htmlWe examine the characteristics of GNNs on four different tasks. Among them, Newtonian Dynamics and molecular dynamics (MD) simulations are node-level regression tasks, while Hamiltonian dynamics and molecular property prediction are graph-level regression task. Please follow the following guidance to generate and preprocess the data.
There we follow Cranmer et al. (2020) [2] and adjust their code to generate the data.
# install necessary packages
pip install celluloid
pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.htmlpip install celluloid
pip install jaxlibNote that jax.ops.index_update is deprecated at jax 0.2.22, and we modify the profile via x0.at[].set(). Moreover, it might
cause a problem with loading JAX due to Couldn't invoke ptxas. This is because the path of ptxas is not available to the system.
A possible solution is to install cuda manually using the install_cuda_11_1.sh file.
Then run the following sh command to produce the raw data.
sudo bash install_cuda_11_1.shFinally, preprocess the raw data and save as pt file for future usage.
python data/dataset_nbody.pyThe MD dataset, ISO 17, is provided by the Quantum Machine organization, which is available on its official website. After downloading the source data, run the following script to preprocess it.
python data/dataset_iso17.pyQM7 and QM8 datasets are also accessible in the same link of the Quantum Machine organization as the MD dataset.
python data/dataset_qm.pyImplement the following command to pre-train a 3D GNN model.
python train.py --data=qm7 --method=egnn --gpu=0,1 # load a pretrained model
python test.py --data=qm7 --method=egnn --pretrain=1
# randomly initializing a model
python test.py --data=qm7 --method=egnn --pretrain=0In Section Revisiting Representation Bottlenecks of DNNs, we additionally investigate the representation bottleneck of another commonly used type of DNNs, i.e., CNNs. The cnn folder documents the corresponding implementation, which evaluates multi-order interaction strengths of timm backbones for visual representation learning. It is based on Deng et al. [1] and their official code.
This cnn repository works with PyTorch 1.8 or higher and timm. There are installation steps with the latest PyTorch:
conda create -n bottleneck python=3.8 pytorch=1.12 cudatoolkit=11.3 torchvision -c pytorch -y
conda activate bottleneck
pip install -r cnn/requirements.txtThen, please download the datasets and place them under ./cnn/datasets. CIFAR-10 will be automatically
downloaded, while ImageNet should be downloaded and unziped manually.
We only support the evaluation of pre-trained models. Please download the released pre-trained models from timm
and place them in ./cnn/timm_hub. Then run the following example on ImageNet in ./cnn/interaction_in1k.sh:
cd cnn
bash interaction.sh
You can uncomment the setting, including the model name and checkpoints, that you want to run on top of the script.
The results will be saved in the results directory by default.
If you have any questions, please do not hesitate to contact Fang WU.
Please consider citing our paper if you find it helpful. Thank you! 😜
@article{wu2024discovering,
title={Discovering the Representation Bottleneck of Graph Neural Networks},
author={Wu, Fang and Li, Siyuan and Li, Stan Z},
journal={IEEE Transactions on Knowledge and Data Engineering},
year={2024},
publisher={IEEE}
}[1]
Deng, H., Ren, Q., Chen, X., Zhang, H., Ren, J., & Zhang, Q. (2021). Discovering and explaining the representation bottleneck of dnns. arXiv preprint arXiv:2111.06236.
[2]
Cranmer, Miles, et al. "Discovering symbolic models from deep learning with inductive biases." NIPS 33 (2020): 17429-17442.
