Skip to content

Commit

Permalink
Switched back to saving one .pkl for all transformers. The pkl is sti…
Browse files Browse the repository at this point in the history
…ll a tuple, but the elements are now dictionaries instead of lists. The rest of the code should work the same.
  • Loading branch information
stewarthe6 committed Dec 11, 2024
1 parent 18987d3 commit 4689f9e
Showing 1 changed file with 43 additions and 42 deletions.
85 changes: 43 additions & 42 deletions atomsci/ddm/pipeline/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def create_transformers(self, training_datasets):
params.transformer_key: A string pointing to the dataset key containing the transformer in the datastore, or the path to the transformer
"""
total_transformers = 0
for k, td in training_datasets.items():
self.transformers[k] = self._create_output_transformers(td)

Expand All @@ -381,15 +382,16 @@ def create_transformers(self, training_datasets):
# Set up transformers for weights, if needed
self.transformers_w[k] = trans.create_weight_transformers(self.params, td)

if len(self.transformers[k]) + len(self.transformers_x[k]) + len(self.transformers_w[k]) > 0:
total_transformers = len(self.transformers[k]) + len(self.transformers_x[k]) + len(self.transformers_w[k])

# Transformers are no longer saved as separate datastore objects; they are included in the model tarball
self.params.transformer_key = os.path.join(self.output_dir, f'transformers_{k}.pkl')
with open(self.params.transformer_key, 'wb') as txfmrpkl:
pickle.dump((self.transformers[k], self.transformers_x[k], self.transformers_w[k]), txfmrpkl)
self.log.info("Wrote transformers to %s" % self.params.transformer_key)
self.params.transformer_oid = ""
self.params.transformer_bucket = ""
if total_transformers > 0:
# Transformers are no longer saved as separate datastore objects; they are included in the model tarball
self.params.transformer_key = os.path.join(self.output_dir, 'transformers.pkl')
with open(self.params.transformer_key, 'wb') as txfmrpkl:
pickle.dump((self.transformers, self.transformers_x, self.transformers_w), txfmrpkl)
self.log.info("Wrote transformers to %s" % self.params.transformer_key)
self.params.transformer_oid = ""
self.params.transformer_bucket = ""

# ****************************************************************************************

Expand All @@ -403,42 +405,39 @@ def reload_transformers(self):
if not trans.transformers_needed(self.params):
return

for i in trans.get_transformer_keys(self.params):
# for backwards compatibity if this file exists, all folds use the same transformers
local_path = f"{self.output_dir}/transformers.pkl"
if not os.path.exists(local_path):
local_path = f"{self.output_dir}/transformers_{i}.pkl"

if os.path.exists(local_path):
self.log.info(f"Reloading transformers from model tarball {local_path}")
with open(local_path, 'rb') as txfmr:
transformers_tuple = pickle.load(txfmr)
else:
if self.params.transformer_key is not None:
if self.params.save_results:
self.log.info(f"Reloading transformers from datastore key {self.params.transformer_key}")
transformers_tuple = dsf.retrieve_dataset_by_datasetkey(
dataset_key = self.params.transformer_key,
bucket = self.params.transformer_bucket,
client = self.ds_client )
else:
self.log.info(f"Reloading transformers from file {self.params.transformer_key}")
with open(self.params.transformer_key, 'rb') as txfmr:
transformers_tuple = pickle.load(txfmr)
# for backwards compatibity if this file exists, all folds use the same transformers
local_path = f"{self.output_dir}/transformers.pkl"

if os.path.exists(local_path):
self.log.info(f"Reloading transformers from model tarball {local_path}")
with open(local_path, 'rb') as txfmr:
transformers_tuple = pickle.load(txfmr)
else:
if self.params.transformer_key is not None:
if self.params.save_results:
self.log.info(f"Reloading transformers from datastore key {self.params.transformer_key}")
transformers_tuple = dsf.retrieve_dataset_by_datasetkey(
dataset_key = self.params.transformer_key,
bucket = self.params.transformer_bucket,
client = self.ds_client )
else:
# Shouldn't happen
raise Exception("Transformers needed to reload model, but no transformer_key specified.")
self.log.info(f"Reloading transformers from file {self.params.transformer_key}")
with open(self.params.transformer_key, 'rb') as txfmr:
transformers_tuple = pickle.load(txfmr)
else:
# Shouldn't happen
raise Exception("Transformers needed to reload model, but no transformer_key specified.")


if len(transformers_tuple) == 3:
ty, tx, tw = transformers_tuple
else:
ty, tx = transformers_tuple
tw = []
if len(transformers_tuple) == 3:
ty, tx, tw = transformers_tuple
else:
ty, tx = transformers_tuple
tw = []

self.transformers[i] = ty
self.transformers_x[i] = tx
self.transformers_w[i] = tw
self.transformers = ty
self.transformers_x = tx
self.transformers_w = tw

# ****************************************************************************************

Expand Down Expand Up @@ -1072,7 +1071,7 @@ def generate_predictions(self, dataset):
"""
pred, std = None, None
self.log.info("Predicting values for current model")

dataset = self.transform_dataset(dataset, 'final')
# For deepchem's predict_uncertainty function, you are not allowed to specify transformers. That means that the
# predictions are being made in the transformed space, not the original space. We call undo_transforms() to generate
# the transformed predictions. To transform the standard deviations, we rely on the fact that at present we only use
Expand Down Expand Up @@ -1907,6 +1906,7 @@ def generate_predictions(self, dataset):
pred, std = None, None
self.log.info("Evaluating current model")

dataset = self.transform_dataset(dataset, 'final')
pred = self.model.predict(dataset, self.transformers['final'])
ncmpds = pred.shape[0]
pred = pred.reshape((ncmpds,1,-1))
Expand Down Expand Up @@ -2272,7 +2272,8 @@ def generate_predictions(self, dataset):
pred, std = None, None
self.log.warning("Evaluating current model")

pred = self.model.predict(dataset, self.transformers)
dataset = self.transform_dataset(dataset, 'final')
pred = self.model.predict(dataset, self.transformers['final'])
ncmpds = pred.shape[0]
pred = pred.reshape((ncmpds, 1, -1))
self.log.warning("uncertainty not supported by xgboost models")
Expand Down

0 comments on commit 4689f9e

Please sign in to comment.