diff --git a/massspecgym/models/base.py b/massspecgym/models/base.py index 2359930..e35e1c7 100644 --- a/massspecgym/models/base.py +++ b/massspecgym/models/base.py @@ -28,7 +28,7 @@ def __init__( lr: float = 1e-4, weight_decay: float = 0.0, log_only_loss_at_stages: T.Sequence[Stage | str] = (), - bootstrap_metrics: bool = True, + bootstrap_metrics: bool = False, df_test_path: T.Optional[str | Path] = None, *args, **kwargs diff --git a/notebooks/evaluation.ipynb b/notebooks/evaluation.ipynb index 5a4b5aa..3b12dfd 100644 --- a/notebooks/evaluation.ipynb +++ b/notebooks/evaluation.ipynb @@ -6,6 +6,7 @@ "metadata": {}, "outputs": [], "source": [ + "import random\n", "from pathlib import Path\n", "\n", "import numpy as np\n", @@ -14,12 +15,26 @@ "from scipy.stats import bootstrap\n", "from tqdm import tqdm\n", "\n", - "tqdm.pandas()" + "tqdm.pandas()\n", + "\n", + "# Set random seeds for reproducibility\n", + "seed = 0\n", + "random.seed(seed)\n", + "np.random.seed(seed)\n", + "pd.set_option('compute.use_numexpr', False) # Disable numexpr to ensure reproducibility\n", + "pd.set_option('compute.use_bottleneck', False) # Disable bottleneck to ensure reproducibility" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebooks takes pickled test dataframes automatically stored during testing of the models (i.e., running `trainer.test(model, ...)`) and calculates means and confidence intervals for all metrics. The cell below shows an example of a test dataframe.\n" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -199,19 +214,19 @@ "[17556 rows x 6 columns]" ] }, - "execution_count": 19, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Example of a test dataframe:\n", + "# Example of a test dataframe for the retrieval challenge:\n", "pd.read_pickle('../data/test_results/retrieval/rebuttal_MIST_test_formula_2024-08-13_15-07-19.pkl')" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -245,11 +260,12 @@ "\n", " # Calculate confidence intervals for all metrics into a single table\n", " def get_ci(col_vals, confidence_level=0.999, n_resamples=20_000):\n", - " res = bootstrap((col_vals,), np.mean, confidence_level=confidence_level, n_resamples=n_resamples)\n", + " res = bootstrap((col_vals,), np.mean, confidence_level=confidence_level, n_resamples=n_resamples, random_state=seed)\n", " ci = res.confidence_interval\n", " return f'{ci.low:.2f}-{ci.high:.2f}'\n", " def get_ci_for_each_col(df_method):\n", " return df_method.apply(get_ci, axis=0)\n", + " tqdm.pandas(desc=\"Bootstrapping predictions for each method\", postfix=None)\n", " df_ci = df.groupby('method')[metric_cols].progress_apply(lambda df_method: get_ci_for_each_col(df_method))\n", "\n", " # Merge tables with means and confidence intervals\n", @@ -258,416 +274,198 @@ " return df_mean" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation for the retrieval challenge" + ] + }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 9/9 [05:55<00:00, 39.47s/it]\n" + "Bootstrapping predictions for each method: 100%|██████████| 13/13 [07:49<00:00, 36.10s/it]\n" ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
test_hit_rate@1test_hit_rate@5test_hit_rate@20test_mces@1
method
rebuttal_MIST_test_formula_2024-08-13_15-07-199.57 (8.88-10.30)22.11 (21.13-23.24)41.12 (39.91-42.29)12.75 (12.58-12.92)
rebuttal_deepsets_test_formula_2024-08-15_16-45-064.42 (3.91-4.93)14.46 (13.60-15.39)30.76 (29.64-31.90)15.04 (14.89-15.19)
rebuttal_deepsets_test_mass_2024-08-14_22-51-051.47 (1.20-1.79)6.21 (5.63-6.84)19.23 (18.27-20.22)25.11 (24.84-25.38)
rebuttal_enhanced_MIST_test_mass_2024-08-13_01-18-4414.64 (13.78-15.53)34.87 (33.70-36.06)59.15 (57.95-60.33)15.37 (15.13-15.62)
rebuttal_fingerprint_ffn_test_formula_2024-08-15_15-45-025.09 (4.57-5.62)14.69 (13.83-15.57)31.97 (30.80-33.13)14.94 (14.79-15.10)
rebuttal_fingerprint_ffn_test_mass_2024-08-15_15-39-322.54 (2.16-2.97)7.59 (6.93-8.27)20.0 (19.06-21.05)24.66 (24.37-24.95)
rebuttal_random_test_formula_2024-08-13_16-14-073.06 (2.67-3.51)11.35 (10.59-12.14)27.74 (26.62-28.94)13.87 (13.70-14.03)
rebuttal_random_test_formula_2024-08-13_17-08-093.06 (2.64-3.50)11.35 (10.58-12.13)27.74 (26.66-28.80)13.87 (13.70-14.03)
rebuttal_random_test_mass_2024-08-13_17-08-090.37 (0.24-0.54)2.01 (1.68-2.38)8.22 (7.57-8.93)30.81 (30.43-31.24)
\n", - "
" - ], - "text/plain": [ - " test_hit_rate@1 \\\n", - "method \n", - "rebuttal_MIST_test_formula_2024-08-13_15-07-19 9.57 (8.88-10.30) \n", - "rebuttal_deepsets_test_formula_2024-08-15_16-45-06 4.42 (3.91-4.93) \n", - "rebuttal_deepsets_test_mass_2024-08-14_22-51-05 1.47 (1.20-1.79) \n", - "rebuttal_enhanced_MIST_test_mass_2024-08-13_01-... 14.64 (13.78-15.53) \n", - "rebuttal_fingerprint_ffn_test_formula_2024-08-1... 5.09 (4.57-5.62) \n", - "rebuttal_fingerprint_ffn_test_mass_2024-08-15_1... 2.54 (2.16-2.97) \n", - "rebuttal_random_test_formula_2024-08-13_16-14-07 3.06 (2.67-3.51) \n", - "rebuttal_random_test_formula_2024-08-13_17-08-09 3.06 (2.64-3.50) \n", - "rebuttal_random_test_mass_2024-08-13_17-08-09 0.37 (0.24-0.54) \n", - "\n", - " test_hit_rate@5 \\\n", - "method \n", - "rebuttal_MIST_test_formula_2024-08-13_15-07-19 22.11 (21.13-23.24) \n", - "rebuttal_deepsets_test_formula_2024-08-15_16-45-06 14.46 (13.60-15.39) \n", - "rebuttal_deepsets_test_mass_2024-08-14_22-51-05 6.21 (5.63-6.84) \n", - "rebuttal_enhanced_MIST_test_mass_2024-08-13_01-... 34.87 (33.70-36.06) \n", - "rebuttal_fingerprint_ffn_test_formula_2024-08-1... 14.69 (13.83-15.57) \n", - "rebuttal_fingerprint_ffn_test_mass_2024-08-15_1... 7.59 (6.93-8.27) \n", - "rebuttal_random_test_formula_2024-08-13_16-14-07 11.35 (10.59-12.14) \n", - "rebuttal_random_test_formula_2024-08-13_17-08-09 11.35 (10.58-12.13) \n", - "rebuttal_random_test_mass_2024-08-13_17-08-09 2.01 (1.68-2.38) \n", - "\n", - " test_hit_rate@20 \\\n", - "method \n", - "rebuttal_MIST_test_formula_2024-08-13_15-07-19 41.12 (39.91-42.29) \n", - "rebuttal_deepsets_test_formula_2024-08-15_16-45-06 30.76 (29.64-31.90) \n", - "rebuttal_deepsets_test_mass_2024-08-14_22-51-05 19.23 (18.27-20.22) \n", - "rebuttal_enhanced_MIST_test_mass_2024-08-13_01-... 59.15 (57.95-60.33) \n", - "rebuttal_fingerprint_ffn_test_formula_2024-08-1... 31.97 (30.80-33.13) \n", - "rebuttal_fingerprint_ffn_test_mass_2024-08-15_1... 20.0 (19.06-21.05) \n", - "rebuttal_random_test_formula_2024-08-13_16-14-07 27.74 (26.62-28.94) \n", - "rebuttal_random_test_formula_2024-08-13_17-08-09 27.74 (26.66-28.80) \n", - "rebuttal_random_test_mass_2024-08-13_17-08-09 8.22 (7.57-8.93) \n", - "\n", - " test_mces@1 \n", - "method \n", - "rebuttal_MIST_test_formula_2024-08-13_15-07-19 12.75 (12.58-12.92) \n", - "rebuttal_deepsets_test_formula_2024-08-15_16-45-06 15.04 (14.89-15.19) \n", - "rebuttal_deepsets_test_mass_2024-08-14_22-51-05 25.11 (24.84-25.38) \n", - "rebuttal_enhanced_MIST_test_mass_2024-08-13_01-... 15.37 (15.13-15.62) \n", - "rebuttal_fingerprint_ffn_test_formula_2024-08-1... 14.94 (14.79-15.10) \n", - "rebuttal_fingerprint_ffn_test_mass_2024-08-15_1... 24.66 (24.37-24.95) \n", - "rebuttal_random_test_formula_2024-08-13_16-14-07 13.87 (13.70-14.03) \n", - "rebuttal_random_test_formula_2024-08-13_17-08-09 13.87 (13.70-14.03) \n", - "rebuttal_random_test_mass_2024-08-13_17-08-09 30.81 (30.43-31.24) " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ "dir_results = Path('../data/test_results/retrieval')\n", "task = 'retrieval'\n", "\n", - "df = evaluate(dir_results, task)\n", - "df" + "df = evaluate(dir_results, task)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Main challenge" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "| method | test_hit_rate@1 | test_hit_rate@5 | test_hit_rate@20 | test_mces@1 |\n", + "|:---------------------------------------------------------------------------------|:--------------------|:--------------------|:--------------------|:--------------------|\n", + "| rebuttal_random_test_mass_2024-08-13_17-08-09 | 0.37 (0.24-0.54) | 2.01 (1.68-2.39) | 8.22 (7.53-8.89) | 30.81 (30.40-31.21) |\n", + "| rebuttal_deepsets_test_mass_2024-08-14_22-51-05 | 1.47 (1.18-1.77) | 6.21 (5.64-6.82) | 19.23 (18.24-20.26) | 25.11 (24.84-25.39) |\n", + "| rebuttal_fingerprint_ffn_sigmoid_mist_canopus_1550_test_mass_2024-08-17_02-30-13 | 1.65 (1.36-1.98) | 5.45 (4.89-6.02) | 15.15 (14.29-16.05) | 26.76 (26.47-27.06) |\n", + "| rebuttal_fingerprint_ffn_test_mass_2024-08-15_15-39-32 | 2.54 (2.17-2.99) | 7.59 (6.96-8.28) | 20.0 (19.01-20.98) | 24.66 (24.38-24.94) |\n", + "| rebuttal_deepsets_ff_test_mass_2024-08-17_02-30-13 | 5.24 (4.71-5.83) | 12.58 (11.80-13.42) | 28.21 (27.10-29.36) | 22.13 (21.85-22.43) |\n", + "| rebuttal_enhanced_MIST_test_mass_2024-08-13_01-18-44 | 14.64 (13.82-15.54) | 34.87 (33.69-36.10) | 59.15 (57.89-60.39) | 15.37 (15.12-15.62) |\n" + ] + } + ], + "source": [ + "df_paper = df.reset_index()\n", + "df_paper = df_paper[(~df_paper['method'].str.contains('formula')) | (df_paper['method'].str.contains('no_formula'))]\n", + "df_paper = df_paper.sort_values('test_hit_rate@1', ascending=True, key=lambda x: x.str.split(' ').str[0].astype(float))\n", + "print(df_paper.to_markdown(index=False))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Bonus chemical formulae challenge" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "| method | test_hit_rate@1 | test_hit_rate@5 | test_hit_rate@20 | test_mces@1 |\n", - "|:----------------------------------------------------------|:------------------|:--------------------|:--------------------|:--------------------|\n", - "| rebuttal_random_test_formula_2024-08-13_16-14-07 | 3.06 (2.67-3.51) | 11.35 (10.59-12.14) | 27.74 (26.62-28.94) | 13.87 (13.70-14.03) |\n", - "| rebuttal_random_test_formula_2024-08-13_17-08-09 | 3.06 (2.64-3.50) | 11.35 (10.58-12.13) | 27.74 (26.66-28.80) | 13.87 (13.70-14.03) |\n", - "| rebuttal_deepsets_test_formula_2024-08-15_16-45-06 | 4.42 (3.91-4.93) | 14.46 (13.60-15.39) | 30.76 (29.64-31.90) | 15.04 (14.89-15.19) |\n", - "| rebuttal_fingerprint_ffn_test_formula_2024-08-15_15-45-02 | 5.09 (4.57-5.62) | 14.69 (13.83-15.57) | 31.97 (30.80-33.13) | 14.94 (14.79-15.10) |\n", - "| rebuttal_MIST_test_formula_2024-08-13_15-07-19 | 9.57 (8.88-10.30) | 22.11 (21.13-23.24) | 41.12 (39.91-42.29) | 12.75 (12.58-12.92) |\n" + "| method | test_hit_rate@1 | test_hit_rate@5 | test_hit_rate@20 | test_mces@1 |\n", + "|:------------------------------------------------------------------------------------|:------------------|:--------------------|:--------------------|:--------------------|\n", + "| rebuttal_random_test_formula_2024-08-13_16-14-07 | 3.06 (2.64-3.52) | 11.35 (10.60-12.12) | 27.74 (26.52-28.84) | 13.87 (13.70-14.03) |\n", + "| rebuttal_random_test_formula_2024-08-13_17-08-09 | 3.06 (2.64-3.52) | 11.35 (10.60-12.12) | 27.74 (26.52-28.84) | 13.87 (13.70-14.03) |\n", + "| rebuttal_fingerprint_ffn_sigmoid_mist_canopus_1550_test_formula_2024-08-17_02-30-13 | 4.07 (3.61-4.54) | 13.13 (12.33-13.95) | 29.44 (28.32-30.53) | 15.5 (15.34-15.64) |\n", + "| rebuttal_deepsets_test_formula_2024-08-15_16-45-06 | 4.42 (3.92-4.97) | 14.46 (13.58-15.36) | 30.76 (29.67-31.93) | 15.04 (14.89-15.19) |\n", + "| rebuttal_fingerprint_ffn_test_formula_2024-08-15_15-45-02 | 5.09 (4.57-5.66) | 14.69 (13.83-15.56) | 31.97 (30.86-33.10) | 14.94 (14.79-15.09) |\n", + "| rebuttal_deepsets_ff_test_formula_2024-08-17_02-30-13 | 6.56 (5.95-7.16) | 16.46 (15.58-17.35) | 33.46 (32.39-34.59) | 14.14 (13.98-14.31) |\n", + "| rebuttal_MIST_test_formula_2024-08-13_15-07-19 | 9.57 (8.88-10.30) | 22.11 (21.10-23.13) | 41.12 (39.98-42.34) | 12.75 (12.59-12.91) |\n" ] } ], "source": [ "df_paper = df.reset_index()\n", - "df_paper = df_paper[df_paper['method'].str.contains('formula')]\n", + "df_paper = df_paper[(df_paper['method'].str.contains('formula')) & (~df_paper['method'].str.contains('no_formula'))]\n", "df_paper = df_paper.sort_values('test_hit_rate@1', ascending=True, key=lambda x: x.str.split(' ').str[0].astype(float))\n", "print(df_paper.to_markdown(index=False))" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "| | Hit rate @ 1 ↑ | Hit rate @ 5 ↑ | Hit rate @ 20 ↑ | MCES @ 1 ↓ |\n", - "|:-------------------------------------------------------|:--------------------:|:--------------------:|:--------------------:|:--------------------:|\n", - "| Random | 0.37 (0.24-0.54) | 2.01 (1.68-2.38) | 8.22 (7.57-8.93) | 30.81 (30.43-31.24) |\n", - "| DeepSets | 1.47 (1.20-1.79) | 6.21 (5.63-6.84) | 19.23 (18.27-20.22) | 25.11 (24.84-25.38) |\n", - "| FingerprintFFN | 2.54 (2.16-2.97) | 7.59 (6.93-8.27) | 20.0 (19.06-21.05) | 24.66 (24.37-24.95) |\n", - "| MIST | **14.64** (13.78-15.53) | **34.87** (33.70-36.06) | **59.15** (57.95-60.33) | **15.37** (15.13-15.62) |\n", - "| *Bonus chemical formulae challenge* | | | | |\n", - "| Random | 3.06 (2.64-3.50) | 11.35 (10.58-12.13) | 27.74 (26.66-28.80) | 13.87 (13.70-14.03) |\n", - "| DeepSets | 4.42 (3.91-4.93) | 14.46 (13.60-15.39) | 30.76 (29.64-31.90) | 15.04 (14.89-15.19) |\n", - "| FingerprintFFN | 5.09 (4.57-5.62) | 14.69 (13.83-15.57) | 31.97 (30.80-33.13) | 14.94 (14.79-15.10) |\n", - "| MIST | **9.57** (8.88-10.30) | **22.11** (21.13-23.24) | **41.12** (39.91-42.29) | **12.75** (12.58-12.92) |" + "## Evaluation for the de novo challenge" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/4 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
test_top_1_accuracytest_top_1_mces_disttest_top_1_max_tanimoto_simtest_top_10_accuracytest_top_10_mces_disttest_top_10_max_tanimoto_sim
method
random_baseline_formula0.0 (nan-nan)21.11 (20.97-21.26)0.08 (0.08-0.08)0.0 (nan-nan)18.25 (18.14-18.35)0.11 (0.11-0.11)
random_baseline_no_formula0.0 (nan-nan)28.59 (28.33-28.84)0.07 (0.07-0.07)0.0 (nan-nan)25.72 (25.48-25.96)0.1 (0.10-0.10)
rebuttal_selfies_transformer_test_2024-08-15_16-05-360.0 (nan-nan)33.28 (32.98-33.58)0.1 (0.10-0.10)0.0 (nan-nan)21.84 (21.67-22.00)0.15 (0.15-0.15)
rebuttal_smiles_transformer_test_2024-08-15_17-11-370.0 (nan-nan)53.8 (52.95-54.65)0.07 (0.07-0.08)0.0 (nan-nan)21.97 (21.78-22.16)0.17 (0.17-0.17)
\n", - "" - ], - "text/plain": [ - " test_top_1_accuracy \\\n", - "method \n", - "random_baseline_formula 0.0 (nan-nan) \n", - "random_baseline_no_formula 0.0 (nan-nan) \n", - "rebuttal_selfies_transformer_test_2024-08-15_16... 0.0 (nan-nan) \n", - "rebuttal_smiles_transformer_test_2024-08-15_17-... 0.0 (nan-nan) \n", - "\n", - " test_top_1_mces_dist \\\n", - "method \n", - "random_baseline_formula 21.11 (20.97-21.26) \n", - "random_baseline_no_formula 28.59 (28.33-28.84) \n", - "rebuttal_selfies_transformer_test_2024-08-15_16... 33.28 (32.98-33.58) \n", - "rebuttal_smiles_transformer_test_2024-08-15_17-... 53.8 (52.95-54.65) \n", - "\n", - " test_top_1_max_tanimoto_sim \\\n", - "method \n", - "random_baseline_formula 0.08 (0.08-0.08) \n", - "random_baseline_no_formula 0.07 (0.07-0.07) \n", - "rebuttal_selfies_transformer_test_2024-08-15_16... 0.1 (0.10-0.10) \n", - "rebuttal_smiles_transformer_test_2024-08-15_17-... 0.07 (0.07-0.08) \n", - "\n", - " test_top_10_accuracy \\\n", - "method \n", - "random_baseline_formula 0.0 (nan-nan) \n", - "random_baseline_no_formula 0.0 (nan-nan) \n", - "rebuttal_selfies_transformer_test_2024-08-15_16... 0.0 (nan-nan) \n", - "rebuttal_smiles_transformer_test_2024-08-15_17-... 0.0 (nan-nan) \n", - "\n", - " test_top_10_mces_dist \\\n", - "method \n", - "random_baseline_formula 18.25 (18.14-18.35) \n", - "random_baseline_no_formula 25.72 (25.48-25.96) \n", - "rebuttal_selfies_transformer_test_2024-08-15_16... 21.84 (21.67-22.00) \n", - "rebuttal_smiles_transformer_test_2024-08-15_17-... 21.97 (21.78-22.16) \n", - "\n", - " test_top_10_max_tanimoto_sim \n", - "method \n", - "random_baseline_formula 0.11 (0.11-0.11) \n", - "random_baseline_no_formula 0.1 (0.10-0.10) \n", - "rebuttal_selfies_transformer_test_2024-08-15_16... 0.15 (0.15-0.15) \n", - "rebuttal_smiles_transformer_test_2024-08-15_17-... 0.17 (0.17-0.17) " - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ "dir_results = Path('../data/test_results/de_novo')\n", "task = 'de_novo'\n", "\n", - "df = evaluate(dir_results, task)\n", - "df" + "df = evaluate(dir_results, task)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Main challenge" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "| method | test_top_1_accuracy | test_top_1_mces_dist | test_top_1_max_tanimoto_sim | test_top_10_accuracy | test_top_10_mces_dist | test_top_10_max_tanimoto_sim |\n", - "|:------------------------------------------------------|:----------------------|:-----------------------|:------------------------------|:-----------------------|:------------------------|:-------------------------------|\n", - "| random_baseline_formula | 0.0 (nan-nan) | 21.11 (20.97-21.26) | 0.08 (0.08-0.08) | 0.0 (nan-nan) | 18.25 (18.14-18.35) | 0.11 (0.11-0.11) |\n", - "| random_baseline_no_formula | 0.0 (nan-nan) | 28.59 (28.33-28.84) | 0.07 (0.07-0.07) | 0.0 (nan-nan) | 25.72 (25.48-25.96) | 0.1 (0.10-0.10) |\n", - "| rebuttal_selfies_transformer_test_2024-08-15_16-05-36 | 0.0 (nan-nan) | 33.28 (32.98-33.58) | 0.1 (0.10-0.10) | 0.0 (nan-nan) | 21.84 (21.67-22.00) | 0.15 (0.15-0.15) |\n", - "| rebuttal_smiles_transformer_test_2024-08-15_17-11-37 | 0.0 (nan-nan) | 53.8 (52.95-54.65) | 0.07 (0.07-0.08) | 0.0 (nan-nan) | 21.97 (21.78-22.16) | 0.17 (0.17-0.17) |\n" + "| method | test_top_1_accuracy | test_top_1_mces_dist | test_top_1_max_tanimoto_sim | test_top_10_accuracy | test_top_10_mces_dist | test_top_10_max_tanimoto_sim |\n", + "|:------------------------------------------------------------------------|:----------------------|:-----------------------|:------------------------------|:-----------------------|:------------------------|:-------------------------------|\n", + "| rebuttal_smiles_transformer_mist_canopus_1550_test_2024-08-17_02-30-13 | 0.0 (nan-nan) | 96.17 (95.78-96.53) | 0.01 (0.00-0.01) | 0.0 (nan-nan) | 70.88 (70.09-71.68) | 0.04 (0.04-0.04) |\n", + "| rebuttal_smiles_transformer_mist_canopus_test_2024-08-16_22-34-55 | 0.0 (nan-nan) | 96.06 (95.67-96.43) | 0.01 (0.00-0.01) | 0.0 (nan-nan) | 70.77 (69.96-71.53) | 0.04 (0.04-0.04) |\n", + "| rebuttal_selfies_transformer_mist_canopus_test_2024-08-16_22-33-21 | 0.0 (nan-nan) | 39.43 (39.10-39.76) | 0.08 (0.08-0.08) | 0.0 (nan-nan) | 27.21 (26.99-27.46) | 0.13 (0.13-0.13) |\n", + "| rebuttal_selfies_transformer_mist_canopus_1550_test_2024-08-17_02-30-13 | 0.0 (nan-nan) | 40.21 (39.88-40.56) | 0.08 (0.08-0.08) | 0.0 (nan-nan) | 27.14 (26.91-27.38) | 0.13 (0.12-0.13) |\n", + "| random_baseline_no_formula | 0.0 (nan-nan) | 28.59 (28.33-28.84) | 0.07 (0.07-0.07) | 0.0 (nan-nan) | 25.72 (25.49-25.95) | 0.1 (0.10-0.10) |\n", + "| rebuttal_smiles_transformer_test_2024-08-15_17-11-37 | 0.0 (nan-nan) | 53.8 (52.95-54.61) | 0.07 (0.07-0.08) | 0.0 (nan-nan) | 21.97 (21.79-22.16) | 0.17 (0.17-0.17) |\n", + "| rebuttal_selfies_transformer_test_2024-08-15_16-05-36 | 0.0 (nan-nan) | 33.28 (33.00-33.57) | 0.1 (0.10-0.10) | 0.0 (nan-nan) | 21.84 (21.67-22.00) | 0.15 (0.15-0.15) |\n" ] } ], "source": [ "df_paper = df.reset_index()\n", - "# df_paper = df_paper[~df_paper['method'].str.contains('formula')]\n", - "df_paper = df_paper.sort_values('test_top_1_mces_dist', ascending=True, key=lambda x: x.str.split(' ').str[0].astype(float))\n", + "df_paper = df_paper[(~df_paper['method'].str.contains('formula')) | (df_paper['method'].str.contains('no_formula'))]\n", + "df_paper = df_paper.sort_values('test_top_10_mces_dist', ascending=False, key=lambda x: x.str.split(' ').str[0].astype(float))\n", "print(df_paper.to_markdown(index=False))" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "| | Top-1 Accuracy ↑ | Top-1 MCES ↓ | Top-1 Tanimoto ↑ | Top-10 Accuracy ↑ | Top-10 MCES ↓ | Top-10 Tanimoto ↑ |\n", - "|:------------------------------------------------------|:----------------------:|:-----------------------:|:------------------------------:|:-----------------------:|:------------------------:|:-------------------------------:|\n", - "| Random chemical generation | 0.0 | **28.59** (28.33-28.84) | 0.07 (0.07-0.07) | 0.0 | 25.72 (25.48-25.96) | 0.1 (0.10-0.10) |\n", - "| SMILES Transformer | 0.0 | 53.8 (52.95-54.65) | 0.07 (0.07-0.08) | 0.0 | 21.97 (21.78-22.16) | **0.17** (0.17-0.17) |\n", - "| SELFIES Transformer | 0.0 | 33.28 (32.98-33.58) | **0.1** (0.10-0.10) | 0.0 | **21.84** (21.67-22.00) | 0.15 (0.15-0.15) |\n", - "| *Bonus chemical formulae challenge*\n", - "| Random chemical generation | 0.0 | **21.11** (20.97-21.26) | **0.08** (0.08-0.08) | 0.0 | **18.25** (18.14-18.35) | **0.11** (0.11-0.11) |" + "### Bonus chemical formulae challenge" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "| method | test_top_1_accuracy | test_top_1_mces_dist | test_top_1_max_tanimoto_sim | test_top_10_accuracy | test_top_10_mces_dist | test_top_10_max_tanimoto_sim |\n", + "|:-----------------------------------------------------------------|:----------------------|:-----------------------|:------------------------------|:-----------------------|:------------------------|:-------------------------------|\n", + "| rebuttal_smiles_transformer_formula_test_2024-08-17_02-30-13 | 0.0 (nan-nan) | 79.39 (78.64-80.08) | 0.03 (0.03-0.04) | 0.0 (nan-nan) | 52.13 (51.45-52.81) | 0.1 (0.09-0.10) |\n", + "| rebuttal_selfies_transformer_formula_test_2024-08-17_02-30-13 | 0.0 (nan-nan) | 38.88 (38.57-39.20) | 0.08 (0.08-0.08) | 0.0 (nan-nan) | 26.87 (26.66-27.11) | 0.13 (0.13-0.13) |\n", + "| rebuttal_selfies_transformer_formula_v2_test_2024-08-18_14-28-08 | 0.0 (nan-nan) | 38.88 (38.57-39.21) | 0.08 (0.08-0.08) | 0.0 (nan-nan) | 26.87 (26.66-27.11) | 0.13 (0.13-0.13) |\n", + "| random_baseline_formula | 0.0 (nan-nan) | 21.11 (20.97-21.26) | 0.08 (0.08-0.08) | 0.0 (nan-nan) | 18.25 (18.14-18.35) | 0.11 (0.11-0.11) |\n" + ] + } + ], + "source": [ + "df_paper = df.reset_index()\n", + "df_paper = df_paper[(df_paper['method'].str.contains('formula')) & (~df_paper['method'].str.contains('no_formula'))]\n", + "df_paper = df_paper.sort_values('test_top_10_mces_dist', ascending=False, key=lambda x: x.str.split(' ').str[0].astype(float))\n", + "print(df_paper.to_markdown(index=False))" ] } ],