diff --git a/massspecgym/models/base.py b/massspecgym/models/base.py
index 2359930..e35e1c7 100644
--- a/massspecgym/models/base.py
+++ b/massspecgym/models/base.py
@@ -28,7 +28,7 @@ def __init__(
lr: float = 1e-4,
weight_decay: float = 0.0,
log_only_loss_at_stages: T.Sequence[Stage | str] = (),
- bootstrap_metrics: bool = True,
+ bootstrap_metrics: bool = False,
df_test_path: T.Optional[str | Path] = None,
*args,
**kwargs
diff --git a/notebooks/evaluation.ipynb b/notebooks/evaluation.ipynb
index 5a4b5aa..3b12dfd 100644
--- a/notebooks/evaluation.ipynb
+++ b/notebooks/evaluation.ipynb
@@ -6,6 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
+ "import random\n",
"from pathlib import Path\n",
"\n",
"import numpy as np\n",
@@ -14,12 +15,26 @@
"from scipy.stats import bootstrap\n",
"from tqdm import tqdm\n",
"\n",
- "tqdm.pandas()"
+ "tqdm.pandas()\n",
+ "\n",
+ "# Set random seeds for reproducibility\n",
+ "seed = 0\n",
+ "random.seed(seed)\n",
+ "np.random.seed(seed)\n",
+ "pd.set_option('compute.use_numexpr', False) # Disable numexpr to ensure reproducibility\n",
+ "pd.set_option('compute.use_bottleneck', False) # Disable bottleneck to ensure reproducibility"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This notebooks takes pickled test dataframes automatically stored during testing of the models (i.e., running `trainer.test(model, ...)`) and calculates means and confidence intervals for all metrics. The cell below shows an example of a test dataframe.\n"
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 2,
"metadata": {},
"outputs": [
{
@@ -199,19 +214,19 @@
"[17556 rows x 6 columns]"
]
},
- "execution_count": 19,
+ "execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "# Example of a test dataframe:\n",
+ "# Example of a test dataframe for the retrieval challenge:\n",
"pd.read_pickle('../data/test_results/retrieval/rebuttal_MIST_test_formula_2024-08-13_15-07-19.pkl')"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -245,11 +260,12 @@
"\n",
" # Calculate confidence intervals for all metrics into a single table\n",
" def get_ci(col_vals, confidence_level=0.999, n_resamples=20_000):\n",
- " res = bootstrap((col_vals,), np.mean, confidence_level=confidence_level, n_resamples=n_resamples)\n",
+ " res = bootstrap((col_vals,), np.mean, confidence_level=confidence_level, n_resamples=n_resamples, random_state=seed)\n",
" ci = res.confidence_interval\n",
" return f'{ci.low:.2f}-{ci.high:.2f}'\n",
" def get_ci_for_each_col(df_method):\n",
" return df_method.apply(get_ci, axis=0)\n",
+ " tqdm.pandas(desc=\"Bootstrapping predictions for each method\", postfix=None)\n",
" df_ci = df.groupby('method')[metric_cols].progress_apply(lambda df_method: get_ci_for_each_col(df_method))\n",
"\n",
" # Merge tables with means and confidence intervals\n",
@@ -258,416 +274,198 @@
" return df_mean"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Evaluation for the retrieval challenge"
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|██████████| 9/9 [05:55<00:00, 39.47s/it]\n"
+ "Bootstrapping predictions for each method: 100%|██████████| 13/13 [07:49<00:00, 36.10s/it]\n"
]
- },
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " test_hit_rate@1 | \n",
- " test_hit_rate@5 | \n",
- " test_hit_rate@20 | \n",
- " test_mces@1 | \n",
- "
\n",
- " \n",
- " method | \n",
- " | \n",
- " | \n",
- " | \n",
- " | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " rebuttal_MIST_test_formula_2024-08-13_15-07-19 | \n",
- " 9.57 (8.88-10.30) | \n",
- " 22.11 (21.13-23.24) | \n",
- " 41.12 (39.91-42.29) | \n",
- " 12.75 (12.58-12.92) | \n",
- "
\n",
- " \n",
- " rebuttal_deepsets_test_formula_2024-08-15_16-45-06 | \n",
- " 4.42 (3.91-4.93) | \n",
- " 14.46 (13.60-15.39) | \n",
- " 30.76 (29.64-31.90) | \n",
- " 15.04 (14.89-15.19) | \n",
- "
\n",
- " \n",
- " rebuttal_deepsets_test_mass_2024-08-14_22-51-05 | \n",
- " 1.47 (1.20-1.79) | \n",
- " 6.21 (5.63-6.84) | \n",
- " 19.23 (18.27-20.22) | \n",
- " 25.11 (24.84-25.38) | \n",
- "
\n",
- " \n",
- " rebuttal_enhanced_MIST_test_mass_2024-08-13_01-18-44 | \n",
- " 14.64 (13.78-15.53) | \n",
- " 34.87 (33.70-36.06) | \n",
- " 59.15 (57.95-60.33) | \n",
- " 15.37 (15.13-15.62) | \n",
- "
\n",
- " \n",
- " rebuttal_fingerprint_ffn_test_formula_2024-08-15_15-45-02 | \n",
- " 5.09 (4.57-5.62) | \n",
- " 14.69 (13.83-15.57) | \n",
- " 31.97 (30.80-33.13) | \n",
- " 14.94 (14.79-15.10) | \n",
- "
\n",
- " \n",
- " rebuttal_fingerprint_ffn_test_mass_2024-08-15_15-39-32 | \n",
- " 2.54 (2.16-2.97) | \n",
- " 7.59 (6.93-8.27) | \n",
- " 20.0 (19.06-21.05) | \n",
- " 24.66 (24.37-24.95) | \n",
- "
\n",
- " \n",
- " rebuttal_random_test_formula_2024-08-13_16-14-07 | \n",
- " 3.06 (2.67-3.51) | \n",
- " 11.35 (10.59-12.14) | \n",
- " 27.74 (26.62-28.94) | \n",
- " 13.87 (13.70-14.03) | \n",
- "
\n",
- " \n",
- " rebuttal_random_test_formula_2024-08-13_17-08-09 | \n",
- " 3.06 (2.64-3.50) | \n",
- " 11.35 (10.58-12.13) | \n",
- " 27.74 (26.66-28.80) | \n",
- " 13.87 (13.70-14.03) | \n",
- "
\n",
- " \n",
- " rebuttal_random_test_mass_2024-08-13_17-08-09 | \n",
- " 0.37 (0.24-0.54) | \n",
- " 2.01 (1.68-2.38) | \n",
- " 8.22 (7.57-8.93) | \n",
- " 30.81 (30.43-31.24) | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " test_hit_rate@1 \\\n",
- "method \n",
- "rebuttal_MIST_test_formula_2024-08-13_15-07-19 9.57 (8.88-10.30) \n",
- "rebuttal_deepsets_test_formula_2024-08-15_16-45-06 4.42 (3.91-4.93) \n",
- "rebuttal_deepsets_test_mass_2024-08-14_22-51-05 1.47 (1.20-1.79) \n",
- "rebuttal_enhanced_MIST_test_mass_2024-08-13_01-... 14.64 (13.78-15.53) \n",
- "rebuttal_fingerprint_ffn_test_formula_2024-08-1... 5.09 (4.57-5.62) \n",
- "rebuttal_fingerprint_ffn_test_mass_2024-08-15_1... 2.54 (2.16-2.97) \n",
- "rebuttal_random_test_formula_2024-08-13_16-14-07 3.06 (2.67-3.51) \n",
- "rebuttal_random_test_formula_2024-08-13_17-08-09 3.06 (2.64-3.50) \n",
- "rebuttal_random_test_mass_2024-08-13_17-08-09 0.37 (0.24-0.54) \n",
- "\n",
- " test_hit_rate@5 \\\n",
- "method \n",
- "rebuttal_MIST_test_formula_2024-08-13_15-07-19 22.11 (21.13-23.24) \n",
- "rebuttal_deepsets_test_formula_2024-08-15_16-45-06 14.46 (13.60-15.39) \n",
- "rebuttal_deepsets_test_mass_2024-08-14_22-51-05 6.21 (5.63-6.84) \n",
- "rebuttal_enhanced_MIST_test_mass_2024-08-13_01-... 34.87 (33.70-36.06) \n",
- "rebuttal_fingerprint_ffn_test_formula_2024-08-1... 14.69 (13.83-15.57) \n",
- "rebuttal_fingerprint_ffn_test_mass_2024-08-15_1... 7.59 (6.93-8.27) \n",
- "rebuttal_random_test_formula_2024-08-13_16-14-07 11.35 (10.59-12.14) \n",
- "rebuttal_random_test_formula_2024-08-13_17-08-09 11.35 (10.58-12.13) \n",
- "rebuttal_random_test_mass_2024-08-13_17-08-09 2.01 (1.68-2.38) \n",
- "\n",
- " test_hit_rate@20 \\\n",
- "method \n",
- "rebuttal_MIST_test_formula_2024-08-13_15-07-19 41.12 (39.91-42.29) \n",
- "rebuttal_deepsets_test_formula_2024-08-15_16-45-06 30.76 (29.64-31.90) \n",
- "rebuttal_deepsets_test_mass_2024-08-14_22-51-05 19.23 (18.27-20.22) \n",
- "rebuttal_enhanced_MIST_test_mass_2024-08-13_01-... 59.15 (57.95-60.33) \n",
- "rebuttal_fingerprint_ffn_test_formula_2024-08-1... 31.97 (30.80-33.13) \n",
- "rebuttal_fingerprint_ffn_test_mass_2024-08-15_1... 20.0 (19.06-21.05) \n",
- "rebuttal_random_test_formula_2024-08-13_16-14-07 27.74 (26.62-28.94) \n",
- "rebuttal_random_test_formula_2024-08-13_17-08-09 27.74 (26.66-28.80) \n",
- "rebuttal_random_test_mass_2024-08-13_17-08-09 8.22 (7.57-8.93) \n",
- "\n",
- " test_mces@1 \n",
- "method \n",
- "rebuttal_MIST_test_formula_2024-08-13_15-07-19 12.75 (12.58-12.92) \n",
- "rebuttal_deepsets_test_formula_2024-08-15_16-45-06 15.04 (14.89-15.19) \n",
- "rebuttal_deepsets_test_mass_2024-08-14_22-51-05 25.11 (24.84-25.38) \n",
- "rebuttal_enhanced_MIST_test_mass_2024-08-13_01-... 15.37 (15.13-15.62) \n",
- "rebuttal_fingerprint_ffn_test_formula_2024-08-1... 14.94 (14.79-15.10) \n",
- "rebuttal_fingerprint_ffn_test_mass_2024-08-15_1... 24.66 (24.37-24.95) \n",
- "rebuttal_random_test_formula_2024-08-13_16-14-07 13.87 (13.70-14.03) \n",
- "rebuttal_random_test_formula_2024-08-13_17-08-09 13.87 (13.70-14.03) \n",
- "rebuttal_random_test_mass_2024-08-13_17-08-09 30.81 (30.43-31.24) "
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
}
],
"source": [
"dir_results = Path('../data/test_results/retrieval')\n",
"task = 'retrieval'\n",
"\n",
- "df = evaluate(dir_results, task)\n",
- "df"
+ "df = evaluate(dir_results, task)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Main challenge"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "| method | test_hit_rate@1 | test_hit_rate@5 | test_hit_rate@20 | test_mces@1 |\n",
+ "|:---------------------------------------------------------------------------------|:--------------------|:--------------------|:--------------------|:--------------------|\n",
+ "| rebuttal_random_test_mass_2024-08-13_17-08-09 | 0.37 (0.24-0.54) | 2.01 (1.68-2.39) | 8.22 (7.53-8.89) | 30.81 (30.40-31.21) |\n",
+ "| rebuttal_deepsets_test_mass_2024-08-14_22-51-05 | 1.47 (1.18-1.77) | 6.21 (5.64-6.82) | 19.23 (18.24-20.26) | 25.11 (24.84-25.39) |\n",
+ "| rebuttal_fingerprint_ffn_sigmoid_mist_canopus_1550_test_mass_2024-08-17_02-30-13 | 1.65 (1.36-1.98) | 5.45 (4.89-6.02) | 15.15 (14.29-16.05) | 26.76 (26.47-27.06) |\n",
+ "| rebuttal_fingerprint_ffn_test_mass_2024-08-15_15-39-32 | 2.54 (2.17-2.99) | 7.59 (6.96-8.28) | 20.0 (19.01-20.98) | 24.66 (24.38-24.94) |\n",
+ "| rebuttal_deepsets_ff_test_mass_2024-08-17_02-30-13 | 5.24 (4.71-5.83) | 12.58 (11.80-13.42) | 28.21 (27.10-29.36) | 22.13 (21.85-22.43) |\n",
+ "| rebuttal_enhanced_MIST_test_mass_2024-08-13_01-18-44 | 14.64 (13.82-15.54) | 34.87 (33.69-36.10) | 59.15 (57.89-60.39) | 15.37 (15.12-15.62) |\n"
+ ]
+ }
+ ],
+ "source": [
+ "df_paper = df.reset_index()\n",
+ "df_paper = df_paper[(~df_paper['method'].str.contains('formula')) | (df_paper['method'].str.contains('no_formula'))]\n",
+ "df_paper = df_paper.sort_values('test_hit_rate@1', ascending=True, key=lambda x: x.str.split(' ').str[0].astype(float))\n",
+ "print(df_paper.to_markdown(index=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Bonus chemical formulae challenge"
]
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "| method | test_hit_rate@1 | test_hit_rate@5 | test_hit_rate@20 | test_mces@1 |\n",
- "|:----------------------------------------------------------|:------------------|:--------------------|:--------------------|:--------------------|\n",
- "| rebuttal_random_test_formula_2024-08-13_16-14-07 | 3.06 (2.67-3.51) | 11.35 (10.59-12.14) | 27.74 (26.62-28.94) | 13.87 (13.70-14.03) |\n",
- "| rebuttal_random_test_formula_2024-08-13_17-08-09 | 3.06 (2.64-3.50) | 11.35 (10.58-12.13) | 27.74 (26.66-28.80) | 13.87 (13.70-14.03) |\n",
- "| rebuttal_deepsets_test_formula_2024-08-15_16-45-06 | 4.42 (3.91-4.93) | 14.46 (13.60-15.39) | 30.76 (29.64-31.90) | 15.04 (14.89-15.19) |\n",
- "| rebuttal_fingerprint_ffn_test_formula_2024-08-15_15-45-02 | 5.09 (4.57-5.62) | 14.69 (13.83-15.57) | 31.97 (30.80-33.13) | 14.94 (14.79-15.10) |\n",
- "| rebuttal_MIST_test_formula_2024-08-13_15-07-19 | 9.57 (8.88-10.30) | 22.11 (21.13-23.24) | 41.12 (39.91-42.29) | 12.75 (12.58-12.92) |\n"
+ "| method | test_hit_rate@1 | test_hit_rate@5 | test_hit_rate@20 | test_mces@1 |\n",
+ "|:------------------------------------------------------------------------------------|:------------------|:--------------------|:--------------------|:--------------------|\n",
+ "| rebuttal_random_test_formula_2024-08-13_16-14-07 | 3.06 (2.64-3.52) | 11.35 (10.60-12.12) | 27.74 (26.52-28.84) | 13.87 (13.70-14.03) |\n",
+ "| rebuttal_random_test_formula_2024-08-13_17-08-09 | 3.06 (2.64-3.52) | 11.35 (10.60-12.12) | 27.74 (26.52-28.84) | 13.87 (13.70-14.03) |\n",
+ "| rebuttal_fingerprint_ffn_sigmoid_mist_canopus_1550_test_formula_2024-08-17_02-30-13 | 4.07 (3.61-4.54) | 13.13 (12.33-13.95) | 29.44 (28.32-30.53) | 15.5 (15.34-15.64) |\n",
+ "| rebuttal_deepsets_test_formula_2024-08-15_16-45-06 | 4.42 (3.92-4.97) | 14.46 (13.58-15.36) | 30.76 (29.67-31.93) | 15.04 (14.89-15.19) |\n",
+ "| rebuttal_fingerprint_ffn_test_formula_2024-08-15_15-45-02 | 5.09 (4.57-5.66) | 14.69 (13.83-15.56) | 31.97 (30.86-33.10) | 14.94 (14.79-15.09) |\n",
+ "| rebuttal_deepsets_ff_test_formula_2024-08-17_02-30-13 | 6.56 (5.95-7.16) | 16.46 (15.58-17.35) | 33.46 (32.39-34.59) | 14.14 (13.98-14.31) |\n",
+ "| rebuttal_MIST_test_formula_2024-08-13_15-07-19 | 9.57 (8.88-10.30) | 22.11 (21.10-23.13) | 41.12 (39.98-42.34) | 12.75 (12.59-12.91) |\n"
]
}
],
"source": [
"df_paper = df.reset_index()\n",
- "df_paper = df_paper[df_paper['method'].str.contains('formula')]\n",
+ "df_paper = df_paper[(df_paper['method'].str.contains('formula')) & (~df_paper['method'].str.contains('no_formula'))]\n",
"df_paper = df_paper.sort_values('test_hit_rate@1', ascending=True, key=lambda x: x.str.split(' ').str[0].astype(float))\n",
"print(df_paper.to_markdown(index=False))"
]
},
{
- "attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
- "| | Hit rate @ 1 ↑ | Hit rate @ 5 ↑ | Hit rate @ 20 ↑ | MCES @ 1 ↓ |\n",
- "|:-------------------------------------------------------|:--------------------:|:--------------------:|:--------------------:|:--------------------:|\n",
- "| Random | 0.37 (0.24-0.54) | 2.01 (1.68-2.38) | 8.22 (7.57-8.93) | 30.81 (30.43-31.24) |\n",
- "| DeepSets | 1.47 (1.20-1.79) | 6.21 (5.63-6.84) | 19.23 (18.27-20.22) | 25.11 (24.84-25.38) |\n",
- "| FingerprintFFN | 2.54 (2.16-2.97) | 7.59 (6.93-8.27) | 20.0 (19.06-21.05) | 24.66 (24.37-24.95) |\n",
- "| MIST | **14.64** (13.78-15.53) | **34.87** (33.70-36.06) | **59.15** (57.95-60.33) | **15.37** (15.13-15.62) |\n",
- "| *Bonus chemical formulae challenge* | | | | |\n",
- "| Random | 3.06 (2.64-3.50) | 11.35 (10.58-12.13) | 27.74 (26.66-28.80) | 13.87 (13.70-14.03) |\n",
- "| DeepSets | 4.42 (3.91-4.93) | 14.46 (13.60-15.39) | 30.76 (29.64-31.90) | 15.04 (14.89-15.19) |\n",
- "| FingerprintFFN | 5.09 (4.57-5.62) | 14.69 (13.83-15.57) | 31.97 (30.80-33.13) | 14.94 (14.79-15.10) |\n",
- "| MIST | **9.57** (8.88-10.30) | **22.11** (21.13-23.24) | **41.12** (39.91-42.29) | **12.75** (12.58-12.92) |"
+ "## Evaluation for the de novo challenge"
]
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- " 0%| | 0/4 [00:00, ?it/s]/scratch/project_465000883/bushuiev/miniconda3/envs/massspecgym/lib/python3.11/site-packages/scipy/stats/_resampling.py:144: RuntimeWarning: invalid value encountered in scalar divide\n",
+ "Bootstrapping predictions for each method: 0%| | 0/11 [00:00, ?it/s]/scratch/project_465000883/bushuiev/miniconda3/envs/massspecgym/lib/python3.11/site-packages/scipy/stats/_resampling.py:144: RuntimeWarning: invalid value encountered in scalar divide\n",
" a_hat = 1/6 * sum(nums) / sum(dens)**(3/2)\n",
"/scratch/project_465000883/bushuiev/miniconda3/envs/massspecgym/lib/python3.11/site-packages/scipy/stats/_resampling.py:97: DegenerateDataWarning: The BCa confidence interval cannot be calculated. This problem is known to occur when the distribution is degenerate or the statistic is np.min.\n",
" warnings.warn(DegenerateDataWarning(msg))\n",
- "100%|██████████| 4/4 [04:13<00:00, 63.32s/it]\n"
+ "Bootstrapping predictions for each method: 100%|██████████| 11/11 [10:12<00:00, 55.66s/it]\n"
]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " test_top_1_accuracy | \n",
- " test_top_1_mces_dist | \n",
- " test_top_1_max_tanimoto_sim | \n",
- " test_top_10_accuracy | \n",
- " test_top_10_mces_dist | \n",
- " test_top_10_max_tanimoto_sim | \n",
- "
\n",
- " \n",
- " method | \n",
- " | \n",
- " | \n",
- " | \n",
- " | \n",
- " | \n",
- " | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " random_baseline_formula | \n",
- " 0.0 (nan-nan) | \n",
- " 21.11 (20.97-21.26) | \n",
- " 0.08 (0.08-0.08) | \n",
- " 0.0 (nan-nan) | \n",
- " 18.25 (18.14-18.35) | \n",
- " 0.11 (0.11-0.11) | \n",
- "
\n",
- " \n",
- " random_baseline_no_formula | \n",
- " 0.0 (nan-nan) | \n",
- " 28.59 (28.33-28.84) | \n",
- " 0.07 (0.07-0.07) | \n",
- " 0.0 (nan-nan) | \n",
- " 25.72 (25.48-25.96) | \n",
- " 0.1 (0.10-0.10) | \n",
- "
\n",
- " \n",
- " rebuttal_selfies_transformer_test_2024-08-15_16-05-36 | \n",
- " 0.0 (nan-nan) | \n",
- " 33.28 (32.98-33.58) | \n",
- " 0.1 (0.10-0.10) | \n",
- " 0.0 (nan-nan) | \n",
- " 21.84 (21.67-22.00) | \n",
- " 0.15 (0.15-0.15) | \n",
- "
\n",
- " \n",
- " rebuttal_smiles_transformer_test_2024-08-15_17-11-37 | \n",
- " 0.0 (nan-nan) | \n",
- " 53.8 (52.95-54.65) | \n",
- " 0.07 (0.07-0.08) | \n",
- " 0.0 (nan-nan) | \n",
- " 21.97 (21.78-22.16) | \n",
- " 0.17 (0.17-0.17) | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " test_top_1_accuracy \\\n",
- "method \n",
- "random_baseline_formula 0.0 (nan-nan) \n",
- "random_baseline_no_formula 0.0 (nan-nan) \n",
- "rebuttal_selfies_transformer_test_2024-08-15_16... 0.0 (nan-nan) \n",
- "rebuttal_smiles_transformer_test_2024-08-15_17-... 0.0 (nan-nan) \n",
- "\n",
- " test_top_1_mces_dist \\\n",
- "method \n",
- "random_baseline_formula 21.11 (20.97-21.26) \n",
- "random_baseline_no_formula 28.59 (28.33-28.84) \n",
- "rebuttal_selfies_transformer_test_2024-08-15_16... 33.28 (32.98-33.58) \n",
- "rebuttal_smiles_transformer_test_2024-08-15_17-... 53.8 (52.95-54.65) \n",
- "\n",
- " test_top_1_max_tanimoto_sim \\\n",
- "method \n",
- "random_baseline_formula 0.08 (0.08-0.08) \n",
- "random_baseline_no_formula 0.07 (0.07-0.07) \n",
- "rebuttal_selfies_transformer_test_2024-08-15_16... 0.1 (0.10-0.10) \n",
- "rebuttal_smiles_transformer_test_2024-08-15_17-... 0.07 (0.07-0.08) \n",
- "\n",
- " test_top_10_accuracy \\\n",
- "method \n",
- "random_baseline_formula 0.0 (nan-nan) \n",
- "random_baseline_no_formula 0.0 (nan-nan) \n",
- "rebuttal_selfies_transformer_test_2024-08-15_16... 0.0 (nan-nan) \n",
- "rebuttal_smiles_transformer_test_2024-08-15_17-... 0.0 (nan-nan) \n",
- "\n",
- " test_top_10_mces_dist \\\n",
- "method \n",
- "random_baseline_formula 18.25 (18.14-18.35) \n",
- "random_baseline_no_formula 25.72 (25.48-25.96) \n",
- "rebuttal_selfies_transformer_test_2024-08-15_16... 21.84 (21.67-22.00) \n",
- "rebuttal_smiles_transformer_test_2024-08-15_17-... 21.97 (21.78-22.16) \n",
- "\n",
- " test_top_10_max_tanimoto_sim \n",
- "method \n",
- "random_baseline_formula 0.11 (0.11-0.11) \n",
- "random_baseline_no_formula 0.1 (0.10-0.10) \n",
- "rebuttal_selfies_transformer_test_2024-08-15_16... 0.15 (0.15-0.15) \n",
- "rebuttal_smiles_transformer_test_2024-08-15_17-... 0.17 (0.17-0.17) "
- ]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
}
],
"source": [
"dir_results = Path('../data/test_results/de_novo')\n",
"task = 'de_novo'\n",
"\n",
- "df = evaluate(dir_results, task)\n",
- "df"
+ "df = evaluate(dir_results, task)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Main challenge"
]
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "| method | test_top_1_accuracy | test_top_1_mces_dist | test_top_1_max_tanimoto_sim | test_top_10_accuracy | test_top_10_mces_dist | test_top_10_max_tanimoto_sim |\n",
- "|:------------------------------------------------------|:----------------------|:-----------------------|:------------------------------|:-----------------------|:------------------------|:-------------------------------|\n",
- "| random_baseline_formula | 0.0 (nan-nan) | 21.11 (20.97-21.26) | 0.08 (0.08-0.08) | 0.0 (nan-nan) | 18.25 (18.14-18.35) | 0.11 (0.11-0.11) |\n",
- "| random_baseline_no_formula | 0.0 (nan-nan) | 28.59 (28.33-28.84) | 0.07 (0.07-0.07) | 0.0 (nan-nan) | 25.72 (25.48-25.96) | 0.1 (0.10-0.10) |\n",
- "| rebuttal_selfies_transformer_test_2024-08-15_16-05-36 | 0.0 (nan-nan) | 33.28 (32.98-33.58) | 0.1 (0.10-0.10) | 0.0 (nan-nan) | 21.84 (21.67-22.00) | 0.15 (0.15-0.15) |\n",
- "| rebuttal_smiles_transformer_test_2024-08-15_17-11-37 | 0.0 (nan-nan) | 53.8 (52.95-54.65) | 0.07 (0.07-0.08) | 0.0 (nan-nan) | 21.97 (21.78-22.16) | 0.17 (0.17-0.17) |\n"
+ "| method | test_top_1_accuracy | test_top_1_mces_dist | test_top_1_max_tanimoto_sim | test_top_10_accuracy | test_top_10_mces_dist | test_top_10_max_tanimoto_sim |\n",
+ "|:------------------------------------------------------------------------|:----------------------|:-----------------------|:------------------------------|:-----------------------|:------------------------|:-------------------------------|\n",
+ "| rebuttal_smiles_transformer_mist_canopus_1550_test_2024-08-17_02-30-13 | 0.0 (nan-nan) | 96.17 (95.78-96.53) | 0.01 (0.00-0.01) | 0.0 (nan-nan) | 70.88 (70.09-71.68) | 0.04 (0.04-0.04) |\n",
+ "| rebuttal_smiles_transformer_mist_canopus_test_2024-08-16_22-34-55 | 0.0 (nan-nan) | 96.06 (95.67-96.43) | 0.01 (0.00-0.01) | 0.0 (nan-nan) | 70.77 (69.96-71.53) | 0.04 (0.04-0.04) |\n",
+ "| rebuttal_selfies_transformer_mist_canopus_test_2024-08-16_22-33-21 | 0.0 (nan-nan) | 39.43 (39.10-39.76) | 0.08 (0.08-0.08) | 0.0 (nan-nan) | 27.21 (26.99-27.46) | 0.13 (0.13-0.13) |\n",
+ "| rebuttal_selfies_transformer_mist_canopus_1550_test_2024-08-17_02-30-13 | 0.0 (nan-nan) | 40.21 (39.88-40.56) | 0.08 (0.08-0.08) | 0.0 (nan-nan) | 27.14 (26.91-27.38) | 0.13 (0.12-0.13) |\n",
+ "| random_baseline_no_formula | 0.0 (nan-nan) | 28.59 (28.33-28.84) | 0.07 (0.07-0.07) | 0.0 (nan-nan) | 25.72 (25.49-25.95) | 0.1 (0.10-0.10) |\n",
+ "| rebuttal_smiles_transformer_test_2024-08-15_17-11-37 | 0.0 (nan-nan) | 53.8 (52.95-54.61) | 0.07 (0.07-0.08) | 0.0 (nan-nan) | 21.97 (21.79-22.16) | 0.17 (0.17-0.17) |\n",
+ "| rebuttal_selfies_transformer_test_2024-08-15_16-05-36 | 0.0 (nan-nan) | 33.28 (33.00-33.57) | 0.1 (0.10-0.10) | 0.0 (nan-nan) | 21.84 (21.67-22.00) | 0.15 (0.15-0.15) |\n"
]
}
],
"source": [
"df_paper = df.reset_index()\n",
- "# df_paper = df_paper[~df_paper['method'].str.contains('formula')]\n",
- "df_paper = df_paper.sort_values('test_top_1_mces_dist', ascending=True, key=lambda x: x.str.split(' ').str[0].astype(float))\n",
+ "df_paper = df_paper[(~df_paper['method'].str.contains('formula')) | (df_paper['method'].str.contains('no_formula'))]\n",
+ "df_paper = df_paper.sort_values('test_top_10_mces_dist', ascending=False, key=lambda x: x.str.split(' ').str[0].astype(float))\n",
"print(df_paper.to_markdown(index=False))"
]
},
{
- "attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
- "| | Top-1 Accuracy ↑ | Top-1 MCES ↓ | Top-1 Tanimoto ↑ | Top-10 Accuracy ↑ | Top-10 MCES ↓ | Top-10 Tanimoto ↑ |\n",
- "|:------------------------------------------------------|:----------------------:|:-----------------------:|:------------------------------:|:-----------------------:|:------------------------:|:-------------------------------:|\n",
- "| Random chemical generation | 0.0 | **28.59** (28.33-28.84) | 0.07 (0.07-0.07) | 0.0 | 25.72 (25.48-25.96) | 0.1 (0.10-0.10) |\n",
- "| SMILES Transformer | 0.0 | 53.8 (52.95-54.65) | 0.07 (0.07-0.08) | 0.0 | 21.97 (21.78-22.16) | **0.17** (0.17-0.17) |\n",
- "| SELFIES Transformer | 0.0 | 33.28 (32.98-33.58) | **0.1** (0.10-0.10) | 0.0 | **21.84** (21.67-22.00) | 0.15 (0.15-0.15) |\n",
- "| *Bonus chemical formulae challenge*\n",
- "| Random chemical generation | 0.0 | **21.11** (20.97-21.26) | **0.08** (0.08-0.08) | 0.0 | **18.25** (18.14-18.35) | **0.11** (0.11-0.11) |"
+ "### Bonus chemical formulae challenge"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "| method | test_top_1_accuracy | test_top_1_mces_dist | test_top_1_max_tanimoto_sim | test_top_10_accuracy | test_top_10_mces_dist | test_top_10_max_tanimoto_sim |\n",
+ "|:-----------------------------------------------------------------|:----------------------|:-----------------------|:------------------------------|:-----------------------|:------------------------|:-------------------------------|\n",
+ "| rebuttal_smiles_transformer_formula_test_2024-08-17_02-30-13 | 0.0 (nan-nan) | 79.39 (78.64-80.08) | 0.03 (0.03-0.04) | 0.0 (nan-nan) | 52.13 (51.45-52.81) | 0.1 (0.09-0.10) |\n",
+ "| rebuttal_selfies_transformer_formula_test_2024-08-17_02-30-13 | 0.0 (nan-nan) | 38.88 (38.57-39.20) | 0.08 (0.08-0.08) | 0.0 (nan-nan) | 26.87 (26.66-27.11) | 0.13 (0.13-0.13) |\n",
+ "| rebuttal_selfies_transformer_formula_v2_test_2024-08-18_14-28-08 | 0.0 (nan-nan) | 38.88 (38.57-39.21) | 0.08 (0.08-0.08) | 0.0 (nan-nan) | 26.87 (26.66-27.11) | 0.13 (0.13-0.13) |\n",
+ "| random_baseline_formula | 0.0 (nan-nan) | 21.11 (20.97-21.26) | 0.08 (0.08-0.08) | 0.0 (nan-nan) | 18.25 (18.14-18.35) | 0.11 (0.11-0.11) |\n"
+ ]
+ }
+ ],
+ "source": [
+ "df_paper = df.reset_index()\n",
+ "df_paper = df_paper[(df_paper['method'].str.contains('formula')) & (~df_paper['method'].str.contains('no_formula'))]\n",
+ "df_paper = df_paper.sort_values('test_top_10_mces_dist', ascending=False, key=lambda x: x.str.split(' ').str[0].astype(float))\n",
+ "print(df_paper.to_markdown(index=False))"
]
}
],