Skip to content

Commit

Permalink
feat: update ligand group method (#125)
Browse files Browse the repository at this point in the history
* feat: update ligand group method

* test: change unidock-tools dock energy range

* test: unidock_tools dock test case energy_range

* Update random seed
  • Loading branch information
ysyecust authored Apr 28, 2024
1 parent d7d958f commit 6c4c72d
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 35 deletions.
60 changes: 48 additions & 12 deletions unidock/src/cuda/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ static constexpr size_t MAX_THREAD_ = 41700000 ; // modified for vina1.2, to cal
static constexpr size_t MAX_LIGAND_NUM_ = 10250;
};
struct SmallConfig {
static constexpr size_t MAX_NUM_OF_LIG_TORSION_ = 14;
static constexpr size_t MAX_NUM_OF_LIG_TORSION_ = 8;
static constexpr size_t MAX_NUM_OF_FLEX_TORSION_ = 1;
static constexpr size_t MAX_NUM_OF_RIGID_ = 14;
static constexpr size_t MAX_NUM_OF_ATOMS_ = 80;
static constexpr size_t MAX_NUM_OF_RIGID_ = 12;
static constexpr size_t MAX_NUM_OF_ATOMS_ = 40;
static constexpr size_t SIZE_OF_MOLEC_STRUC_ =
((3 + 4 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_ + 1) * sizeof(float));
static constexpr size_t SIZE_OF_CHANGE_STRUC_ =
Expand Down Expand Up @@ -132,18 +132,18 @@ static constexpr size_t MAX_THREAD_ = 41700000 ; // modified for vina1.2, to cal
static constexpr size_t MAX_LIGAND_NUM_ = 10250;
};
struct MediumConfig {
static constexpr size_t MAX_NUM_OF_LIG_TORSION_ = 18;
static constexpr size_t MAX_NUM_OF_LIG_TORSION_ = 16;
static constexpr size_t MAX_NUM_OF_FLEX_TORSION_ = 1;
static constexpr size_t MAX_NUM_OF_RIGID_ = 18;
static constexpr size_t MAX_NUM_OF_ATOMS_ = 100;
static constexpr size_t MAX_NUM_OF_ATOMS_ = 80;
static constexpr size_t SIZE_OF_MOLEC_STRUC_ =
((3 + 4 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_ + 1) * sizeof(float));
static constexpr size_t SIZE_OF_CHANGE_STRUC_ =
((3 + 3 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_ + 1) * sizeof(float));
static constexpr size_t MAX_HESSIAN_MATRIX_D_SIZE_ =
((6 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_)
* (6 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_ + 1) / 2);
static constexpr size_t MAX_NUM_OF_LIG_PAIRS_ =330;
static constexpr size_t MAX_NUM_OF_LIG_PAIRS_ =600;
static constexpr size_t MAX_NUM_OF_BFGS_STEPS_ =64;
static constexpr size_t MAX_NUM_OF_RANDOM_MAP_= 1000 ;// not too large (stack overflow!)
static constexpr size_t GRIDS_SIZE_ =37 ; // larger than vina1.1, max(XS_TYPE_SIZE, AD_TYPE_SIZE + 2)
Expand All @@ -170,16 +170,16 @@ static constexpr size_t MAX_LIGAND_NUM_ = 10250;
struct LargeConfig {
static constexpr size_t MAX_NUM_OF_LIG_TORSION_ = 24;
static constexpr size_t MAX_NUM_OF_FLEX_TORSION_ = 1;
static constexpr size_t MAX_NUM_OF_RIGID_ = 24;
static constexpr size_t MAX_NUM_OF_ATOMS_ = 100;
static constexpr size_t MAX_NUM_OF_RIGID_ = 36;
static constexpr size_t MAX_NUM_OF_ATOMS_ = 120;
static constexpr size_t SIZE_OF_MOLEC_STRUC_ =
((3 + 4 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_ + 1) * sizeof(float));
static constexpr size_t SIZE_OF_CHANGE_STRUC_ =
((3 + 3 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_ + 1) * sizeof(float));
static constexpr size_t MAX_HESSIAN_MATRIX_D_SIZE_ =
((6 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_)
* (6 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_ + 1) / 2);
static constexpr size_t MAX_NUM_OF_LIG_PAIRS_ =512;
static constexpr size_t MAX_NUM_OF_LIG_PAIRS_ =1024;
static constexpr size_t MAX_NUM_OF_BFGS_STEPS_ =64;
static constexpr size_t MAX_NUM_OF_RANDOM_MAP_= 1000 ;// not too large (stack overflow!)
static constexpr size_t GRIDS_SIZE_ =37 ; // larger than vina1.1, max(XS_TYPE_SIZE, AD_TYPE_SIZE + 2)
Expand All @@ -204,18 +204,54 @@ static constexpr size_t MAX_THREAD_ = 41700000 ; // modified for vina1.2, to cal
static constexpr size_t MAX_LIGAND_NUM_ = 10250;
};
struct ExtraLargeConfig {
static constexpr size_t MAX_NUM_OF_LIG_TORSION_ = 36;
static constexpr size_t MAX_NUM_OF_FLEX_TORSION_ = 1;
static constexpr size_t MAX_NUM_OF_RIGID_ = 64;
static constexpr size_t MAX_NUM_OF_ATOMS_ = 160;
static constexpr size_t SIZE_OF_MOLEC_STRUC_ =
((3 + 4 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_ + 1) * sizeof(float));
static constexpr size_t SIZE_OF_CHANGE_STRUC_ =
((3 + 3 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_ + 1) * sizeof(float));
static constexpr size_t MAX_HESSIAN_MATRIX_D_SIZE_ =
((6 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_)
* (6 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_ + 1) / 2);
static constexpr size_t MAX_NUM_OF_LIG_PAIRS_ =2048;
static constexpr size_t MAX_NUM_OF_BFGS_STEPS_ =64;
static constexpr size_t MAX_NUM_OF_RANDOM_MAP_= 1000 ;// not too large (stack overflow!)
static constexpr size_t GRIDS_SIZE_ =37 ; // larger than vina1.1, max(XS_TYPE_SIZE, AD_TYPE_SIZE + 2)

static constexpr size_t MAX_NUM_OF_GRID_MI_ =128; // 55
static constexpr size_t MAX_NUM_OF_GRID_MJ_= 128; // 55
static constexpr size_t MAX_NUM_OF_GRID_MK_ =128 ; // 81
static constexpr size_t MAX_NUM_OF_GRID_POINT_ =512000;

//#define GRID_MI 65//55
//#define GRID_MJ 71//55
//#define GRID_MK 61//81
static constexpr size_t MAX_PRECAL_NUM_ATOM_ =30;
static constexpr size_t MAX_P_DATA_M_DATA_SIZE_ =MAX_NUM_OF_ATOMS_*(MAX_NUM_OF_ATOMS_+1)/2;
// modified for vina1.2, should be larger, n*(n+1)/2, n=num_of_atom, select n=140
//#define MAX_NUM_OF_GRID_ATOMS 150
static constexpr size_t FAST_SIZE_ =2051 ;// modified for vina1.2 m_max_cutoff^2 * factor + 3, ad4=13424
static constexpr size_t SMOOTH_SIZE_ =2051;
static constexpr size_t MAX_CONTAINER_SIZE_EVERY_WI_ =5;

static constexpr size_t MAX_THREAD_ = 41700000 ; // modified for vina1.2, to calculate random map memory upper bound
static constexpr size_t MAX_LIGAND_NUM_ = 10250;
};
struct MaxConfig {
static constexpr size_t MAX_NUM_OF_LIG_TORSION_ = 48;
static constexpr size_t MAX_NUM_OF_FLEX_TORSION_ = 1;
static constexpr size_t MAX_NUM_OF_RIGID_ = 48;
static constexpr size_t MAX_NUM_OF_ATOMS_ = 150;
static constexpr size_t MAX_NUM_OF_RIGID_ = 128;
static constexpr size_t MAX_NUM_OF_ATOMS_ = 300;
static constexpr size_t SIZE_OF_MOLEC_STRUC_ =
((3 + 4 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_ + 1) * sizeof(float));
static constexpr size_t SIZE_OF_CHANGE_STRUC_ =
((3 + 3 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_ + 1) * sizeof(float));
static constexpr size_t MAX_HESSIAN_MATRIX_D_SIZE_ =
((6 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_)
* (6 + MAX_NUM_OF_LIG_TORSION_ + MAX_NUM_OF_FLEX_TORSION_ + 1) / 2);
static constexpr size_t MAX_NUM_OF_LIG_PAIRS_ =1024;
static constexpr size_t MAX_NUM_OF_LIG_PAIRS_ =4096;
static constexpr size_t MAX_NUM_OF_BFGS_STEPS_ =64;
static constexpr size_t MAX_NUM_OF_RANDOM_MAP_= 1000 ;// not too large (stack overflow!)
static constexpr size_t GRIDS_SIZE_ =37 ; // larger than vina1.1, max(XS_TYPE_SIZE, AD_TYPE_SIZE + 2)
Expand Down
35 changes: 35 additions & 0 deletions unidock/src/cuda/monte_carlo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,34 @@ std::vector<output_type> monte_carlo_template::cuda_to_vina<ExtraLargeConfig>(ou
}
return results_vina;
}
template<>
std::vector<output_type> monte_carlo_template::cuda_to_vina<MaxConfig>(output_type_cuda_t_<MaxConfig> results_ptr[],
int thread) const {
// DEBUG_PRINTF("entering cuda_to_vina\n");
std::vector<output_type> results_vina;
for (int i = 0; i < thread; ++i) {
output_type_cuda_t_<MaxConfig> results = results_ptr[i];
conf tmp_c;
tmp_c.ligands.resize(1);
// Position
for (int j = 0; j < 3; j++) tmp_c.ligands[0].rigid.position[j] = results.position[j];
// Orientation
qt q(results.orientation[0], results.orientation[1], results.orientation[2],
results.orientation[3]);
tmp_c.ligands[0].rigid.orientation = q;
output_type tmp_vina(tmp_c, results.e);
// torsion
for (int j = 0; j < results.lig_torsion_size; j++)
tmp_vina.c.ligands[0].torsions.push_back(results.lig_torsion[j]);
// coords
for (int j = 0; j < MaxConfig::MAX_NUM_OF_ATOMS_; j++) {
vec v_tmp(results.coords[j][0], results.coords[j][1], results.coords[j][2]);
if (v_tmp[0] * v_tmp[1] * v_tmp[2] != 0) tmp_vina.coords.push_back(v_tmp);
}
results_vina.push_back(tmp_vina);
}
return results_vina;
}
__host__ void monte_carlo::operator()(
std::vector<model> &m_gpu, std::vector<output_container> &out_gpu,
std::vector<precalculate_byatom> &p_gpu, triangular_matrix_cuda_t *m_data_list_gpu,
Expand Down Expand Up @@ -2972,6 +3000,13 @@ __host__ void monte_carlo_template::do_docking<ExtraLargeConfig>(std::vector<mod
unsigned long long seed, std::vector<std::vector<bias_element>> &bias_batch_list) const {
monte_carlo_template::do_docking_base<ExtraLargeConfig>(m_gpu, out_gpu,p_gpu, m_data_list_gpu,ig, corner1, corner2, generator, verbosity,seed, bias_batch_list);
}
template <>
__host__ void monte_carlo_template::do_docking<MaxConfig>(std::vector<model> &m_gpu, std::vector<output_container> &out_gpu,
std::vector<precalculate_byatom> &p_gpu, triangular_matrix_cuda_t *m_data_list_gpu,
const igrid &ig, const vec &corner1, const vec &corner2, rng &generator, int verbosity,
unsigned long long seed, std::vector<std::vector<bias_element>> &bias_batch_list) const {
monte_carlo_template::do_docking_base<MaxConfig>(m_gpu, out_gpu,p_gpu, m_data_list_gpu,ig, corner1, corner2, generator, verbosity,seed, bias_batch_list);
}
/* Above based on monte-carlo.cpp */

// #endif
53 changes: 34 additions & 19 deletions unidock/src/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,30 @@ float calculateScore(const Ligand &ligand) {
bool compareLigands(const Ligand &a, const Ligand &b) {
return a.score < b.score;
}
void classifyLigands(const std::vector<Ligand>& ligands,std::vector<Ligand> &smallGroup, std::vector<Ligand> &mediumGroup,std::vector<Ligand> & largeGroup, std::vector<Ligand> &extraLargeGroup, std::vector<Ligand> &overflowGroup) {
const int atomThresholds[5] = {40, 80, 120, 160, 300};
const int torsionThresholds[5] = {8, 16, 24, 36, 48};
const int rigidThresholds[5] = {12, 24, 36, 64, 128};
const int pairThresholds[5] = {300, 600, 1024, 2048, 4096};

for (const auto& lig : ligands) {
if (lig.num_atoms <= atomThresholds[0] && lig.num_torsions <= torsionThresholds[0] &&
lig.num_rigids <= rigidThresholds[0] && lig.num_lig_pairs <= pairThresholds[0]) {
smallGroup.push_back(lig);
} else if (lig.num_atoms <= atomThresholds[1] && lig.num_torsions <= torsionThresholds[1] &&
lig.num_rigids <= rigidThresholds[1] && lig.num_lig_pairs <= pairThresholds[1]) {
mediumGroup.push_back(lig);
} else if (lig.num_atoms <= atomThresholds[2] && lig.num_torsions <= torsionThresholds[2] &&
lig.num_rigids <= rigidThresholds[2] && lig.num_lig_pairs <= pairThresholds[2]) {
largeGroup.push_back(lig);
} else if (lig.num_atoms <= atomThresholds[3] && lig.num_torsions <= torsionThresholds[3] &&
lig.num_rigids <= rigidThresholds[3] && lig.num_lig_pairs <= pairThresholds[3]) {
extraLargeGroup.push_back(lig);
} else {
overflowGroup.push_back(lig);
}
}
}
void printMaxValues(const std::vector<Ligand>& group) {
int max_atoms = std::numeric_limits<int>::min();
int max_torsions = std::numeric_limits<int>::min();
Expand Down Expand Up @@ -986,19 +1010,6 @@ bug reporting, license agreements, and more information. \n";
max_num_torsions = std::max(max_num_torsions, num_torsions_vector.at(i));
max_num_rigids = std::max(max_num_rigids,num_rigids_vector.at(i));
max_num_lig_pairs = std::max(max_num_lig_pairs,num_lig_pairs_vector.at(i));

// printf("num_atoms%ld\n",num_atoms_vector.at(i));
// printf("num_torsions:%ld\n",num_torsions_vector.at(i));
// printf("num_rigids:%ld\n",num_rigids_vector.at(i));
// printf("num_internal_pairs:%ld\n",num_lig_pairs_vector.at(i));

// all_ligands[i].second.about();

// printf("max_num_ligands%ld\n",max_num_ligands);
// printf("max_num_other_pairs%ld\n",max_num_other_pairs);
// printf("max_num_flex%ld\n",max_num_flex);
// printf("max_num_ligand_degrees_of_freedom%ld\n",max_num_ligand_degrees_of_freedom);
// printf("max_num_internal_pairs%ld\n",max_num_internal_pairs);
}

printf("max_num_atoms%ld\n",max_num_atoms);
Expand All @@ -1016,12 +1027,13 @@ bug reporting, license agreements, and more information. \n";
lig.score = calculateScore(lig);
ligands.push_back(lig);
}
std::sort(ligands.begin(), ligands.end(), compareLigands);
int groupSize = ligands.size() / 4;
std::vector<Ligand> smallGroup(ligands.begin(), ligands.begin() + groupSize);
std::vector<Ligand> mediumGroup(ligands.begin() + groupSize, ligands.begin() + 2 * groupSize);
std::vector<Ligand> largeGroup(ligands.begin() + 2 * groupSize, ligands.begin() + 3 * groupSize);
std::vector<Ligand> extraLargeGroup(ligands.begin() + 3 * groupSize, ligands.end());

std::vector<Ligand> smallGroup;
std::vector<Ligand> mediumGroup;
std::vector<Ligand> largeGroup;
std::vector<Ligand> extraLargeGroup;
std::vector<Ligand> maxGroup;
classifyLigands(ligands,smallGroup,mediumGroup,largeGroup,extraLargeGroup,maxGroup);
std::cout << "Small Group:" << std::endl;
printMaxValues(smallGroup);

Expand Down Expand Up @@ -1053,6 +1065,9 @@ bug reporting, license agreements, and more information. \n";
template_batch_docking<ExtraLargeConfig>(v,all_ligands,extraLargeGroup,"Extra Large",exhaustiveness, multi_bias,max_memory,
receptor_atom_numbers, out_dir,bias_file, num_modes,
min_rmsd,max_evals,max_step,seed, refine_step, local_only, energy_range);
template_batch_docking<MaxConfig>(v,all_ligands,maxGroup,"Max",exhaustiveness, multi_bias,max_memory,
receptor_atom_numbers, out_dir,bias_file, num_modes,
min_rmsd,max_evals,max_step,seed, refine_step, local_only, energy_range);
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions unidock_tools/tests/ut/dock/test_run_unidock.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_run_unidock_vina(receptor, ligand, pocket):
size_z=pocket[5],
scoring="vina",
num_modes=10,
energy_range=6.0,
energy_range=9.0,
seed=181129,
)

Expand Down Expand Up @@ -73,8 +73,8 @@ def test_run_unidock_ad4(receptor, ligand, pocket):
size_z=pocket[5],
scoring="ad4",
num_modes=5,
energy_range=6.0,
seed=181129,
energy_range=12.0,
seed=42,
)

result_ligand = result_ligands[0]
Expand All @@ -83,4 +83,4 @@ def test_run_unidock_ad4(receptor, ligand, pocket):
scores = scores_list[0]
assert len(scores) == 5

shutil.rmtree(workdir, ignore_errors=True)
shutil.rmtree(workdir, ignore_errors=True)

0 comments on commit 6c4c72d

Please sign in to comment.