Skip to content

Commit

Permalink
updated plotting scheme for fits
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Dec 10, 2024
1 parent bc37643 commit 192ce0d
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 33 deletions.
15 changes: 11 additions & 4 deletions src/tavi/data/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,12 @@ def add_background(
def params(self) -> Parameters:
"""Get fitting parameters as a dictionary with the model prefix being the key"""

if self._parameters is None:
self._parameters = self.guess()
if self.result is not None:
params = self.result.params
else:
params = self.guess()

self._parameters = params
return self._parameters

def guess(self) -> Parameters:
Expand Down Expand Up @@ -171,5 +174,9 @@ def eval(self, pars: Parameters, x: np.ndarray) -> np.ndarray:
def fit(self, pars: Parameters) -> ModelResult:

result = self.model.fit(self.y, pars, x=self.x, weights=self.err)
self.result = result
return result
if result.success:
self.result = result
self._parameters = result.params
return result
else:
return None
45 changes: 20 additions & 25 deletions src/tavi/plotter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# import matplotlib.colors as colors
from functools import partial
from typing import Optional, Union
from typing import Optional

import numpy as np
from mpl_toolkits.axisartist.grid_finder import MaxNLocator
Expand Down Expand Up @@ -73,11 +73,6 @@ def add_scan(self, scan_data: ScanData1D, **kwargs):
for key, val in kwargs.items():
scan_data.fmt.update({key: val})

def _add_fit_from_eval(self, fit_data: FitData1D, **kwargs):
self.fit_data.append(fit_data)
for key, val in kwargs.items():
fit_data.fmt.update({key: val})

def _add_fit_from_fitting(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100, **kwargs):
if (result := fit_data.result) is None:
raise ValueError("Fitting result is None.")
Expand All @@ -87,22 +82,24 @@ def _add_fit_from_fitting(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100
for key, val in kwargs.items():
data.fmt.update({key: val})

# TODO
def add_fit(
self, fit_data: Union[tuple[np.ndarray, np.ndarray], Fit1D], num_of_pts: Optional[int] = 100, **kwargs
):
if isinstance(fit_data, tuple):
x, y = fit_data
self._add_fit_from_eval(FitData1D(x, y), **kwargs)
elif isinstance(fit_data, Fit1D):
self._add_fit_from_fitting(fit_data, num_of_pts, **kwargs)
else:
raise ValueError(f"Invalid input fit_data={fit_data}")
def add_fit(self, fit1d: Fit1D, x: Optional[np.ndarray] = None, DISPLAY_PARAMS=True, **kwargs):
if x is None:
x = fit1d.x
if (result := fit1d.result) is None: # evaluate
y = fit1d.eval(fit1d.params, x)
else: # fit
y = result.eval(param=fit1d.params, x=x)
fit_data = FitData1D(x, y)
self.fit_data.append(fit_data)
for key, val in kwargs.items():
fit_data.fmt.update({key: val})

# TODO
def add_fit_components(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100, **kwargs):
if isinstance(fit_data, Fit1D) and (result := fit_data.result) is not None:
x = fit_data.x_to_plot(num_of_pts)
def add_fit_components(self, fit1d: Fit1D, x: Optional[np.ndarray] = None, DISPLAY_PARAMS=True, **kwargs):
if x is None:
x = fit1d.x
if (result := fit1d.result) is None: # fit first
pass
else:
components = result.eval_components(result.params, x=x)

num_components = len(components)
Expand All @@ -115,13 +112,11 @@ def add_fit_components(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100, *
for i, (prefix, y) in enumerate(components.items()):
data = FitData1D(x=x, y=y)
self.fit_data.append(data)
data.fmt.update({"label": prefix[:-1]}) # remove "_"
label = prefix[:-1] # remove "_"
data.fmt.update({"label": label})
for key, val in kwargs.items():
data.fmt.update({key: val[i]})

else:
raise ValueError(f"Invalid input fit_data={fit_data}")

def add_reso_bar(self, pos: tuple, fwhm: float, **kwargs):
reso_data = ResoBar(pos, fwhm)
for key, val in kwargs.items():
Expand Down
8 changes: 4 additions & 4 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_fit_single_peak_internal_model(fit_data):
if PLOT:
p1 = Plot1D()
p1.add_scan(s1_scan, fmt="o", label="data")
p1.add_fit(f1, label="fit", color="C3", num_of_pts=50, marker="^")
p1.add_fit(f1, label="fit", color="C3", marker="^")
p1.add_fit_components(f1, color=["C4", "C5"])

fig, ax = plt.subplots()
Expand All @@ -114,7 +114,7 @@ def test_fit_two_peak(fit_data):

f1 = Fit1D(s1_scan, fit_range=(0.0, 4.0), name="scan42_fit2peaks")

# f1.add_background(model="Constant")
f1.add_background(model="Constant")
f1.add_signal(model="Gaussian")
f1.add_signal(model="Gaussian")

Expand All @@ -130,14 +130,14 @@ def test_fit_two_peak(fit_data):
result = f1.fit(pars)
assert np.allclose(result.params["s1_center"].value, 3.54, atol=0.01)
assert np.allclose(result.params["s1_fwhm"].value, 0.40, atol=0.01)
x = f1.x_to_plot(num_of_pts=100, min=-1, max=5)
x = f1.x_to_plot(num_of_pts=200, min=-1, max=5)
# y = f1.eval(result.params, x)

if PLOT:
p1 = Plot1D()
p1.add_scan(s1_scan, fmt="o", label="data")
p1.add_fit(f1, x=x, label="fit", color="C3")
p1.add_fit_components(f1, x=x, color=["C4", "C5"])
p1.add_fit_components(f1, x=x, color=["C1", "C2", "C4"])

fig, ax = plt.subplots()
p1.plot(ax)
Expand Down

0 comments on commit 192ce0d

Please sign in to comment.