From dab93fffe116ae69143069260dc3f017b7a8e78a Mon Sep 17 00:00:00 2001 From: pesap Date: Sun, 26 Jan 2025 12:40:53 -0700 Subject: [PATCH] feat: Add new function that creates an instance of a model with the option to skip validation (#110) This PR adds a new method that creates an instance of a model but with the option skip the validation of pydantic by setting a `skip_validation` flag. --- src/r2x/parser/handler.py | 34 +++++++++++++++----- src/r2x/parser/reeds.py | 66 +++++++++++++++++++-------------------- tests/test_models.py | 15 +++++++++ 3 files changed, 73 insertions(+), 42 deletions(-) diff --git a/src/r2x/parser/handler.py b/src/r2x/parser/handler.py index aabef042..92a44a0a 100644 --- a/src/r2x/parser/handler.py +++ b/src/r2x/parser/handler.py @@ -4,26 +4,29 @@ """ # System packages -import json -from copy import deepcopy import inspect -from dataclasses import dataclass, field +import json from abc import ABC, abstractmethod -from typing import Any, TypeVar from collections.abc import Callable, Sequence +from copy import deepcopy +from dataclasses import dataclass, field from pathlib import Path +from typing import Any, TypeVar + +import pandas as pd +import polars as pl # Third-party packages +from infrasys.component import Component from loguru import logger -import polars as pl -import pandas as pd +from plexosdb import XMLHandler # Local packages from r2x.api import System from r2x.config_scenario import Scenario -from plexosdb import XMLHandler -from .polars_helpers import pl_filter_year, pl_lowercase, pl_rename + from ..utils import check_file_exists +from .polars_helpers import pl_filter_year, pl_lowercase, pl_rename @dataclass @@ -307,3 +310,18 @@ def get_parser_data( logger.debug("Starting creation of system: {}", config.name) return parser + + +def create_model_instance( + model_class: type["Component"], skip_validation: bool = False, **field_values +) -> Any: + """Create R2X model instance.""" + valid_fields = { + key: value + for key, value in field_values.items() + if key in model_class.model_fields + if value is not None + } + if skip_validation: + return model_class.model_construct(**valid_fields) # type: ignore + return model_class.model_validate(valid_fields) diff --git a/src/r2x/parser/reeds.py b/src/r2x/parser/reeds.py index ddb1725f..fb15c19c 100644 --- a/src/r2x/parser/reeds.py +++ b/src/r2x/parser/reeds.py @@ -41,7 +41,7 @@ from r2x.models.core import MinMax from r2x.models.costs import HydroGenerationCost, ThermalGenerationCost from r2x.models.generators import HydroDispatch, HydroEnergyReservoir, RenewableGen, ThermalGen -from r2x.parser.handler import BaseParser +from r2x.parser.handler import BaseParser, create_model_instance from r2x.units import ActivePower, EmissionRate, Energy, Percentage, Time, ureg from r2x.utils import get_enum_from_string, match_category, read_csv @@ -76,6 +76,7 @@ def __init__(self, *args, **kwargs) -> None: raise AttributeError("Missing solve year from the configuration class.") self.device_map = self.reeds_config.defaults["reeds_device_map"] self.weather_year: int = self.reeds_config.weather_year + self.skip_validation: bool = getattr(self.reeds_config, "skip_validation", False) # Add hourly_time_index self.hourly_time_index = np.arange( @@ -120,14 +121,15 @@ def _construct_buses(self): zones = bus_data["transmission_region"].unique() for zone in zones: - self.system.add_component(LoadZone(name=zone)) + self.system.add_component(self._create_model_instance(LoadZone, name=zone)) for area in bus_data["state"].unique(): - self.system.add_component(Area(name=area)) + self.system.add_component(self._create_model_instance(Area, name=area)) for idx, bus in enumerate(bus_data.iter_rows(named=True)): self.system.add_component( - ACBus( + self._create_model_instance( + ACBus, number=idx + 1, name=bus["region"], area=self.system.get_component(Area, name=bus["state"]), @@ -149,7 +151,8 @@ def _construct_reserves(self): vors = self.reeds_config.defaults["reserve_vors"].get(name) reserve_area = self.system.get_component(LoadZone, name=reserve) self.system.add_component( - Reserve( + self._create_model_instance( + Reserve, name=f"{reserve}_{name}", region=reserve_area, reserve_type=ReserveType[name], @@ -161,7 +164,7 @@ def _construct_reserves(self): ) ) # Add reserve map - self.system.add_component(ReserveMap(name="reserve_map")) + self.system.add_component(self._create_model_instance(ReserveMap, name="reserve_map")) def _construct_branches(self): logger.info("Creating branch objects.") @@ -200,7 +203,8 @@ def _construct_branches(self): losses = branch["losses"] if branch["losses"] else 0 self.system.add_component( - MonitoredLine( + self._create_model_instance( + MonitoredLine, category=branch["kind"], name=branch_name, from_bus=from_bus, @@ -218,7 +222,7 @@ def _construct_tx_interfaces(self): MonitoredLine, filter_func=lambda x: x.from_bus.load_zone.name != x.to_bus.load_zone.name ) interfaces = defaultdict(dict) # Holder of interfaces - tx_interface_map = TransmissionInterfaceMap(name="transmission_map") + tx_interface_map = self._create_model_instance(TransmissionInterfaceMap, name="transmission_map") for line in interface_lines: zone_from = line.from_bus.load_zone.name zone_to = line.to_bus.load_zone.name @@ -249,7 +253,8 @@ def _construct_tx_interfaces(self): # Ramp multiplier defines the MW/min ratio for the interface ramp_multiplier = self.reeds_config.defaults["interface_max_ramp_up_multiplier"] self.system.add_component( - TransmissionInterface( + self._create_model_instance( + TransmissionInterface, name=interface_name, active_power_flow_limits=MinMax(-max_power_flow, max_power_flow), direction_mapping={}, # TBD @@ -287,11 +292,8 @@ def _construct_emissions(self) -> None: generator_emission = emit_rates.filter(pl.col("generator_name") == generator.name) for row in generator_emission.iter_rows(named=True): row["rate"] = EmissionRate(row["rate"], "kg/MWh") - valid_fields = {key: value for key, value in row.items() if key in Emission.model_fields} - valid_fields["emission_type"] = get_enum_from_string( - valid_fields["emission_type"], EmissionType - ) - emission_model = Emission(**valid_fields) + row["emission_type"] = get_enum_from_string(row["emission_type"], EmissionType) + emission_model = self._create_model_instance(Emission, **row) self.system.add_component(emission_model) def _construct_generators(self) -> None: # noqa: C901 @@ -489,23 +491,13 @@ def _construct_generators(self) -> None: # noqa: C901 # will need to change it. row["active_power_limits"] = MinMax(min=0, max=row["active_power"]) - valid_fields = { - key: value for key, value in row.items() if key in gen_model.model_fields if value is not None + row["ext"] = {} + row["ext"] = { + "tech": row["tech"], + "reeds_tech": row["tech"], + "reeds_vintage": row["tech_vintage"], } - valid_fields["ext"] = {} - # valid_fields["ext"] = { - # key: value for key, value in row.items() if key not in valid_fields if value - # } - - valid_fields["ext"].update( - { - "tech": row["tech"], - "reeds_tech": row["tech"], - "reeds_vintage": row["tech_vintage"], - } - ) - - self.system.add_component(gen_model(**valid_fields)) + self.system.add_component(self._create_model_instance(gen_model, **row)) def _construct_load(self): logger.info("Adding load time series.") @@ -531,7 +523,9 @@ def _construct_load(self): ) user_dict = {"solve_year": self.reeds_config.weather_year} max_load = np.max(ts.data) - load = PowerLoad(name=f"{bus.name}", bus=bus, max_active_power=max_load) + load = self._create_model_instance( + PowerLoad, name=f"{bus.name}", bus=bus, max_active_power=max_load + ) self.system.add_component(load) self.system.add_time_series(ts, load, **user_dict) @@ -841,7 +835,8 @@ def _construct_hybrid_systems(self): storage_unit_fields["prime_mover_type"] = PrimeMoversType.BA # If at some point we change the power of the storage it should be here - storage_unit = GenericBattery( + storage_unit = self._create_model_instance( + GenericBattery, name=f"{hybrid_name}", active_power=device.active_power, # Assume same power for the battery category="pvb-storage", @@ -856,8 +851,8 @@ def _construct_hybrid_systems(self): ts = self.system.get_time_series(device) self.system.add_time_series(ts, new_device) self.system.remove_component(device) - hybrid_construct = HybridSystem( - name=f"{hybrid_name}", renewable_unit=new_device, storage_unit=storage_unit + hybrid_construct = self._create_model_instance( + HybridSystem, name=f"{hybrid_name}", renewable_unit=new_device, storage_unit=storage_unit ) self.system.add_component(hybrid_construct) @@ -865,3 +860,6 @@ def _aggregate_renewable_generators(self, data: pl.DataFrame) -> pl.DataFrame: return data.group_by(["tech", "region"]).agg( [pl.col("active_power").sum(), pl.exclude("active_power").first()] ) + + def _create_model_instance(self, model_class, **kwargs): + return create_model_instance(model_class, skip_validation=self.skip_validation, **kwargs) diff --git a/tests/test_models.py b/tests/test_models.py index 7b005d88..bfe22b10 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,6 +1,7 @@ from r2x.enums import PrimeMoversType from r2x.models import Generator, ACBus, Emission, HydroPumpedStorage, ThermalStandard from r2x.models import MinMax +from r2x.parser.handler import create_model_instance from r2x.units import EmissionRate, ureg @@ -48,3 +49,17 @@ def test_serialize_active_power_limits(): output = generator.serialize_active_power_limits(active_power_limits) assert output == {"min": 0, "max": 100} + + +def test_create_model_instance(): + name = "TestGen" + generator = create_model_instance(Generator, name=name) + assert isinstance(generator, Generator) + assert isinstance(generator.name, str) + assert generator.name == name + + name = ["TestGen"] + generator = create_model_instance(Generator, name=name, skip_validation=True) + assert isinstance(generator, Generator) + assert isinstance(generator.name, list) + assert generator.name == name