Improving Structural Plausibility in 3D Molecule Generation via Property-Conditioned Training with Distorted Molecules
This repository accompanies the paper "Improving Structural Plausibility in 3D Molecule Generation via Property-Conditioned Training with Distorted Molecules" (preprint here). Our approach involves introducing distorted molecules into training datasets and annotating each molecule with a label that reflects its level of distortion, and consequently, its structural quality. By training generative models to distinguish between high- and low-quality molecular conformations, we enable selective sampling from high-quality regions of the learned space, resulting in an improvement in the validity of generated molecules.
Our conditional method has been tested on the following models:
Each model can be trained and sampled using its original source code without modification. To set up an environment that is compatible for all of these modules and clone all three repos, use the repos_and_envs script:
git clone https://github.com/lucyvost/distorted_diffusion.git
cd distorted_diffusion
bash repos_and_env.sh
This environment was created by the authors of GCDM, and includes all packages required for training and sampling all three models discussed:
hydra-core=1.2.0
matplotlib-base=3.4.3
numpy=1.23.1
pyg=2.2.0=py39_torch_1.12.0_cu116
python=3.9.15
pytorch=1.12.1=py3.9_cuda11.6_cudnn8.3.2_0
pytorch-cluster=1.6.0=py39_torch_1.12.0_cu116
pytorch-scatter=2.1.0=py39_torch_1.12.0_cu116
pytorch-sparse=0.6.16=py39_torch_1.12.0_cu116
pytorch-lightning=1.7.7
scikit-learn=1.1.2
torchmetrics=0.10.2
The installation script uses mamba, so once created, you may need to add the location of miniforge to your path to activate the environment.For more details, please check out their repo.
We use three molecular datasets for evaluation. To enable comparison with the pretrained baseline models, we follow the same processing and splitting regimes.
All three datasets can be downloaded and processed using the download_datasets.sh script. Beware that this takes around 40 mins, and will occupy ~70GB.
bash download_datasets.sh
Alternatively, QM9 and GEOM can be individually downloaded and processed using the EDM repo:
QM9: downloaded and processed using this EDM script
GEOM: downloaded and processed following instructions here
To add distorted molecules and labels to a downloaded and preprocessed dataset, run:
python distort_molecules.py --datadir $datadir --max_dist 0.25 --ratio_distorted_mols 50
After downloading the datasets, train the model on any dataset using training_scripts/train_edm.sh
, specifying the dataset name (qm9, geom or zinc) and the mode (baseline or conditional).
bash training_scripts/train_edm.sh $dataset $mode
Sample the model using sampling_scripts/sample_edm.sh
, specifying the location of the checkpoints and the mode.
bash sampling_scripts/sample_edm.sh $path_to_checkpoints $mode
After downloading the datasets, train the model on any dataset using training_scripts/train_gcdm.sh
, specifying the dataset name (qm9, geom or zinc) and the mode (baseline or conditional).
bash training_scripts/train_gcdm.sh $dataset $mode
Sample the model using sampling_scripts/sample_gcdm.sh
, specifying the location of the checkpoints, the mode, and the dataset.
bash sampling_scripts/sample_gcdm.sh $path_to_checkpoints $mode $dataset
Note: since this work was carried out, the authors have released a docker container for their model. For this work, we used the code provided by them as supplementary information here. Below is a guide to running this version of the code - for the new version, please follow the guidance on their repo.
After downloading the datasets, train the model on any dataset using training_scripts/train_molfm.sh
, specifying the dataset name (qm9, geom or zinc) and the mode (baseline or conditional).
bash training_scripts/train_molfm.sh $dataset $mode
Sample the model using sampling_scripts/sample_molfm.sh
, specifying the location of the checkpoints and the mode.
bash sampling_scripts/sample_molfm.sh $path_to_checkpoints $mode
We provide checkpoints for all of the models assessed in the manuscript in checkpoints. These can each be sampled using the corresponding shell scripts as shown above.
The molecules we generated with each model are available in the generated_molecules
folder. To reproduce the results shown in tables 1-3 of the manuscript, run
python assess_molecules.py $path_to_generations
This will return a table with individual PoseBusters pass rates as well as 95% confidence intervals. Note that due to the large number of molecules and energy calculations of PoseBusters, this script can take up 40 minutes to run for a single set of molecules.