Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the way analysis settings are set #110

Open
wants to merge 10 commits into
base: v2.0.0
Choose a base branch
from
63 changes: 54 additions & 9 deletions tests/analyses/new/test_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import inspect
import json
from dataclasses import fields
Expand All @@ -7,6 +9,7 @@
import numpy as np
import pytest
from pandas import DataFrame
from pydantic import Field
from pydantic.dataclasses import dataclass
from pydantic.fields import FieldInfo

Expand All @@ -15,14 +18,24 @@
AnalysisData,
AnalysisMetadata,
AnalysisResult,
AnalysisSettings,
BaseAnalysisWrapper,
_validated_setter,
)
from zospy.analyses.new.decorators import analysis_settings
from zospy.analyses.new.parsers.types import ValidatedDataFrame
from zospy.analyses.new.reports.surface_data import SurfaceDataSettings
from zospy.analyses.new.systemviewers.base import SystemViewerWrapper


def all_subclasses(cls):
return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in all_subclasses(c)])


analysis_wrapper_classes = all_subclasses(BaseAnalysisWrapper)
analysis_wrapper_classes.remove(SystemViewerWrapper)


class TestValidatedSetter:
class MockSettings:
int_setting: int = 1
Expand Down Expand Up @@ -53,20 +66,16 @@ def test_set_non_existing(self):
settings.non_existing = 2


analysis_wrapper_classes = BaseAnalysisWrapper.__subclasses__()
analysis_wrapper_classes.remove(SystemViewerWrapper)


@dataclass
class MockAnalysisData:
int_data: int = 1
string_data: str = "a"


@dataclass
@analysis_settings
class MockAnalysisSettings:
int_setting: int = 1
string_setting: str = "a"
int_setting: int = Field(default=1, description="An integer setting")
string_setting: str = Field(default="a", description="A string setting")


class MockAnalysis(BaseAnalysisWrapper[MockAnalysisData, MockAnalysisSettings]):
Expand All @@ -75,8 +84,14 @@ class MockAnalysis(BaseAnalysisWrapper[MockAnalysisData, MockAnalysisSettings]):
_needs_config_file = False
_needs_text_output_file = False

def __init__(self, int_setting: int = 1, string_setting: str = "a", *, block_remove_temp_files: bool = False):
super().__init__(MockAnalysisSettings(), locals())
def __init__(
self,
*,
int_setting: int = 1,
string_setting: str = "a",
block_remove_temp_files: bool = False,
):
super().__init__(settings_kws=locals())

self.block_remove_temp_files = block_remove_temp_files

Expand Down Expand Up @@ -112,11 +127,21 @@ def get_settings_defaults(settings_class):

return result

def test_get_settings_type(self):
assert MockAnalysis._settings_type == MockAnalysisSettings # noqa: SLF001

def test_settings_type_is_specified(self):
assert MockAnalysis._settings_type is not AnalysisSettings # noqa: SLF001

@pytest.mark.parametrize("cls", analysis_wrapper_classes)
def test_analyses_correct_analysis_name(self, cls):
assert cls.TYPE is not None
assert hasattr(constants.Analysis.AnalysisIDM, cls.TYPE)

@pytest.mark.parametrize("cls", analysis_wrapper_classes)
def test_init_all_keyword_only_parameters(self, cls):
all(p.kind.name == "KEYWORD_ONLY" for _, p in inspect.signature(cls).parameters.items())

@pytest.mark.parametrize("cls", analysis_wrapper_classes)
def test_init_contains_all_settings(self, cls):
if cls().settings is None:
Expand All @@ -139,6 +164,26 @@ def test_analyses_default_values(self, cls):
assert field_name in init_signature.parameters
assert init_signature.parameters[field_name].default == default_value

def test_change_settings_from_parameters(self):
analysis = MockAnalysis(int_setting=2, string_setting="b")

assert analysis.settings.int_setting == 2
assert analysis.settings.string_setting == "b"

def test_change_settings_from_object(self):
settings = MockAnalysisSettings(int_setting=2, string_setting="b")
analysis = MockAnalysis.from_settings(settings)

assert analysis.settings.int_setting == 2
assert analysis.settings.string_setting == "b"

def test_settings_object_is_copied(self):
settings = MockAnalysisSettings(int_setting=2, string_setting="b")
analysis = MockAnalysis.from_settings(settings)

assert analysis.settings is not settings
assert analysis.settings == settings

@pytest.mark.parametrize(
"temp_file_type,filename",
[
Expand Down
2 changes: 1 addition & 1 deletion tests/analyses/new/test_systemviewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class MockSystemViewerSettings:

class MockSystemViewer(SystemViewerWrapper[MockSystemViewerSettings]):
def __init__(self, *, number: int = 5, settings: TestBase.MockSystemViewerSettings | None = None):
super().__init__(settings or TestBase.MockSystemViewerSettings(), locals())
super().__init__(locals())

def _create_analysis(self, *, settings_first=True): # noqa: ARG002
self._analysis = SimpleNamespace(
Expand Down
104 changes: 95 additions & 9 deletions zospy/analyses/new/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@

from __future__ import annotations

import dataclasses
import os
import weakref
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields, is_dataclass
from dataclasses import dataclass, is_dataclass
from datetime import datetime # noqa: TCH003 Pydantic needs datetime to be present at runtime
from enum import Enum
from importlib import import_module
from pathlib import Path
from tempfile import mkstemp
from typing import TYPE_CHECKING, Generic, Literal, TypedDict, TypeVar, cast
from typing import TYPE_CHECKING, Generic, Literal, TypedDict, TypeVar, cast, get_args

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -544,8 +545,23 @@ class BaseAnalysisWrapper(ABC, Generic[AnalysisData, AnalysisSettings]):
_needs_config_file: bool = False
_needs_text_output_file: bool = False

def __init__(self, settings: AnalysisSettings, settings_arguments: dict[str, any]):
self._init_settings(settings, settings_arguments)
def __init__(self, *, settings_kws: dict[str, any] | None = None):
"""Create a new analysis wrapper.

Settings can be changed by passing the settings as keyword arguments. Use the `from_settings` method to specify
the settings using a settings object.

Parameters
----------
settings_kws : dict[str, any]
Arguments to set the settings of the analysis.

Raises
------
ValueError
If `settings` is not a dataclass.
"""
self._copy_settings(settings_kws=settings_kws)

self._config_file = None
self._text_output_file = None
Expand All @@ -555,18 +571,88 @@ def __init__(self, settings: AnalysisSettings, settings_arguments: dict[str, any
self._remove_config_file = False
self._remove_text_output_file = False

def _init_settings(self, settings: AnalysisSettings, parameters: dict[str, any]):
self._settings = settings
def __init_subclass__(
cls,
*,
analysis_type: str | None = None,
mode: Literal["Sequential", "Nonsequential"] | None = None,
needs_config_file: bool = False,
needs_text_output_file: bool = False,
**kwargs,
):
"""Determine the settings type and class-level configuration of the analysis."""
cls.TYPE = analysis_type
cls.MODE = mode
cls._needs_config_file = needs_config_file
cls._needs_text_output_file = needs_text_output_file

if not hasattr(cls, "_settings_type"):
if hasattr(cls, "__orig_bases__"):
base = cls.__orig_bases__[0]
cls._settings_type: type[AnalysisSettings] = get_args(base)[1]
else:
cls._settings_type = None

super().__init_subclass__(**kwargs)

def _copy_settings(
self, *, settings: AnalysisSettings | None = None, settings_kws: dict[str, any] | None = None
) -> None:
"""Copy settings to the settings object of the analysis.

Settings should be specified using either the `settings` argument or the `parameters` dictionary. An error is
raised if both are specified. If no settings are specified, the settings object is set to the default settings.

Parameters
----------
settings : AnalysisSettings
Analysis settings object.
settings_kws
Dictionary with the settings parameters.

Raises
------
ValueError
If both `settings` and `parameters` are specified.
If `settings` is not a dataclass.
"""
if settings is not None and settings_kws is not None:
raise ValueError(
"Settings should either be specified as a settings object or as keyword arguments, not both."
)

if settings is None:
settings = None if self._settings_type is None else self._settings_type()

if settings is None:
self._settings = None
return

if not is_dataclass(settings):
raise ValueError("settings should be a dataclass.")

for field in fields(settings):
if field.name in parameters:
setattr(self.settings, field.name, parameters[field.name])
# Create a new settings object with the specified parameters. If no parameters are specified, this creates a
# copy of the settings object. This is done to avoid modifying the original settings object.
self._settings = dataclasses.replace(settings, **(settings_kws or {}))

@classmethod
def from_settings(cls, settings: AnalysisSettings):
"""Create a new analysis with the specified settings.

Parameters
----------
settings : AnalysisSettings
Settings of the analysis.

Returns
-------
BaseAnalysisWrapper
The analysis wrapper.
"""
instance = cls()
instance._copy_settings(settings=settings) # noqa: SLF001

return instance

@property
def settings(self) -> AnalysisSettings:
Expand Down
14 changes: 6 additions & 8 deletions zospy/analyses/new/mtf/fft_through_focus_mtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,13 @@ class FFTThroughFocusMTFSettings:
use_dashes: bool = Field(default=False, description="Use dashes")


class FFTThroughFocusMTF(BaseAnalysisWrapper[Union[DataFrame, None], FFTThroughFocusMTFSettings]):
class FFTThroughFocusMTF(
BaseAnalysisWrapper[Union[DataFrame, None], FFTThroughFocusMTFSettings],
analysis_type="FftThroughFocusMtf",
needs_config_file=True,
):
"""FFT Through Focus MTF analysis."""

TYPE = "FftThroughFocusMtf"

_needs_config_file = True
_needs_text_output_file = False

def __init__(
self,
*,
Expand All @@ -76,15 +75,14 @@ def __init__(
mtf_type: constants.Analysis.Settings.Mtf.MtfTypes | str = "Modulation",
use_polarization: bool = False,
use_dashes: bool = False,
settings: FFTThroughFocusMTFSettings | None = None,
):
"""Create a new FFT Through Focus MTF analysis.

See Also
--------
FFTThroughFocusMTFSettings : Settings for the FFT Through Focus MTF analysis.
"""
super().__init__(settings or FFTThroughFocusMTFSettings(), locals())
super().__init__(settings_kws=locals())

def run_analysis(self) -> DataFrame | None:
"""Run the FFT Through Focus MTF analysis."""
Expand Down
17 changes: 8 additions & 9 deletions zospy/analyses/new/mtf/huygens_mtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,9 @@ class HuygensMtfSettings:
use_dashes: bool = Field(default=False, description="Use dashes")


class HuygensMTF(BaseAnalysisWrapper[DataFrame, HuygensMtfSettings]):
class HuygensMTF(BaseAnalysisWrapper[DataFrame, HuygensMtfSettings], analysis_type="HuygensMtf"):
"""Huygens Modulation Transfer Function (MTF) analysis."""

TYPE = "HuygensMtf"

_needs_config_file = False
_needs_text_output_file = False

def __init__(
self,
*,
Expand All @@ -81,10 +76,14 @@ def __init__(
maximum_frequency: float = 150.0,
use_polarization: bool = False,
use_dashes: bool = False,
settings: HuygensMtfSettings | None = None,
):
"""Create a new Huygens MTF analysis."""
super().__init__(settings or HuygensMtfSettings(), locals())
"""Create a new Huygens MTF analysis.

See Also
--------
HuygensMtfSettings : Settings for the Huygens MTF analysis.
"""
super().__init__(settings_kws=locals())

def run_analysis(self) -> DataFrame | None:
"""Run the Huygens MTF analysis."""
Expand Down
15 changes: 7 additions & 8 deletions zospy/analyses/new/polarization/pupil_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ class PolarizationPupilMapSettings:
sub_configs: str = Field(default="", description="Subtract configurations")


class PolarizationPupilMap(BaseAnalysisWrapper[None, PolarizationPupilMapSettings]):
class PolarizationPupilMap(
BaseAnalysisWrapper[None, PolarizationPupilMapSettings],
analysis_type="PolarizationPupilMap",
needs_config_file=True,
needs_text_output_file=True,
):
"""Polarization pupil map analysis."""

TYPE = "PolarizationPupilMap"

_needs_config_file = True
_needs_text_output_file = True

def __init__(
self,
*,
Expand All @@ -105,15 +105,14 @@ def __init__(
sampling: str | int = "11x11",
add_configs: str = "",
sub_configs: str = "",
settings: PolarizationPupilMapSettings | None = None,
):
"""Create a new polarization pupil map analysis.

See Also
--------
PolarizationPupilMapSettings : Settings for the polarization pupil map analysis.
"""
super().__init__(settings or PolarizationPupilMapSettings(), locals())
super().__init__(settings_kws=locals())

def run_analysis(self) -> PolarizationPupilMapResult:
"""Run the polarization pupil map analysis."""
Expand Down
Loading
Loading