Skip to content

Commit

Permalink
Merge branch 'main' into increase_lockwait
Browse files Browse the repository at this point in the history
  • Loading branch information
bkieft-usa authored Feb 10, 2025
2 parents da26c57 + c9a3a9c commit 6d962e1
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 12 deletions.
41 changes: 31 additions & 10 deletions metatlas/targeted/rt_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,16 +378,37 @@ def get_atlas_name(ids: AnalysisIdentifiers, workflow: Workflow, analysis: Analy
)


def align_atlas(atlas: metob.Atlas, model: Model, rt_offset: float) -> metob.Atlas:
def align_atlas(atlas: metob.Atlas, model: Model, rt_offset: float, align_rt_min_max: bool) -> metob.Atlas:
"""use model to align RTs within atlas"""
aligned = atlas.clone(recursive=True)
old_peaks = [cid.rt_references[0].rt_peak for cid in aligned.compound_identifications]
new_peaks = model.predict(np.array(old_peaks, dtype=float))
for peak, cid in zip(new_peaks, aligned.compound_identifications):
rt_ref = cid.rt_references[0]
rt_ref.rt_peak = peak
rt_ref.rt_min = peak - rt_offset
rt_ref.rt_max = peak + rt_offset
if align_rt_min_max is True:
logger.info(f"Using model to predict new RT peak, min, and max values for each compound (ignoring rt offset of {rt_offset}).")
aligned = atlas.clone(recursive=True)
old_peaks = [cid.rt_references[0].rt_peak for cid in aligned.compound_identifications]
new_peaks = model.predict(np.array(old_peaks, dtype=float))
old_mins = [cid.rt_references[0].rt_min for cid in aligned.compound_identifications]
old_maxs = [cid.rt_references[0].rt_max for cid in aligned.compound_identifications]
new_mins = model.predict(np.array(old_mins, dtype=float))
new_maxs = model.predict(np.array(old_maxs, dtype=float))
for peak, min, max, cid in zip(new_peaks, new_mins, new_maxs, aligned.compound_identifications):
if peak - min < 0.05 or peak - min > 5:
logger.warning(f"Bound between RT peak and RT minimum for {cid.name} is abnormal: peak={peak} and minimum={min}.")
if max - peak < 0.05 or max - peak > 5:
logger.warning(f"Bound between RT maximum and RT peak for {cid.name} is abnormal: peak={peak} and maximum={max}.")
rt_ref = cid.rt_references[0]
rt_ref.rt_peak = peak
rt_ref.rt_min = min
rt_ref.rt_max = max
else:
logger.info(f"Using model to predict new RT peak for each compound and then setting min and max by {rt_offset} mins.")
aligned = atlas.clone(recursive=True)
old_peaks = [cid.rt_references[0].rt_peak for cid in aligned.compound_identifications]
new_peaks = model.predict(np.array(old_peaks, dtype=float))
for peak, cid in zip(new_peaks, aligned.compound_identifications):
rt_ref = cid.rt_references[0]
rt_ref.rt_peak = peak
rt_ref.rt_min = peak - rt_offset
rt_ref.rt_max = peak + rt_offset

return aligned


Expand Down Expand Up @@ -419,7 +440,7 @@ def create_aligned_atlases(
logger.info("Creating atlas %s", name)
out_atlas_file_name = ids.output_dir / f"{name}.csv"

aligned_atlas = align_atlas(template_atlas, model, analysis.atlas.rt_offset) if analysis.atlas.do_alignment else template_atlas
aligned_atlas = align_atlas(template_atlas, model, analysis.atlas.rt_offset, analysis.atlas.align_rt_min_max) if analysis.atlas.do_alignment else template_atlas
logger.info("Collecting data for pre-filter") if analysis.atlas.do_prefilter else None
aligned_filtered_atlas = filter_atlas(aligned_atlas, ids, analysis, data) if analysis.atlas.do_prefilter else aligned_atlas
aligned_filtered_atlas.name = name
Expand Down
1 change: 1 addition & 0 deletions metatlas/tools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class Atlas(BaseModel):
name: str
do_alignment: bool = False
do_prefilter: bool = False
align_rt_min_max: bool = False
rt_offset: float = 0.5

@validator("unique_id")
Expand Down
2 changes: 2 additions & 0 deletions test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ workflows:
unique_id: 89694aa326cd46958d38d8e9066de16c
do_alignment: True
do_prefilter: False
align_rt_min_max: False
parameters:
copy_atlas: True
polarity: positive
Expand Down Expand Up @@ -110,6 +111,7 @@ workflows:
unique_id: f74a731c590544aba5c3720b346e508e
do_alignment: True
do_prefilter: True
align_rt_min_max: False
rt_offset: 0.2
parameters:
copy_atlas: True
Expand Down
35 changes: 33 additions & 2 deletions tests/unit/test_rt_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,46 @@ def test_plot_actual_vs_aligned_rts01(model):
rt_alignment.plot_actual_vs_aligned_rts(arrays, arrays, rts_df, "file_name", model, model, model)


def test_align_atlas(atlas_with_2_cids, model):
out = rt_alignment.align_atlas(atlas_with_2_cids, model, 0)
def test_align_atlas01(atlas_with_2_cids, model):
out = rt_alignment.align_atlas(atlas_with_2_cids, model, 0, False)
assert out.name == "HILICz150_ANT20190824_PRD_EMA_Unlab_POS"
assert not math.isclose(atlas_with_2_cids.compound_identifications[0].rt_references[0].rt_peak, -0.8035359946292822, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[0].rt_references[0].rt_peak, -0.8035359946292822, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[0].rt_references[0].rt_min, -0.8035359946292822, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[0].rt_references[0].rt_max, -0.8035359946292822, abs_tol=1e-7)
assert not math.isclose(atlas_with_2_cids.compound_identifications[1].rt_references[0].rt_peak, 0.02331840799266649, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[1].rt_references[0].rt_peak, 0.02331840799266649, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[1].rt_references[0].rt_min, 0.02331840799266649, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[1].rt_references[0].rt_max, 0.02331840799266649, abs_tol=1e-7)
assert len(out.compound_identifications) == 2

def test_align_atlas02(atlas_with_2_cids, model):
rt_offset = 0.5
out = rt_alignment.align_atlas(atlas_with_2_cids, model, rt_offset, False)
assert out.name == "HILICz150_ANT20190824_PRD_EMA_Unlab_POS"
assert not math.isclose(atlas_with_2_cids.compound_identifications[0].rt_references[0].rt_peak, -0.8035359946292822, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[0].rt_references[0].rt_peak, -0.8035359946292822, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[0].rt_references[0].rt_min, -0.8035359946292822-rt_offset, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[0].rt_references[0].rt_max, -0.8035359946292822+rt_offset, abs_tol=1e-7)
assert not math.isclose(atlas_with_2_cids.compound_identifications[1].rt_references[0].rt_peak, 0.02331840799266649, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[1].rt_references[0].rt_peak, 0.02331840799266649, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[1].rt_references[0].rt_min, 0.02331840799266649-rt_offset, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[1].rt_references[0].rt_max, 0.02331840799266649+rt_offset, abs_tol=1e-7)
assert len(out.compound_identifications) == 2

def test_align_atlas03(atlas_with_2_cids, model):
rt_offset = 0.25
out = rt_alignment.align_atlas(atlas_with_2_cids, model, rt_offset, True)
assert out.name == "HILICz150_ANT20190824_PRD_EMA_Unlab_POS"
assert not math.isclose(atlas_with_2_cids.compound_identifications[0].rt_references[0].rt_peak, -0.8035359946292822, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[0].rt_references[0].rt_peak, -0.8035359946292822, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[0].rt_references[0].rt_min, -0.8035359946292822-0.5, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[0].rt_references[0].rt_max, -0.8035359946292822+0.5, abs_tol=1e-7)
assert not math.isclose(atlas_with_2_cids.compound_identifications[1].rt_references[0].rt_peak, 0.02331840799266649, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[1].rt_references[0].rt_peak, 0.02331840799266649, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[1].rt_references[0].rt_min, 0.02331840799266649-0.5, abs_tol=1e-7)
assert math.isclose(out.compound_identifications[1].rt_references[0].rt_max, 0.02331840799266649+0.5, abs_tol=1e-7)
assert len(out.compound_identifications) == 2

def test_create_aligned_atlases(model, analysis_ids, metatlas_dataset, workflow):
atlases = rt_alignment.create_aligned_atlases(model, model, model, analysis_ids, workflow, metatlas_dataset)
Expand Down

0 comments on commit 6d962e1

Please sign in to comment.