Skip to content

Commit

Permalink
Used parameter to determine if SMOTE or undersampling is being used
Browse files Browse the repository at this point in the history
  • Loading branch information
stewarthe6 committed Sep 25, 2024
1 parent f247893 commit 2e03fef
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions atomsci/ddm/pipeline/model_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,17 +661,17 @@ def combined_training_data(self):
# All of the splits have the same combined train/valid data, regardless of whether we're using
# k-fold or train/valid/test splitting.
if self.combined_train_valid_data is None:
# normally combining one fold is sufficient, but if SMOTE is being used
# each fold will have compounds unique to it.
# normally combining one fold is sufficient, but if SMOTE or undersampling is being used
# just combining the first fold isn't enough
(train, valid) = self.train_valid_dsets[0]
combined_X = np.concatenate((train.X, valid.X), axis=0)
combined_y = np.concatenate((train.y, valid.y), axis=0)
combined_w = np.concatenate((train.w, valid.w), axis=0)
combined_ids = np.concatenate((train.ids, valid.ids))

contains_synthetic = any(id.startswith('synthetic_') for id in train.ids)
if contains_synthetic:
if self.params.sampling_method=='SMOTE' or self.params.sampling_method=='undersampling':
# for each successive fold, merge in any new compounds
# this loop just won't run if there are no additional folds
for train, valid in self.train_valid_dsets[1:]:
fold_ids = np.concatenate((train.ids, valid.ids))
new_id_indexes = [i for i in range(len(fold_ids)) if i not in combined_ids]
Expand Down

0 comments on commit 2e03fef

Please sign in to comment.