G2PT is an auto-regressive transformer model designed to learn graph structures through next-token prediction.
📑 paper: https://www.arxiv.org/abs/2501.01073
🤗 checkpoints: G2PT Collection
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("xchen16/g2pt-guacamol-small-deg")
model = AutoModelForCausalLM.from_pretrained("xchen16/g2pt-guacamol-small-deg")
# Generate sequences
inputs = tokenizer(['<boc>'], return_tensors="pt")
outputs = model.generate(
inputs["input_ids"],
max_length=tokenizer.model_max_length,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=1.0
)
sequences = tokenizer.batch_decode(outputs)
# Converting sequences to Smiles/RDKit Molecules/nx graphs
...
Datasets | ||||
---|---|---|---|---|
QM9 | Moses | GuacaMol | ||
Small | BFS | ✅ | ✅ | ✅ |
DEG | ✅ | ✅ | ✅ | |
Base | BFS | ✅ | ✅ | |
DEG | ✅ | ✅ | ||
Large | BFS | ✅ | ✅ | |
DEG | ✅ | ✅ |
More coming soon...
-
First, get the code:
git clone https://github.com/tufts-ml/g2pt_hf.git cd g2pt_hf
-
Set up your Python environment:
conda create -n g2pt python==3.10 conda activate g2pt
-
Install dependencies:
pip install -r requirements.txt
For dataset preparation instructions, please refer to datasets/README.md. For using custom data, make sure to provide the corresponding tokenizer configurations, see tokenizers.
Launch training with the provided script:
sh scripts/pretrain.sh
Default training configuration:
- To distributed train across N GPUs, set --nproc_per_node=N
- Modify configs/datasets and configs/networks for your tasks. Training arguments are in configs/default.py
Generate new graphs using:
sh scripts/sample.sh
If you use G2PT in your research, please cite our paper:
@article{chen2025graph,
title={Graph Generative Pre-trained Transformer},
author={Chen, Xiaohui and Wang, Yinkai and He, Jiaxing and Du, Yuanqi and Hassoun, Soha and Xu, Xiaolin and Liu, Li-Ping},
journal={arXiv preprint arXiv:2501.01073},
year={2025}
}