diff --git a/models/crabnet_hyperparameter/.gitignore b/models/crabnet_hyperparameter/.gitignore index 5da6c5d..4957a60 100644 --- a/models/crabnet_hyperparameter/.gitignore +++ b/models/crabnet_hyperparameter/.gitignore @@ -3,3 +3,4 @@ # Except this file !.gitignore !dummy +!cv diff --git a/models/crabnet_hyperparameter/cv/.gitignore b/models/crabnet_hyperparameter/cv/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/models/crabnet_hyperparameter/dummy/.gitignore b/models/crabnet_hyperparameter/dummy/.gitignore index 5e7d273..cf3699e 100644 --- a/models/crabnet_hyperparameter/dummy/.gitignore +++ b/models/crabnet_hyperparameter/dummy/.gitignore @@ -2,3 +2,4 @@ * # Except this file !.gitignore +!cv diff --git a/models/crabnet_hyperparameter/dummy/cv/.gitignore b/models/crabnet_hyperparameter/dummy/cv/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/models/crabnet_hyperparameter/dummy/cv/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/notebooks/crabnet_hyperparameter/1.0-sgb-collect-from-mongodb.ipynb b/notebooks/crabnet_hyperparameter/1.0-sgb-collect-from-mongodb.ipynb index 8b20dc5..8054970 100644 --- a/notebooks/crabnet_hyperparameter/1.0-sgb-collect-from-mongodb.ipynb +++ b/notebooks/crabnet_hyperparameter/1.0-sgb-collect-from-mongodb.ipynb @@ -12,7 +12,8 @@ "# Requires the PyMongo package.\n", "# https://api.mongodb.com/python/current\n", "\n", - "client = MongoClient(f\"mongodb+srv://{MONGODB_USERNAME}:{MONGODB_PASSWORD}@matsci-opt-benchmarks.ehu7qrh.mongodb.net/?retryWrites=true&w=majority\")\n", + "cluster_uri = \"matsci-opt-benchmarks.ehu7qrh\"\n", + "client = MongoClient(f\"mongodb+srv://{MONGODB_USERNAME}:{MONGODB_PASSWORD}@{cluster_uri}.mongodb.net/?retryWrites=true&w=majority\")\n", "\n", "database_name = \"crabnet-hyperparameter\"\n", "collection_name = \"sobol\"" diff --git a/notebooks/crabnet_hyperparameter/1.2-jp-surrogate.ipynb b/notebooks/crabnet_hyperparameter/1.2-jp-surrogate.ipynb index 91f44a4..60aebd2 100644 --- a/notebooks/crabnet_hyperparameter/1.2-jp-surrogate.ipynb +++ b/notebooks/crabnet_hyperparameter/1.2-jp-surrogate.ipynb @@ -14,11 +14,10 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ - "import pickle\n", "from sklearn.ensemble import RandomForestRegressor\n", "from sklearn.metrics import mean_absolute_error\n", "from sklearn.model_selection import KFold, GroupKFold\n", @@ -37,11 +36,13 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ - "dummy = True\n", + "from pathlib import Path\n", + "\n", + "dummy = False\n", "\n", "task_name = \"crabnet_hyperparameter\"\n", "\n", @@ -49,7 +50,12 @@ "model_dir = path.join(\"..\", \"..\", \"models\", task_name)\n", "\n", "if dummy:\n", - " model_dir = path.join(model_dir, \"dummy\")" + " model_dir = path.join(model_dir, \"dummy\")\n", + " \n", + "cv_model_dir = path.join(model_dir, \"cv\")\n", + "\n", + "Path(model_dir).mkdir(exist_ok=True, parents=True) # technically redundant\n", + "Path(cv_model_dir).mkdir(exist_ok=True, parents=True)" ] }, { @@ -62,14 +68,17 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "sobol_reg = pd.read_csv(path.join(data_dir, \"sobol_regression.csv\"))\n", "\n", "if dummy:\n", - " sobol_reg = sobol_reg.head(100)" + " data_dir = path.join(data_dir, \"dummy\")\n", + " sobol_reg = sobol_reg.head(100)\n", + " \n", + "Path(data_dir).mkdir(exist_ok=True, parents=True)" ] }, { @@ -82,7 +91,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -108,12 +117,14 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "# argument for rfr_mae, X_array, y_array, model_name to save model as .pkl\n", - "def rfr_group_mae(X_array, y_array, group_array, model_name_stem, objective_name, random_state=13):\n", + "def rfr_group_mae(\n", + " X_array, y_array, group_array, model_name_stem, objective_name, random_state=13\n", + "):\n", " kf = GroupKFold(n_splits=5)\n", " mae_scores = []\n", " y_preds = []\n", @@ -134,11 +145,11 @@ " mae = mean_absolute_error(y_test, y_pred)\n", " mae_scores.append(mae)\n", " # save model as .pkl\n", - " joblib.dump(model, f\"{model_name_stem}_{i}.pkl\")\n", + " joblib.dump(model, f\"{model_name_stem}_{i}.pkl\", compress=7)\n", "\n", " avg_mae = np.mean(mae_scores)\n", " std_mae = np.std(mae_scores)\n", - " \n", + "\n", " print(f\"MAE for {objective_name}: {avg_mae:.4f} +/- {std_mae:.4f}\")\n", " results = {\"mae\": mae_scores, \"y_pred\": y_preds, \"y_true\": y_trues}\n", " return results" @@ -154,7 +165,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -187,16 +198,7 @@ " \"elem_prop_mat2vec\",\n", " \"elem_prop_onehot\",\n", " \"hardware_2080ti\"\n", - "]\n", - "\n", - "# fba_isna_prob_features = common_features\n", - "# ls_isna_prob_features = common_features\n", - "\n", - "# fba_features = common_features + [\"fba_rank\"]\n", - "# ls_features = common_features + [\"ls_rank\"]\n", - "\n", - "# fba_time_s_features = common_features + [\"fba_time_s_rank\"]\n", - "# ls_time_s_features = common_features + [\"ls_time_s_rank\"]" + "]" ] }, { @@ -211,31 +213,33 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "MAE for mae: 0.0886 +/- 0.0158\n" + "MAE for mae: 0.0217 +/- 0.0004\n" ] } ], "source": [ - "mae_features = common_features + [\"mae_rank\"] ## ToDo: mae or mae_rank (mae.1)\n", + "mae_features = common_features + [\"mae_rank\"]\n", "\n", "X_array_mae = sobol_reg[mae_features].to_numpy()\n", "y_array_mae = sobol_reg[[\"mae\"]].to_numpy().ravel()\n", "\n", "sobol_reg_mae_group = (\n", - " sobol_reg[mae_features]\n", + " sobol_reg[common_features]\n", " .round(6)\n", " .apply(lambda row: \"_\".join(row.values.astype(str)), axis=1)\n", ")\n", "\n", "mae_model_stem = path.join(model_dir, \"sobol_reg_mae\")\n", - "mae_results = rfr_group_mae(X_array_mae, y_array_mae, sobol_reg_mae_group, mae_model_stem, \"mae\")" + "mae_results = rfr_group_mae(\n", + " X_array_mae, y_array_mae, sobol_reg_mae_group, mae_model_stem, \"mae\"\n", + ")" ] }, { @@ -248,25 +252,25 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "MAE for rmse: 0.0883 +/- 0.0257\n" + "MAE for rmse: 0.0265 +/- 0.0004\n" ] } ], "source": [ - "rmse_features = common_features + [\"rmse_rank\"] ## ToDo: rmse or rmse_rank (rmse.1)\n", + "rmse_features = common_features + [\"rmse_rank\"]\n", "\n", "X_array_rmse = sobol_reg[rmse_features].to_numpy()\n", "y_array_rmse = sobol_reg[[\"rmse\"]].to_numpy().ravel()\n", "\n", "sobol_reg_rmse_group = (\n", - " sobol_reg[rmse_features]\n", + " sobol_reg[common_features]\n", " .round(6)\n", " .apply(lambda row: \"_\".join(row.values.astype(str)), axis=1)\n", ")\n", @@ -287,32 +291,36 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "MAE for model_size: 7078125.7105 +/- 791229.1921\n" + "MAE for model_size: 317796.8646 +/- 6005.9939\n" ] } ], "source": [ - "model_size_features = common_features \n", + "model_size_features = common_features\n", "\n", "X_array_model_size = sobol_reg[model_size_features].to_numpy()\n", "y_array_model_size = sobol_reg[[\"model_size\"]].to_numpy().ravel()\n", "\n", "sobol_reg_model_size_group = (\n", - " sobol_reg[model_size_features]\n", + " sobol_reg[common_features]\n", " .round(6)\n", " .apply(lambda row: \"_\".join(row.values.astype(str)), axis=1)\n", ")\n", "\n", "model_size_model_stem = path.join(model_dir, \"sobol_reg_model_size\")\n", "model_size_results = rfr_group_mae(\n", - " X_array_model_size, y_array_model_size, sobol_reg_model_size_group, model_size_model_stem, \"model_size\"\n", + " X_array_model_size,\n", + " y_array_model_size,\n", + " sobol_reg_model_size_group,\n", + " model_size_model_stem,\n", + " \"model_size\",\n", ")" ] }, @@ -326,870 +334,44 @@ }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "MAE for runtime: 71.1022 +/- 12.7256\n" + "MAE for runtime: 20.5904 +/- 1.0632\n" ] } ], "source": [ - "runtime_features = common_features + [\"runtime_rank\"] ## ToDo: runtime or runtime_rank (runtime.1)\n", + "runtime_features = common_features + [\"runtime_rank\"]\n", "\n", "X_array_runtime = sobol_reg[runtime_features].to_numpy()\n", "y_array_runtime = sobol_reg[[\"runtime\"]].to_numpy().ravel()\n", "\n", "sobol_reg_runtime_group = (\n", - " sobol_reg[runtime_features]\n", + " sobol_reg[common_features]\n", " .round(6)\n", " .apply(lambda row: \"_\".join(row.values.astype(str)), axis=1)\n", ")\n", "\n", "runtime_model_stem = path.join(model_dir, \"sobol_reg_runtime\")\n", "runtime_results = rfr_group_mae(\n", - " X_array_runtime, y_array_runtime, sobol_reg_runtime_group, runtime_model_stem, \"runtime\"\n", + " X_array_runtime,\n", + " y_array_runtime,\n", + " sobol_reg_runtime_group,\n", + " runtime_model_stem,\n", + " \"runtime\",\n", ")" ] }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 33, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'mae': {'mae': [0.09736450358524826,\n", - " 0.09584076684453244,\n", - " 0.0802529964475443,\n", - " 0.10721362715898333,\n", - " 0.06212542711224781],\n", - " 'y_pred': [[1.0470445123262662,\n", - " 1.1388262858078486,\n", - " 1.0937742371687706,\n", - " 1.092910442683129,\n", - " 1.0438372602378871,\n", - " 1.0979699665832756,\n", - " 0.4680361319697645,\n", - " 0.4694053084366586,\n", - " 0.558221162505513,\n", - " 0.8858907433235408,\n", - " 0.7634087634084612,\n", - " 0.9416579741236427,\n", - " 0.4489247723793967,\n", - " 0.438963360103239,\n", - " 0.4383556589032155,\n", - " 0.4480121911319065,\n", - " 0.3982781319502199,\n", - " 0.44545946894258875,\n", - " 0.44618629554354483,\n", - " 0.4244994333327561],\n", - " [0.9314951872437939,\n", - " 1.0570991161603451,\n", - " 1.0727408998185621,\n", - " 0.8653992024529414,\n", - " 0.4556970487915757,\n", - " 0.6124129060295667,\n", - " 0.39256521816360324,\n", - " 0.4853519476046882,\n", - " 0.45922470429958195,\n", - " 0.4867680789728061,\n", - " 0.3957088248979133,\n", - " 0.3986279115375167,\n", - " 0.4179404932305161,\n", - " 0.46804234963511776,\n", - " 0.43770644923194396,\n", - " 0.5005538058504512,\n", - " 0.4957694754658767,\n", - " 0.6732913507628604,\n", - " 0.6992719613452546,\n", - " 0.6175151008557497],\n", - " [1.0645154938648376,\n", - " 1.0513774450679632,\n", - " 0.4124506529524868,\n", - " 0.44970922600866325,\n", - " 0.44055496630936086,\n", - " 0.5421493464381575,\n", - " 0.4516257085298328,\n", - " 0.4258607698930314,\n", - " 0.7842562108258788,\n", - " 0.40015367476313785,\n", - " 0.807394855530477,\n", - " 0.545868524628608,\n", - " 0.7518140327928516,\n", - " 0.4032509027923053,\n", - " 0.8469963227257022,\n", - " 0.7114102397051942,\n", - " 0.5066308963672581,\n", - " 0.5040870817923644,\n", - " 0.7852454384834193,\n", - " 0.4105958258809023],\n", - " [1.0924168055612038,\n", - " 0.5529444684588257,\n", - " 0.6802815405923517,\n", - " 0.47296371379381014,\n", - " 0.5455000012091871,\n", - " 0.6681200081076015,\n", - " 0.7565501612890135,\n", - " 0.6340613799549157,\n", - " 0.4465716679751337,\n", - " 0.6676576312856797,\n", - " 0.38892204021160354,\n", - " 0.44439471503151656,\n", - " 0.4777860593937262,\n", - " 0.44103111165157843,\n", - " 0.46411154479667255,\n", - " 0.44609222794217035,\n", - " 0.46331540104243724,\n", - " 0.39487414144119076,\n", - " 0.42166744226578273,\n", - " 0.5688390301089565],\n", - " [1.1032782248332371,\n", - " 1.0530333359531,\n", - " 1.1088351354161683,\n", - " 0.9841679384498019,\n", - " 0.7005165482624112,\n", - " 0.42518213319652937,\n", - " 0.4089415760184945,\n", - " 0.7424869449077462,\n", - " 0.5628183964224851,\n", - " 0.7159215170923885,\n", - " 0.5574749370189526,\n", - " 0.7078909625637984,\n", - " 0.46589749271394615,\n", - " 0.48995625546955857,\n", - " 0.49224159509110976,\n", - " 0.4137654322396781,\n", - " 0.4266609970794159,\n", - " 0.7267706503027708,\n", - " 0.4989256874920076,\n", - " 0.37636780471835524]],\n", - " 'y_true': [[1.0552956436332182,\n", - " 1.1646900603902348,\n", - " 0.8825866566388886,\n", - " 0.9526161819708144,\n", - " 1.0247243415469645,\n", - " 1.0632357963246917,\n", - " 0.4467830208391259,\n", - " 0.4879459462990078,\n", - " 0.7295502814362195,\n", - " 0.6036306083593018,\n", - " 0.4950718890674334,\n", - " 0.6873335839811521,\n", - " 0.4697655548505207,\n", - " 0.5089084580495541,\n", - " 0.3829123496867304,\n", - " 0.524473419983807,\n", - " 0.411288796630173,\n", - " 0.3598592958901402,\n", - " 0.4440858954943622,\n", - " 0.2560991213476931],\n", - " [1.1636267432773704,\n", - " 1.0329359661069408,\n", - " 1.069201801398877,\n", - " 0.7986457110702646,\n", - " 0.4961310269854214,\n", - " 0.3910859773521523,\n", - " 0.3781965069907805,\n", - " 0.5076757418334757,\n", - " 0.4366450097351446,\n", - " 0.4056658706592403,\n", - " 0.1847152402180571,\n", - " 0.3674441869937287,\n", - " 0.3938699664585747,\n", - " 0.4617969166526025,\n", - " 0.4384921978179171,\n", - " 0.4580442337381689,\n", - " 0.7810141802427323,\n", - " 0.4708583080815828,\n", - " 0.4271250467533048,\n", - " 0.5050356267317561],\n", - " [1.1281137998730446,\n", - " 1.0945798451843127,\n", - " 0.3698879828365178,\n", - " 0.4531683274902793,\n", - " 0.4525050852451548,\n", - " 0.4879867846446609,\n", - " 0.3517701866868379,\n", - " 0.266989165110708,\n", - " 0.7149614023677884,\n", - " 0.40370649311889,\n", - " 0.8699616604560994,\n", - " 0.6329857875493202,\n", - " 0.4865620239311286,\n", - " 0.4153386451995923,\n", - " 0.9599266516672662,\n", - " 0.905210358687475,\n", - " 0.4729544192430172,\n", - " 0.4788339612073082,\n", - " 0.9877029619290842,\n", - " 0.470004454747044],\n", - " [1.140685128033838,\n", - " 0.7358502915281392,\n", - " 0.823283850791903,\n", - " 0.3770487405189237,\n", - " 0.632236364652154,\n", - " 0.7570929615386877,\n", - " 0.9228007848571124,\n", - " 0.8552324450242426,\n", - " 0.4397101259946004,\n", - " 0.7418079102259512,\n", - " 0.3105765446476314,\n", - " 0.5272420833655439,\n", - " 0.492218651605294,\n", - " 0.3810726320132326,\n", - " 0.3284284262940278,\n", - " 0.4428304188026957,\n", - " 0.4214869295850267,\n", - " 0.8272488626167098,\n", - " 0.2959023736082823,\n", - " 0.5132978670584217],\n", - " [1.090678079017915,\n", - " 1.0367006433420138,\n", - " 1.154388738166125,\n", - " 1.0419262190112009,\n", - " 0.7369600156783515,\n", - " 0.4540547082609593,\n", - " 0.3695250474272188,\n", - " 0.9375034384771675,\n", - " 0.445065585423383,\n", - " 0.841639333632844,\n", - " 0.4379726318549143,\n", - " 0.7742928276560113,\n", - " 0.4290468745763599,\n", - " 0.5052136967579737,\n", - " 0.444035506999191,\n", - " 0.4046760710931941,\n", - " 0.5068712581170945,\n", - " 0.6778075029531191,\n", - " 0.6022706271919798,\n", - " 0.3371497033997458]]},\n", - " 'rmse': {'mae': [0.08249603882369612,\n", - " 0.10168581037225795,\n", - " 0.05762723560523438,\n", - " 0.13041153919371945,\n", - " 0.06924802382426436],\n", - " 'y_pred': [[1.5238692246360168,\n", - " 1.4655829603969355,\n", - " 1.4119493312595102,\n", - " 1.4602867156359536,\n", - " 1.5552393195051957,\n", - " 1.4723981225437484,\n", - " 0.9680403806558324,\n", - " 0.9879015442716637,\n", - " 1.0535927467944404,\n", - " 1.4196304727738704,\n", - " 1.3337738333385416,\n", - " 1.3887454259765786,\n", - " 0.9376103267542049,\n", - " 0.9045619746669434,\n", - " 0.9219045661163912,\n", - " 0.9428727072758584,\n", - " 0.8313254924530462,\n", - " 0.8714695269875063,\n", - " 0.9568763853347154,\n", - " 0.916312221215246],\n", - " [1.599377566495635,\n", - " 1.456209200683957,\n", - " 1.4931920196340065,\n", - " 1.442799569358048,\n", - " 0.8925604853768674,\n", - " 1.0794832164567894,\n", - " 0.8094332370998285,\n", - " 1.0014597683934843,\n", - " 0.9652937385233382,\n", - " 0.9801220580556906,\n", - " 0.8077079692417787,\n", - " 0.8443761378169541,\n", - " 0.8679153032476179,\n", - " 0.9418620186541284,\n", - " 0.8833955933844,\n", - " 0.97078712267161,\n", - " 1.0028033737631723,\n", - " 1.242607196773291,\n", - " 1.0532296234334095,\n", - " 1.252921853844784],\n", - " [1.5934330098578129,\n", - " 1.4600890937063096,\n", - " 0.8673188763744352,\n", - " 0.9812690373674894,\n", - " 0.9879829873927076,\n", - " 0.9339159552791316,\n", - " 0.8415739310391548,\n", - " 0.7950122373153122,\n", - " 1.2488793875515696,\n", - " 0.8471858542611423,\n", - " 1.4567857127582058,\n", - " 1.043455631820574,\n", - " 1.2222334182540955,\n", - " 0.9012598664260956,\n", - " 1.4531442951710176,\n", - " 1.3939602663867714,\n", - " 0.963088389259521,\n", - " 0.8678056459005251,\n", - " 1.2570670134849484,\n", - " 0.8868542004532824],\n", - " [1.5131655253583194,\n", - " 1.0402493531199817,\n", - " 1.1998484536433174,\n", - " 0.9279572083492303,\n", - " 0.9720606733113156,\n", - " 1.083777964375988,\n", - " 1.1220832288509783,\n", - " 1.0549121886099275,\n", - " 0.916662294070513,\n", - " 1.1965897496439573,\n", - " 0.7873796903789142,\n", - " 0.949181793658731,\n", - " 0.8457189758051021,\n", - " 0.9300122466982974,\n", - " 0.9442269675230726,\n", - " 0.9437990586017737,\n", - " 0.9770798055002058,\n", - " 0.8256185133883598,\n", - " 0.8281322155107563,\n", - " 1.095986525120956],\n", - " [1.5131399634886111,\n", - " 1.5999155603569504,\n", - " 1.5115796168242792,\n", - " 1.446659752744492,\n", - " 1.235356950151376,\n", - " 0.967072953166603,\n", - " 0.859989517636921,\n", - " 1.2344246596698993,\n", - " 0.987335422865397,\n", - " 1.2571665619513466,\n", - " 0.9946099615435798,\n", - " 1.1997631227319576,\n", - " 0.9065267557570403,\n", - " 0.8623732569844076,\n", - " 0.9157866886589251,\n", - " 0.8704202799321991,\n", - " 0.8951210057562146,\n", - " 1.2243937179420878,\n", - " 1.0451594188685709,\n", - " 0.7870806630856159]],\n", - " 'y_true': [[1.5368549746246047,\n", - " 1.354634969632929,\n", - " 1.468365235186352,\n", - " 1.432583175562272,\n", - " 1.5953749264201045,\n", - " 1.5155996316597349,\n", - " 0.9473464815039158,\n", - " 1.0067917765775811,\n", - " 1.1683311644901675,\n", - " 1.1669372644677438,\n", - " 1.042634731951559,\n", - " 1.3491979367998248,\n", - " 0.9500993048983102,\n", - " 0.9356712458032672,\n", - " 0.8919129473740345,\n", - " 1.050051966761668,\n", - " 0.8170810501327951,\n", - " 0.7654839976332243,\n", - " 0.897154923929182,\n", - " 0.6562046541374233],\n", - " [1.6090842663557108,\n", - " 1.4621137150191188,\n", - " 1.7027129046257758,\n", - " 1.3168271971110843,\n", - " 0.9232822581274416,\n", - " 0.8795262338727019,\n", - " 0.7590143519912055,\n", - " 1.0431462202954926,\n", - " 0.941929825269014,\n", - " 0.953297968558842,\n", - " 0.482990999782018,\n", - " 0.8461856079301379,\n", - " 0.8677945407345222,\n", - " 0.957337036408988,\n", - " 0.9422039933002132,\n", - " 0.9333570456811676,\n", - " 1.17588838414187,\n", - " 0.969262696715333,\n", - " 0.8354014787894858,\n", - " 1.045900564757796],\n", - " [1.6780268046321445,\n", - " 1.4677877256446736,\n", - " 0.8164530075035101,\n", - " 0.995037499534239,\n", - " 0.9632097083694048,\n", - " 0.9110115231703496,\n", - " 0.816569779371218,\n", - " 0.6448597992716313,\n", - " 1.240579980419104,\n", - " 0.8285725679388085,\n", - " 1.449907560592053,\n", - " 1.0635796836867624,\n", - " 0.9965109965269594,\n", - " 0.9243284548523027,\n", - " 1.5049105252470611,\n", - " 1.2881527343668449,\n", - " 1.014500726239169,\n", - " 0.9917168457135808,\n", - " 1.3711145163991243,\n", - " 0.9099871445205638],\n", - " [1.5225364023852097,\n", - " 1.1703898717120615,\n", - " 1.4110015022881397,\n", - " 0.8488938986725486,\n", - " 1.2016814372062825,\n", - " 1.3590846720972698,\n", - " 1.332394937496506,\n", - " 1.2512817028768608,\n", - " 0.8857464212343196,\n", - " 1.3243060125182828,\n", - " 0.7192116523596643,\n", - " 1.0789196251188578,\n", - " 0.9073749955687876,\n", - " 0.8704382894938802,\n", - " 0.7940547994548796,\n", - " 0.8998348279543004,\n", - " 0.9629522026648972,\n", - " 1.2367215748015905,\n", - " 0.6927009076837077,\n", - " 1.061658542665004],\n", - " [1.5269084877210015,\n", - " 1.5418760346263127,\n", - " 1.4840953658081717,\n", - " 1.513990789832791,\n", - " 1.3299708078058443,\n", - " 0.9241146135501812,\n", - " 0.8403109062618637,\n", - " 1.5116387014497448,\n", - " 0.941793741682284,\n", - " 1.4092692450224695,\n", - " 0.8731110155829619,\n", - " 1.1760985853239514,\n", - " 0.84639945250204,\n", - " 0.9314097391725594,\n", - " 0.8742672046769618,\n", - " 0.8651805875174402,\n", - " 0.9675233068032923,\n", - " 1.0909693027831833,\n", - " 1.0754003793297588,\n", - " 0.7580068612234607]]},\n", - " 'model_size': {'mae': [7685169.652999999,\n", - " 7415964.952499999,\n", - " 6376521.887500001,\n", - " 7987441.926500002,\n", - " 5925530.132999999],\n", - " 'y_pred': [[31748552.36,\n", - " 31748552.36,\n", - " 31748552.36,\n", - " 31748552.36,\n", - " 9301260.22,\n", - " 27660297.88,\n", - " 12913672.6,\n", - " 18186586.08,\n", - " 71267520.59,\n", - " 28021586.32,\n", - " 59307251.77,\n", - " 29296837.77,\n", - " 6955578.7,\n", - " 69463756.22,\n", - " 18430565.41,\n", - " 26020299.6,\n", - " 20214633.62,\n", - " 10855294.48,\n", - " 16527473.03,\n", - " 7390933.59],\n", - " [12706501.16,\n", - " 12706501.16,\n", - " 12706501.16,\n", - " 12706501.16,\n", - " 13112502.5,\n", - " 19286362.42,\n", - " 7217941.92,\n", - " 8166640.3,\n", - " 20552773.59,\n", - " 65239937.16,\n", - " 14878526.27,\n", - " 18397631.56,\n", - " 59068318.48,\n", - " 73900487.06,\n", - " 28608262.87,\n", - " 9251750.62,\n", - " 8915253.24,\n", - " 38844889.04,\n", - " 19273094.57,\n", - " 21965206.69],\n", - " [29629095.12,\n", - " 29629095.12,\n", - " 12053296.17,\n", - " 29629095.12,\n", - " 75664404.03,\n", - " 13107456.84,\n", - " 12875849.13,\n", - " 6520985.37,\n", - " 23659107.35,\n", - " 11372748.16,\n", - " 11574762.56,\n", - " 27277148.17,\n", - " 24450764.13,\n", - " 66727161.47,\n", - " 13343967.74,\n", - " 85768467.49,\n", - " 72033486.4,\n", - " 9823250.47,\n", - " 64812296.15,\n", - " 54510690.28],\n", - " [15643649.29,\n", - " 57202105.02,\n", - " 57202105.02,\n", - " 10289500.47,\n", - " 55523251.1,\n", - " 70367512.4,\n", - " 18992382.27,\n", - " 21567118.92,\n", - " 21431064.15,\n", - " 10437022.92,\n", - " 33683562.09,\n", - " 64273865.78,\n", - " 11721530.91,\n", - " 15768739.21,\n", - " 11487622.58,\n", - " 17249875.73,\n", - " 34189479.21,\n", - " 25481965.81,\n", - " 10932738.73,\n", - " 35851355.2],\n", - " [10498125.83,\n", - " 10601349.26,\n", - " 11757068.89,\n", - " 16322486.52,\n", - " 8364310.65,\n", - " 10338359.44,\n", - " 25787910.26,\n", - " 34969170.54,\n", - " 8473082.27,\n", - " 11804697.33,\n", - " 15285946.12,\n", - " 13679656.17,\n", - " 9804040.85,\n", - " 85236074.77,\n", - " 10465363.04,\n", - " 69454719.01,\n", - " 75274700.21,\n", - " 15788693.84,\n", - " 22011361.68,\n", - " 6797345.46]],\n", - " 'y_true': [[41159778,\n", - " 41159778,\n", - " 41159778,\n", - " 41159778,\n", - " 5102929,\n", - " 40617714,\n", - " 6997884,\n", - " 27062263,\n", - " 56853250,\n", - " 42985088,\n", - " 77029944,\n", - " 29944886,\n", - " 3031520,\n", - " 79253550,\n", - " 22855870,\n", - " 13342089,\n", - " 22904778,\n", - " 8998537,\n", - " 17351820,\n", - " 7565082],\n", - " [18601918,\n", - " 18601918,\n", - " 18601918,\n", - " 18601918,\n", - " 7674204,\n", - " 27351886,\n", - " 4687410,\n", - " 7841161,\n", - " 26151862,\n", - " 61236446,\n", - " 6428942,\n", - " 17893714,\n", - " 54964846,\n", - " 89098622,\n", - " 11026806,\n", - " 7893004,\n", - " 5905129,\n", - " 68864068,\n", - " 8242593,\n", - " 29485308],\n", - " [25128952,\n", - " 25128952,\n", - " 15325998,\n", - " 25128952,\n", - " 72081090,\n", - " 13060242,\n", - " 13264468,\n", - " 3748587,\n", - " 18660636,\n", - " 7586719,\n", - " 5346557,\n", - " 18032692,\n", - " 19951674,\n", - " 76211851,\n", - " 11395531,\n", - " 101380714,\n", - " 51767910,\n", - " 7797450,\n", - " 50201360,\n", - " 43248868],\n", - " [12812508,\n", - " 73233782,\n", - " 73233782,\n", - " 5305072,\n", - " 47576038,\n", - " 112092000,\n", - " 15703767,\n", - " 37629854,\n", - " 29786462,\n", - " 3556682,\n", - " 31306682,\n", - " 53040399,\n", - " 12911521,\n", - " 9244838,\n", - " 12012606,\n", - " 15718572,\n", - " 35350584,\n", - " 19749710,\n", - " 11673447,\n", - " 40447886],\n", - " [3400847,\n", - " 6985015,\n", - " 8778822,\n", - " 14357975,\n", - " 6779961,\n", - " 6016406,\n", - " 11484551,\n", - " 23918572,\n", - " 6921258,\n", - " 8005441,\n", - " 17032265,\n", - " 11238030,\n", - " 8908197,\n", - " 96337570,\n", - " 7525216,\n", - " 85178362,\n", - " 55216167,\n", - " 20940184,\n", - " 19718419,\n", - " 2906496]]},\n", - " 'runtime': {'mae': [61.035749037981034,\n", - " 95.65970339143279,\n", - " 63.12033299612999,\n", - " 71.04112536096572,\n", - " 64.65392153012752],\n", - " 'y_pred': [[52.875231685638425,\n", - " 76.65202642679215,\n", - " 81.40229902744294,\n", - " 96.19935775518417,\n", - " 35.146308376789094,\n", - " 67.51439192295075,\n", - " 61.51354663848877,\n", - " 55.33929547548294,\n", - " 92.81132875680923,\n", - " 173.54928274631501,\n", - " 108.78835130691529,\n", - " 58.64955832958221,\n", - " 144.53799430131912,\n", - " 223.42509981632233,\n", - " 93.1848129272461,\n", - " 196.73684736251832,\n", - " 213.0635284614563,\n", - " 178.3861657357216,\n", - " 143.48366355657578,\n", - " 204.03698788642885],\n", - " [65.62846801280975,\n", - " 97.97939052343368,\n", - " 61.810300290584564,\n", - " 50.9579337477684,\n", - " 186.24329578876495,\n", - " 115.73471203804016,\n", - " 186.37057279586793,\n", - " 112.67457494497299,\n", - " 59.409752097129825,\n", - " 83.54339394330978,\n", - " 225.72646994113921,\n", - " 267.267543451786,\n", - " 175.14875828027726,\n", - " 163.36634949445724,\n", - " 149.33159998178482,\n", - " 285.50515351057055,\n", - " 140.39637072324751,\n", - " 57.18339666604996,\n", - " 173.057906999588,\n", - " 75.97468067407608],\n", - " [103.2425888299942,\n", - " 21.249649152755737,\n", - " 255.3177734351158,\n", - " 165.80423557043076,\n", - " 143.2561492228508,\n", - " 238.89448299646378,\n", - " 293.60839938879013,\n", - " 246.08150813102722,\n", - " 118.81786162137985,\n", - " 230.92243660926817,\n", - " 29.917886481285095,\n", - " 75.69309362173081,\n", - " 244.86630887031555,\n", - " 177.65726355314254,\n", - " 34.15758186101913,\n", - " 79.37328704833985,\n", - " 147.83103462696076,\n", - " 198.3243732881546,\n", - " 233.69759202718734,\n", - " 88.13960065364837],\n", - " [38.22541909456253,\n", - " 67.39686560153962,\n", - " 96.88583463668823,\n", - " 124.29468606710434,\n", - " 134.39707334041594,\n", - " 286.01724915742875,\n", - " 202.0696332192421,\n", - " 287.7504031777382,\n", - " 165.280366461277,\n", - " 161.00114835739134,\n", - " 230.7469719028473,\n", - " 121.32596241474151,\n", - " 287.1210882949829,\n", - " 143.58286722660065,\n", - " 158.74365959405898,\n", - " 92.38912449121476,\n", - " 60.00808850288391,\n", - " 228.0931861972809,\n", - " 220.54842390060423,\n", - " 108.45025491237641],\n", - " [56.46034274578094,\n", - " 65.54415522575378,\n", - " 55.75057134389877,\n", - " 57.13177669286728,\n", - " 114.16553794622422,\n", - " 80.16156712293625,\n", - " 238.8992388010025,\n", - " 165.8148323774338,\n", - " 299.2917211318016,\n", - " 96.49497563838959,\n", - " 114.77048594713212,\n", - " 158.92026156902313,\n", - " 197.08038917303085,\n", - " 266.2151762461662,\n", - " 188.52955590963364,\n", - " 249.78887855291367,\n", - " 158.62946053028108,\n", - " 267.39132384061816,\n", - " 61.368955957889554,\n", - " 195.17161348104477]],\n", - " 'y_true': [[50.76671862602234,\n", - " 13.19624400138855,\n", - " 19.814854621887207,\n", - " 26.702231645584103,\n", - " 12.969313144683838,\n", - " 60.6592378616333,\n", - " 54.58466482162476,\n", - " 44.85875821113586,\n", - " 54.25396156311035,\n", - " 110.62842774391174,\n", - " 124.68207693099976,\n", - " 18.491509437561035,\n", - " 56.84877347946167,\n", - " 412.9428143501282,\n", - " 111.9625208377838,\n", - " 463.8674437999725,\n", - " 194.7752683162689,\n", - " 71.33686780929565,\n", - " 89.62382888793945,\n", - " 281.8189051151276],\n", - " [155.83911991119385,\n", - " 11.38468098640442,\n", - " 13.824389934539797,\n", - " 6.711249351501465,\n", - " 50.69706988334656,\n", - " 161.21976828575134,\n", - " 195.9798419475556,\n", - " 57.24438500404358,\n", - " 66.38810992240906,\n", - " 240.0329852104187,\n", - " 594.8048851490021,\n", - " 389.2514734268189,\n", - " 106.56243181228638,\n", - " 622.2987248897552,\n", - " 113.0280704498291,\n", - " 257.1114857196808,\n", - " 111.37134504318236,\n", - " 34.5356810092926,\n", - " 267.2505464553833,\n", - " 70.50088453292847],\n", - " [156.34023070335388,\n", - " 98.9586899280548,\n", - " 163.69966220855713,\n", - " 136.59544610977173,\n", - " 155.54405570030212,\n", - " 164.82566928863525,\n", - " 273.2950313091278,\n", - " 129.67002964019775,\n", - " 86.27063322067261,\n", - " 109.60883498191832,\n", - " 4.304563522338867,\n", - " 41.56924533843994,\n", - " 338.26779794692993,\n", - " 182.1256983280182,\n", - " 12.22360372543335,\n", - " 20.86880207061768,\n", - " 36.924010276794434,\n", - " 78.03244352340698,\n", - " 87.01349568367004,\n", - " 70.23752951622009],\n", - " [50.72808718681336,\n", - " 24.536136627197266,\n", - " 43.1810953617096,\n", - " 96.64365434646606,\n", - " 80.80923509597778,\n", - " 158.458904504776,\n", - " 203.4946177005768,\n", - " 280.6553316116333,\n", - " 311.0933437347412,\n", - " 110.7545862197876,\n", - " 133.32953357696533,\n", - " 96.6609344482422,\n", - " 216.9139444828033,\n", - " 417.8767619132996,\n", - " 450.074337720871,\n", - " 60.3603572845459,\n", - " 59.7351233959198,\n", - " 205.33756804466248,\n", - " 247.85906052589417,\n", - " 50.3548641204834],\n", - " [50.76292824745178,\n", - " 19.667855501174927,\n", - " 21.57206130027771,\n", - " 44.222909450531006,\n", - " 106.16650390625,\n", - " 76.2901964187622,\n", - " 212.37204718589783,\n", - " 62.14056086540222,\n", - " 153.83009886741638,\n", - " 53.25572633743286,\n", - " 56.82449007034302,\n", - " 80.68149018287659,\n", - " 339.5134189128876,\n", - " 403.2442240715027,\n", - " 290.67460441589355,\n", - " 349.98391246795654,\n", - " 114.70717477798462,\n", - " 300.6930778026581,\n", - " 19.84216070175171,\n", - " 322.07845091819763]]}}" - ] - }, - "execution_count": 89, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "main_results = {\n", " \"mae\": mae_results,\n", @@ -1198,42 +380,24 @@ " \"runtime\": runtime_results,\n", "}\n", "with open(path.join(data_dir, \"model_metadata.json\"), \"w\") as f:\n", - " json.dump(main_results, f)\n", - " \n", - "main_results" + " json.dump(main_results, f)" ] }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'mae': [RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13)],\n", - " 'rmse': [RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13)],\n", - " 'model_size': [RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13)],\n", - " 'runtime': [RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13),\n", - " RandomForestRegressor(random_state=13)]}" + "{'mae': RandomForestRegressor(random_state=13),\n", + " 'rmse': RandomForestRegressor(random_state=13),\n", + " 'model_size': RandomForestRegressor(random_state=13),\n", + " 'runtime': RandomForestRegressor(random_state=13)}" ] }, - "execution_count": 90, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -1245,12 +409,13 @@ " \"model_size\": model_size_model_stem,\n", " \"runtime\": runtime_model_stem,\n", "}\n", - "models = {}\n", - "for key, model_path in model_paths.items():\n", - " models[key] = [joblib.load(f\"{model_path}_{i}.pkl\") for i in range(5)]\n", + "for i in range(5):\n", + " models = {}\n", + " for key, model_path in model_paths.items():\n", + " models[key] = joblib.load(f\"{model_path}_{i}.pkl\")\n", "\n", - "with open(path.join(data_dir, \"cross_validation_models.pkl\"), \"wb\") as f:\n", - " pickle.dump(models, f)\n", + " with open(path.join(cv_model_dir, f\"cross_validation_models_{i}.pkl\"), \"wb\") as f:\n", + " joblib.dump(models, f, compress=7)\n", "\n", "models" ] @@ -1260,142 +425,81 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### code graveyard " + "## Production models (full training data)\n", + "Six keys in the dictionary, each key is a value of a label, and its value pair is the trained model.\n", + "This trained model is stored in the models folder with the pickle file name \"trained_model.pkl\"" ] }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ - "### Ordinal Encoding\n", - "\n", - "# \"bias\" [False, True]},\n", - "# \"criterion\" [\"RobustL1\", \"RobustL2\"]},\n", - "# \"elem_prop\"[\"mat2vec\", \"magpie\", \"onehot\"],\n", - "\n", - "# sobol_reg[\"bias\"].replace([\"False\", \"True\"], [0,1], inplace=True)\n", - "# sobol_reg[\"criterion\"].replace([\"RobustL1\", \"RobustL2\"], [0,1], inplace=True)\n", - "# sobol_reg[\"elem_prop\"].replace([\"mat2vec\", \"magpie\", \"onehot\"], [0,1,2], inplace=True)\n", - "# sobol_reg[\"hardware\"].replace([\"2080ti\"], [0], inplace=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [], - "source": [ - "# print(\"Average MAE for fba_isna_prob\",rfr_mae(X_array_fba_isna_prob, y_array_fba_isna_prob,'fba_isna_prob.pkl'))\n", - "\n", - "# load trained model\n", - "# loaded_model = joblib.load('fba_isna_prob_model.pkl')\n", - "\n", - "# Save the model\n", - "# with open('../models/fba_isna_prob.pkl', 'wb') as f:\n", - "# pickle.dump(model, f)\n", - "\n", - "# # Load the model\n", - "# with open('path/to/save/model.pkl', 'rb') as f:\n", - "# loaded_model = pickle.load(f)" + "def train_and_save(\n", + " sr_feat_array,\n", + " sr_labels_array,\n", + " sr_label_names,\n", + "):\n", + " models = {}\n", + "\n", + " for X1, y1, name1 in zip(sr_feat_array, sr_labels_array, sr_label_names):\n", + " print(f\"X1 sr shape: {X1.shape}, Y1 sr shape: {y1.shape}\")\n", + " model = RandomForestRegressor(random_state=13)\n", + " model.fit(X1, y1)\n", + " models[name1] = model\n", + "\n", + " return models" ] }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 36, "metadata": {}, - "outputs": [], - "source": [ - "# url_sobol_filter = \"https://zenodo.org/record/7513019/files/sobol_probability_filter.csv\"\n", - "# sobol_filter = pd.read_csv(url_sobol_filter)\n", - "\n", - "# url_sobol_reg = \"https://zenodo.org/record/7513019/files/sobol_regression.csv\"\n", - "# sobol_reg = pd.read_csv(url_sobol_reg)" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "X1 sr shape: (173219, 29), Y1 sr shape: (173219,)\n", + "X1 sr shape: (173219, 29), Y1 sr shape: (173219,)\n", + "X1 sr shape: (173219, 28), Y1 sr shape: (173219,)\n", + "X1 sr shape: (173219, 29), Y1 sr shape: (173219,)\n" + ] + }, + { + "data": { + "text/plain": [ + "['..\\\\..\\\\models\\\\crabnet_hyperparameter\\\\surrogate_models.pkl']" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# os.getcwd()\n", - "# os.chdir(\"../data/raw\")\n", - "\n", - "# sobol_filter.to_csv('sobol_filter.csv', index=False)\n", + "# List of x_arrays, y_arrays, and target_names\n", + "sobol_reg_x_arrays = [X_array_mae, X_array_rmse, X_array_model_size, X_array_runtime]\n", + "sobol_reg_labels = [y_array_mae, y_array_rmse, y_array_model_size, y_array_runtime]\n", + "sobol_reg_target_names = [\"mae\", \"rmse\", \"model_size\", \"runtime\"]\n", + "\n", + "# Train and save the model on all the data\n", + "models = train_and_save(\n", + " sobol_reg_x_arrays,\n", + " sobol_reg_labels,\n", + " sobol_reg_target_names,\n", + ")\n", "\n", - "# sobol_reg.to_csv('sobol_reg.csv', index=False)" + "joblib.dump(models, path.join(model_dir, \"surrogate_models.pkl\"), compress=7)" ] }, { "cell_type": "code", - "execution_count": 55, + "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# read in sobol_regression.csv\n", - "# url_sobol_reg = \"https://zenodo.org/record/7513019/files/sobol_regression.csv\"" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "metadata": {}, - "outputs": [], - "source": [ - "# sobol_reg_x = sobol_reg[\n", - "# [\n", - "# \"mu1_div_mu3\",\n", - "# \"mu2_div_mu3\",\n", - "# \"std1\",\n", - "# \"std2\",\n", - "# \"std3\",\n", - "# \"comp1\",\n", - "# \"comp2\",\n", - "# \"num_particles\",\n", - "# \"safety_factor\",\n", - "# \"fba_rank\",\n", - "# \"ls_rank\",\n", - "# \"fba_time_s_rank\",\n", - "# \"ls_time_s_rank\",\n", - "# ]\n", - "# ]" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [], - "source": [ - "# print(len(sobol_reg_x))\n", - "# print(len(fba))" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "metadata": {}, - "outputs": [], - "source": [ - "# print(\n", - "# \"Average MAE for ls_time_s\",\n", - "# rfr_mae(X_array_fba_time_s, y_array_ls_time_s, \"sobol_reg_ls_time_s.pkl\"),\n", - "# )" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "metadata": {}, - "outputs": [], - "source": [ - "# # parse data for target \"fba_isna_prob\"\n", - "# fba_isna_prob = sobol_filter[\"fba_isna_prob\"]\n", - "# sobolPF_fba_isna_prob = sobol_filter.drop([\"ls_isna_prob\", \"fba_isna_prob\"], axis=1)\n", - "# fba_isna_prob = fba_isna_prob.to_frame()" - ] + "source": [] } ], "metadata": { @@ -1419,7 +523,7 @@ "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "5814ca482226f814bec6d6a290f5fe630a2a0cb71af6cb85828c242f7bf3b839" + "hash": "01883adffc5ff99e80740fdb2688c7d7f1b5220f2274814f600fbe3b3887f376" } } }, diff --git a/notebooks/crabnet_hyperparameter/1.3-sgb-zenodo-upload.ipynb b/notebooks/crabnet_hyperparameter/1.3-sgb-zenodo-upload.ipynb index 3bc1d71..a645c17 100644 --- a/notebooks/crabnet_hyperparameter/1.3-sgb-zenodo-upload.ipynb +++ b/notebooks/crabnet_hyperparameter/1.3-sgb-zenodo-upload.ipynb @@ -41,11 +41,20 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "'touch' is not recognized as an internal or external command,\n", + "operable program or batch file.\n" + ] + } + ], "source": [ - "!touch $HOME/.config/zenodo.ini" + "# !touch $HOME/.config/zenodo.ini" ] }, { @@ -59,16 +68,16 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ - "sandbox = True" + "sandbox = False" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -76,12 +85,13 @@ "from my_secrets import ZENODO_API_KEY, ZENODO_SANDBOX_API_KEY\n", "\n", "task_name = \"crabnet-hyperparameter\"\n", + "task_name_underscore = task_name.replace(\"-\", \"_\")\n", "\n", "# Define the metadata that will be used on initial upload\n", "data = Metadata(\n", - " title=\"Materials Science Optimization Benchmark Dataset for Multi-fidelity Hard-sphere Packing Simulations\",\n", + " title=\"Materials Science Optimization Benchmark Dataset for High-dimensional, Multi-objective, Multi-fidelity Optimization of CrabNet Hyperparameters\",\n", " upload_type=\"dataset\",\n", - " description=\"Benchmarks are an essential driver of progress in scientific disciplines. Ideal benchmarks mimic real-world tasks as closely as possible, where insufficient difficulty or applicability can stunt growth in the field. Benchmarks should also have sufficiently low computational overhead to promote accessibility and repeatability. The goal is then to win a “Turing test” of sorts by creating a surrogate model that is indistinguishable from the ground truth observation (at least within the dataset bounds that were explored), necessitating a large amount of data. In the fields of materials science and chemistry, industry-relevant optimization tasks are often hierarchical, noisy, multi-fidelity, multi-objective, high-dimensional, and non-linearly correlated while exhibiting mixed numerical and categorical variables subject to linear and non-linear constraints. To complicate matters, unexpected, failed simulation or experimental regions may be present in the search space. In this study, 438371 random hard-sphere packing simulations representing 279 CPU days worth of computational overhead were performed across nine input parameters with linear constraints and two discrete fidelities each with continuous fidelity parameters and results were logged to a free-tier shared MongoDB Atlas database. Two core tabular datasets resulted from this study: 1. a failure probability dataset containing unique input parameter sets and the estimated probabilities that the simulation will fail at each of the two steps, and 2. a regression dataset mapping input parameter sets (including repeats) to particle packing fractions and computational runtimes for each of the two steps. These two datasets can be used to create a surrogate model as close as possible to running the actual simulations by incorporating simulation failure and heteroskedastic noise. For the regression dataset, percentile ranks were computed within each of the groups of identical parameter sets to enable capturing heteroskedastic noise. This is in contrast with a more traditional approach that imposes a-priori assumptions such as Gaussian noise e.g., by providing a mean and standard deviation. A similar approach can be applied to other benchmark datasets to bridge the gap between optimization benchmarks with low computational overhead and realistically complex, real-world optimization scenarios.\",\n", + " description=\"Benchmarks are an essential driver of progress in scientific disciplines. Ideal benchmarks mimic real-world tasks as closely as possible, where insufficient difficulty or applicability can stunt growth in the field. Benchmarks should also have sufficiently low computational overhead to promote accessibility and repeatability. The goal is then to win a “Turing test” of sorts by creating a surrogate model that is indistinguishable from the ground truth observation (at least within the dataset bounds that were explored), necessitating a large amount of data. In materials science and chemistry, industry-relevant optimization tasks are often hierarchical, noisy, multi-fidelity, multi-objective, high-dimensional, and non-linearly correlated while exhibiting mixed numerical and categorical variables subject to linear and non-linear constraints. To complicate matters, unexpected, failed simulation or experimental regions may be present in the search space. In this study, 173219 quasi-random hyperparameter combinations were generated across 23 hyperparameters and used to train CrabNet on the Matbench experimental band gap dataset. The results were logged to a free-tier shared MongoDB Atlas dataset. This study resulted in a regression dataset mapping hyperparameter combinations (including repeats) to MAE, RMSE, computational runtime, and model size for CrabNet model trained on the Matbench experimental band gap benchmark task1. This dataset is used to create a surrogate model as close as possible to running the actual simulations by incorporating heteroskedastic noise. Failure cases for bad hyperparameter combinations were excluded via careful construction of the hyperparameter search space, and so were not considered as was done in prior work. For the regression dataset, percentile ranks were computed within each of the groups of identical parameter sets to enable capturing heteroskedastic noise. This contrasts with a more traditional approach that imposes a-priori assumptions such as Gaussian noise, e.g., by providing a mean and standard deviation. A similar approach can be applied to other benchmark datasets to bridge the gap between optimization benchmarks with low computational overhead and realistically complex, real-world optimization scenarios.\",\n", " creators=[\n", " Creator(\n", " name=\"Baird, Sterling G.\",\n", @@ -89,24 +99,28 @@ " orcid=\"0000-0002-4491-6876\",\n", " ),\n", " Creator(\n", - " name=\"Parikh, Jeet\",\n", - " affiliation=\"\",\n", - " orcid=\"\",\n", + " name=\"Parikh, Jeet N.\",\n", + " affiliation=\"Northwood High School\",\n", + " orcid=\"0000-0002-8706-2962\",\n", " ),\n", " ],\n", ")\n", "\n", - "# unique keys generate a new deposition ID\n", + "# unique keys generate a new deposition ID in $HOME/.config/zenodo.ini\n", "key = f\"matsciopt-{task_name}-benchmark-dataset\"\n", "access_token = ZENODO_SANDBOX_API_KEY if sandbox else ZENODO_API_KEY\n", "res = ensure_zenodo(\n", " key,\n", " data=data,\n", " paths=[\n", - " f\"data/processed/{task_name}/sobol_probability_filter.csv\",\n", - " f\"data/processed/{task_name}/sobol_regression.csv\",\n", - " f\"data/processed/{task_name}/model_metadata.json\",\n", - " # f\"data/processed/{task_name}/cross_validation_models.pkl\",\n", + " f\"../../data/processed/{task_name_underscore}/sobol_regression.csv\",\n", + " f\"../../data/processed/{task_name_underscore}/model_metadata.json\",\n", + " f\"../../models/{task_name_underscore}/surrogate_models.pkl\",\n", + " f\"../../models/{task_name_underscore}/cv/cross_validation_models_0.pkl\",\n", + " f\"../../models/{task_name_underscore}/cv/cross_validation_models_1.pkl\",\n", + " f\"../../models/{task_name_underscore}/cv/cross_validation_models_2.pkl\",\n", + " f\"../../models/{task_name_underscore}/cv/cross_validation_models_3.pkl\",\n", + " f\"../../models/{task_name_underscore}/cv/cross_validation_models_4.pkl\",\n", " ],\n", " sandbox=sandbox, # remove this when you're ready to upload to real Zenodo\n", " access_token=access_token,\n", @@ -115,6 +129,13 @@ "\n", "pprint(res.json())\n" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {