Skip to content

Commit

Permalink
feat: Add new function that creates an instance of a model with the o…
Browse files Browse the repository at this point in the history
…ption to skip validation (#110)

This PR adds a new method that creates an instance of a model but with
the option skip the validation of pydantic by setting a
`skip_validation` flag.
  • Loading branch information
pesap authored Jan 26, 2025
1 parent 05befac commit dab93ff
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 42 deletions.
34 changes: 26 additions & 8 deletions src/r2x/parser/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,29 @@
"""

# System packages
import json
from copy import deepcopy
import inspect
from dataclasses import dataclass, field
import json
from abc import ABC, abstractmethod
from typing import Any, TypeVar
from collections.abc import Callable, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, TypeVar

import pandas as pd
import polars as pl

# Third-party packages
from infrasys.component import Component
from loguru import logger
import polars as pl
import pandas as pd
from plexosdb import XMLHandler

# Local packages
from r2x.api import System
from r2x.config_scenario import Scenario
from plexosdb import XMLHandler
from .polars_helpers import pl_filter_year, pl_lowercase, pl_rename

from ..utils import check_file_exists
from .polars_helpers import pl_filter_year, pl_lowercase, pl_rename


@dataclass
Expand Down Expand Up @@ -307,3 +310,18 @@ def get_parser_data(
logger.debug("Starting creation of system: {}", config.name)

return parser


def create_model_instance(
model_class: type["Component"], skip_validation: bool = False, **field_values
) -> Any:
"""Create R2X model instance."""
valid_fields = {
key: value
for key, value in field_values.items()
if key in model_class.model_fields
if value is not None
}
if skip_validation:
return model_class.model_construct(**valid_fields) # type: ignore
return model_class.model_validate(valid_fields)
66 changes: 32 additions & 34 deletions src/r2x/parser/reeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from r2x.models.core import MinMax
from r2x.models.costs import HydroGenerationCost, ThermalGenerationCost
from r2x.models.generators import HydroDispatch, HydroEnergyReservoir, RenewableGen, ThermalGen
from r2x.parser.handler import BaseParser
from r2x.parser.handler import BaseParser, create_model_instance
from r2x.units import ActivePower, EmissionRate, Energy, Percentage, Time, ureg
from r2x.utils import get_enum_from_string, match_category, read_csv

Expand Down Expand Up @@ -76,6 +76,7 @@ def __init__(self, *args, **kwargs) -> None:
raise AttributeError("Missing solve year from the configuration class.")
self.device_map = self.reeds_config.defaults["reeds_device_map"]
self.weather_year: int = self.reeds_config.weather_year
self.skip_validation: bool = getattr(self.reeds_config, "skip_validation", False)

# Add hourly_time_index
self.hourly_time_index = np.arange(
Expand Down Expand Up @@ -120,14 +121,15 @@ def _construct_buses(self):

zones = bus_data["transmission_region"].unique()
for zone in zones:
self.system.add_component(LoadZone(name=zone))
self.system.add_component(self._create_model_instance(LoadZone, name=zone))

for area in bus_data["state"].unique():
self.system.add_component(Area(name=area))
self.system.add_component(self._create_model_instance(Area, name=area))

for idx, bus in enumerate(bus_data.iter_rows(named=True)):
self.system.add_component(
ACBus(
self._create_model_instance(
ACBus,
number=idx + 1,
name=bus["region"],
area=self.system.get_component(Area, name=bus["state"]),
Expand All @@ -149,7 +151,8 @@ def _construct_reserves(self):
vors = self.reeds_config.defaults["reserve_vors"].get(name)
reserve_area = self.system.get_component(LoadZone, name=reserve)
self.system.add_component(
Reserve(
self._create_model_instance(
Reserve,
name=f"{reserve}_{name}",
region=reserve_area,
reserve_type=ReserveType[name],
Expand All @@ -161,7 +164,7 @@ def _construct_reserves(self):
)
)
# Add reserve map
self.system.add_component(ReserveMap(name="reserve_map"))
self.system.add_component(self._create_model_instance(ReserveMap, name="reserve_map"))

def _construct_branches(self):
logger.info("Creating branch objects.")
Expand Down Expand Up @@ -200,7 +203,8 @@ def _construct_branches(self):

losses = branch["losses"] if branch["losses"] else 0
self.system.add_component(
MonitoredLine(
self._create_model_instance(
MonitoredLine,
category=branch["kind"],
name=branch_name,
from_bus=from_bus,
Expand All @@ -218,7 +222,7 @@ def _construct_tx_interfaces(self):
MonitoredLine, filter_func=lambda x: x.from_bus.load_zone.name != x.to_bus.load_zone.name
)
interfaces = defaultdict(dict) # Holder of interfaces
tx_interface_map = TransmissionInterfaceMap(name="transmission_map")
tx_interface_map = self._create_model_instance(TransmissionInterfaceMap, name="transmission_map")
for line in interface_lines:
zone_from = line.from_bus.load_zone.name
zone_to = line.to_bus.load_zone.name
Expand Down Expand Up @@ -249,7 +253,8 @@ def _construct_tx_interfaces(self):
# Ramp multiplier defines the MW/min ratio for the interface
ramp_multiplier = self.reeds_config.defaults["interface_max_ramp_up_multiplier"]
self.system.add_component(
TransmissionInterface(
self._create_model_instance(
TransmissionInterface,
name=interface_name,
active_power_flow_limits=MinMax(-max_power_flow, max_power_flow),
direction_mapping={}, # TBD
Expand Down Expand Up @@ -287,11 +292,8 @@ def _construct_emissions(self) -> None:
generator_emission = emit_rates.filter(pl.col("generator_name") == generator.name)
for row in generator_emission.iter_rows(named=True):
row["rate"] = EmissionRate(row["rate"], "kg/MWh")
valid_fields = {key: value for key, value in row.items() if key in Emission.model_fields}
valid_fields["emission_type"] = get_enum_from_string(
valid_fields["emission_type"], EmissionType
)
emission_model = Emission(**valid_fields)
row["emission_type"] = get_enum_from_string(row["emission_type"], EmissionType)
emission_model = self._create_model_instance(Emission, **row)
self.system.add_component(emission_model)

def _construct_generators(self) -> None: # noqa: C901
Expand Down Expand Up @@ -489,23 +491,13 @@ def _construct_generators(self) -> None: # noqa: C901
# will need to change it.
row["active_power_limits"] = MinMax(min=0, max=row["active_power"])

valid_fields = {
key: value for key, value in row.items() if key in gen_model.model_fields if value is not None
row["ext"] = {}
row["ext"] = {
"tech": row["tech"],
"reeds_tech": row["tech"],
"reeds_vintage": row["tech_vintage"],
}
valid_fields["ext"] = {}
# valid_fields["ext"] = {
# key: value for key, value in row.items() if key not in valid_fields if value
# }

valid_fields["ext"].update(
{
"tech": row["tech"],
"reeds_tech": row["tech"],
"reeds_vintage": row["tech_vintage"],
}
)

self.system.add_component(gen_model(**valid_fields))
self.system.add_component(self._create_model_instance(gen_model, **row))

def _construct_load(self):
logger.info("Adding load time series.")
Expand All @@ -531,7 +523,9 @@ def _construct_load(self):
)
user_dict = {"solve_year": self.reeds_config.weather_year}
max_load = np.max(ts.data)
load = PowerLoad(name=f"{bus.name}", bus=bus, max_active_power=max_load)
load = self._create_model_instance(
PowerLoad, name=f"{bus.name}", bus=bus, max_active_power=max_load
)
self.system.add_component(load)
self.system.add_time_series(ts, load, **user_dict)

Expand Down Expand Up @@ -841,7 +835,8 @@ def _construct_hybrid_systems(self):
storage_unit_fields["prime_mover_type"] = PrimeMoversType.BA

# If at some point we change the power of the storage it should be here
storage_unit = GenericBattery(
storage_unit = self._create_model_instance(
GenericBattery,
name=f"{hybrid_name}",
active_power=device.active_power, # Assume same power for the battery
category="pvb-storage",
Expand All @@ -856,12 +851,15 @@ def _construct_hybrid_systems(self):
ts = self.system.get_time_series(device)
self.system.add_time_series(ts, new_device)
self.system.remove_component(device)
hybrid_construct = HybridSystem(
name=f"{hybrid_name}", renewable_unit=new_device, storage_unit=storage_unit
hybrid_construct = self._create_model_instance(
HybridSystem, name=f"{hybrid_name}", renewable_unit=new_device, storage_unit=storage_unit
)
self.system.add_component(hybrid_construct)

def _aggregate_renewable_generators(self, data: pl.DataFrame) -> pl.DataFrame:
return data.group_by(["tech", "region"]).agg(
[pl.col("active_power").sum(), pl.exclude("active_power").first()]
)

def _create_model_instance(self, model_class, **kwargs):
return create_model_instance(model_class, skip_validation=self.skip_validation, **kwargs)
15 changes: 15 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from r2x.enums import PrimeMoversType
from r2x.models import Generator, ACBus, Emission, HydroPumpedStorage, ThermalStandard
from r2x.models import MinMax
from r2x.parser.handler import create_model_instance
from r2x.units import EmissionRate, ureg


Expand Down Expand Up @@ -48,3 +49,17 @@ def test_serialize_active_power_limits():

output = generator.serialize_active_power_limits(active_power_limits)
assert output == {"min": 0, "max": 100}


def test_create_model_instance():
name = "TestGen"
generator = create_model_instance(Generator, name=name)
assert isinstance(generator, Generator)
assert isinstance(generator.name, str)
assert generator.name == name

name = ["TestGen"]
generator = create_model_instance(Generator, name=name, skip_validation=True)
assert isinstance(generator, Generator)
assert isinstance(generator.name, list)
assert generator.name == name

0 comments on commit dab93ff

Please sign in to comment.