diff --git a/metatlas/targeted/rt_alignment.py b/metatlas/targeted/rt_alignment.py index 7010a8db..d2c2b85b 100644 --- a/metatlas/targeted/rt_alignment.py +++ b/metatlas/targeted/rt_alignment.py @@ -378,16 +378,37 @@ def get_atlas_name(ids: AnalysisIdentifiers, workflow: Workflow, analysis: Analy ) -def align_atlas(atlas: metob.Atlas, model: Model, rt_offset: float) -> metob.Atlas: +def align_atlas(atlas: metob.Atlas, model: Model, rt_offset: float, align_rt_min_max: bool) -> metob.Atlas: """use model to align RTs within atlas""" - aligned = atlas.clone(recursive=True) - old_peaks = [cid.rt_references[0].rt_peak for cid in aligned.compound_identifications] - new_peaks = model.predict(np.array(old_peaks, dtype=float)) - for peak, cid in zip(new_peaks, aligned.compound_identifications): - rt_ref = cid.rt_references[0] - rt_ref.rt_peak = peak - rt_ref.rt_min = peak - rt_offset - rt_ref.rt_max = peak + rt_offset + if align_rt_min_max is True: + logger.info(f"Using model to predict new RT peak, min, and max values for each compound (ignoring rt offset of {rt_offset}).") + aligned = atlas.clone(recursive=True) + old_peaks = [cid.rt_references[0].rt_peak for cid in aligned.compound_identifications] + new_peaks = model.predict(np.array(old_peaks, dtype=float)) + old_mins = [cid.rt_references[0].rt_min for cid in aligned.compound_identifications] + old_maxs = [cid.rt_references[0].rt_max for cid in aligned.compound_identifications] + new_mins = model.predict(np.array(old_mins, dtype=float)) + new_maxs = model.predict(np.array(old_maxs, dtype=float)) + for peak, min, max, cid in zip(new_peaks, new_mins, new_maxs, aligned.compound_identifications): + if peak - min < 0.05 or peak - min > 5: + logger.warning(f"Bound between RT peak and RT minimum for {cid.name} is abnormal: peak={peak} and minimum={min}.") + if max - peak < 0.05 or max - peak > 5: + logger.warning(f"Bound between RT maximum and RT peak for {cid.name} is abnormal: peak={peak} and maximum={max}.") + rt_ref = cid.rt_references[0] + rt_ref.rt_peak = peak + rt_ref.rt_min = min + rt_ref.rt_max = max + else: + logger.info(f"Using model to predict new RT peak for each compound and then setting min and max by {rt_offset} mins.") + aligned = atlas.clone(recursive=True) + old_peaks = [cid.rt_references[0].rt_peak for cid in aligned.compound_identifications] + new_peaks = model.predict(np.array(old_peaks, dtype=float)) + for peak, cid in zip(new_peaks, aligned.compound_identifications): + rt_ref = cid.rt_references[0] + rt_ref.rt_peak = peak + rt_ref.rt_min = peak - rt_offset + rt_ref.rt_max = peak + rt_offset + return aligned @@ -419,7 +440,7 @@ def create_aligned_atlases( logger.info("Creating atlas %s", name) out_atlas_file_name = ids.output_dir / f"{name}.csv" - aligned_atlas = align_atlas(template_atlas, model, analysis.atlas.rt_offset) if analysis.atlas.do_alignment else template_atlas + aligned_atlas = align_atlas(template_atlas, model, analysis.atlas.rt_offset, analysis.atlas.align_rt_min_max) if analysis.atlas.do_alignment else template_atlas logger.info("Collecting data for pre-filter") if analysis.atlas.do_prefilter else None aligned_filtered_atlas = filter_atlas(aligned_atlas, ids, analysis, data) if analysis.atlas.do_prefilter else aligned_atlas aligned_filtered_atlas.name = name diff --git a/metatlas/tools/config.py b/metatlas/tools/config.py index 109490bd..98b6d0a8 100644 --- a/metatlas/tools/config.py +++ b/metatlas/tools/config.py @@ -144,6 +144,7 @@ class Atlas(BaseModel): name: str do_alignment: bool = False do_prefilter: bool = False + align_rt_min_max: bool = False rt_offset: float = 0.5 @validator("unique_id") diff --git a/test_config.yaml b/test_config.yaml index 84f5a96a..8ce3c0bf 100644 --- a/test_config.yaml +++ b/test_config.yaml @@ -71,6 +71,7 @@ workflows: unique_id: 89694aa326cd46958d38d8e9066de16c do_alignment: True do_prefilter: False + align_rt_min_max: False parameters: copy_atlas: True polarity: positive @@ -110,6 +111,7 @@ workflows: unique_id: f74a731c590544aba5c3720b346e508e do_alignment: True do_prefilter: True + align_rt_min_max: False rt_offset: 0.2 parameters: copy_atlas: True diff --git a/tests/unit/test_rt_alignment.py b/tests/unit/test_rt_alignment.py index ee43ce95..ee6c9a63 100644 --- a/tests/unit/test_rt_alignment.py +++ b/tests/unit/test_rt_alignment.py @@ -22,15 +22,46 @@ def test_plot_actual_vs_aligned_rts01(model): rt_alignment.plot_actual_vs_aligned_rts(arrays, arrays, rts_df, "file_name", model, model, model) -def test_align_atlas(atlas_with_2_cids, model): - out = rt_alignment.align_atlas(atlas_with_2_cids, model, 0) +def test_align_atlas01(atlas_with_2_cids, model): + out = rt_alignment.align_atlas(atlas_with_2_cids, model, 0, False) assert out.name == "HILICz150_ANT20190824_PRD_EMA_Unlab_POS" assert not math.isclose(atlas_with_2_cids.compound_identifications[0].rt_references[0].rt_peak, -0.8035359946292822, abs_tol=1e-7) assert math.isclose(out.compound_identifications[0].rt_references[0].rt_peak, -0.8035359946292822, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[0].rt_references[0].rt_min, -0.8035359946292822, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[0].rt_references[0].rt_max, -0.8035359946292822, abs_tol=1e-7) assert not math.isclose(atlas_with_2_cids.compound_identifications[1].rt_references[0].rt_peak, 0.02331840799266649, abs_tol=1e-7) assert math.isclose(out.compound_identifications[1].rt_references[0].rt_peak, 0.02331840799266649, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[1].rt_references[0].rt_min, 0.02331840799266649, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[1].rt_references[0].rt_max, 0.02331840799266649, abs_tol=1e-7) assert len(out.compound_identifications) == 2 +def test_align_atlas02(atlas_with_2_cids, model): + rt_offset = 0.5 + out = rt_alignment.align_atlas(atlas_with_2_cids, model, rt_offset, False) + assert out.name == "HILICz150_ANT20190824_PRD_EMA_Unlab_POS" + assert not math.isclose(atlas_with_2_cids.compound_identifications[0].rt_references[0].rt_peak, -0.8035359946292822, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[0].rt_references[0].rt_peak, -0.8035359946292822, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[0].rt_references[0].rt_min, -0.8035359946292822-rt_offset, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[0].rt_references[0].rt_max, -0.8035359946292822+rt_offset, abs_tol=1e-7) + assert not math.isclose(atlas_with_2_cids.compound_identifications[1].rt_references[0].rt_peak, 0.02331840799266649, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[1].rt_references[0].rt_peak, 0.02331840799266649, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[1].rt_references[0].rt_min, 0.02331840799266649-rt_offset, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[1].rt_references[0].rt_max, 0.02331840799266649+rt_offset, abs_tol=1e-7) + assert len(out.compound_identifications) == 2 + +def test_align_atlas03(atlas_with_2_cids, model): + rt_offset = 0.25 + out = rt_alignment.align_atlas(atlas_with_2_cids, model, rt_offset, True) + assert out.name == "HILICz150_ANT20190824_PRD_EMA_Unlab_POS" + assert not math.isclose(atlas_with_2_cids.compound_identifications[0].rt_references[0].rt_peak, -0.8035359946292822, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[0].rt_references[0].rt_peak, -0.8035359946292822, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[0].rt_references[0].rt_min, -0.8035359946292822-0.5, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[0].rt_references[0].rt_max, -0.8035359946292822+0.5, abs_tol=1e-7) + assert not math.isclose(atlas_with_2_cids.compound_identifications[1].rt_references[0].rt_peak, 0.02331840799266649, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[1].rt_references[0].rt_peak, 0.02331840799266649, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[1].rt_references[0].rt_min, 0.02331840799266649-0.5, abs_tol=1e-7) + assert math.isclose(out.compound_identifications[1].rt_references[0].rt_max, 0.02331840799266649+0.5, abs_tol=1e-7) + assert len(out.compound_identifications) == 2 def test_create_aligned_atlases(model, analysis_ids, metatlas_dataset, workflow): atlases = rt_alignment.create_aligned_atlases(model, model, model, analysis_ids, workflow, metatlas_dataset)