Skip to content

Commit

Permalink
pass both source and targe wavenumbers at init for Resample transform
Browse files Browse the repository at this point in the history
  • Loading branch information
franckalbinet committed Jan 23, 2025
1 parent eeebdab commit 98aee71
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 18 deletions.
25 changes: 15 additions & 10 deletions nbs/00_core.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[DEFAULT]
repo = soilspectfm
lib_name = soilspectfm
version = 0.0.3
version = 0.0.4
min_python = 3.7
license = apache2
black_formatting = False
Expand Down
2 changes: 1 addition & 1 deletion soilspectfm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.3"
__version__ = "0.0.4"
10 changes: 4 additions & 6 deletions soilspectfm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,19 +204,17 @@ def transform(self, X, y=None): return -np.log10(np.clip(X, self.eps, 1))
class Resample(BaseEstimator, TransformerMixin):
"Resampling transformer compatible with scikit-learn."
def __init__(self,
source_x: np.ndarray, # Source x-axis points (wavenumbers or wavelengths)
target_x: np.ndarray, # Target x-axis points (wavenumbers or wavelengths) for resampling
interpolation_kind: str='cubic' # Type of spline interpolation to use
):
store_attr()

def fit(self,
X: np.ndarray, # Spectral data to be resampled
x: np.ndarray=None, # Original x-axis points (wavenumbers or wavelengths)
y: np.ndarray=None # Original y-axis points
y: np.ndarray=None # Not used
):
"Fit the transformer"
if x is None: raise ValueError("Original x-axis (wavenumbers or wavelengths) must be provided")
self.original_x_ = np.asarray(x)
"No-op in that particular case"
return self

def transform(self,
Expand All @@ -227,7 +225,7 @@ def transform(self,
X_transformed = np.zeros((X.shape[0], len(self.target_x)))

for i in range(X.shape[0]):
cs = CubicSpline(self.original_x_, X[i])
cs = CubicSpline(self.source_x, X[i])
X_transformed[i] = cs(self.target_x)

return X_transformed

0 comments on commit 98aee71

Please sign in to comment.