Skip to content

Commit

Permalink
Merge pull request #107 from NREL/feature/serde-api-features
Browse files Browse the repository at this point in the history
Feature/serde api features
  • Loading branch information
calbaker authored Jan 16, 2025
2 parents f9a1d7e + 47df893 commit 8833767
Show file tree
Hide file tree
Showing 28 changed files with 366 additions and 84 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/py-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.9', '3.10', '3.11']
python-version: ['3.10', '3.11']

env:
PYTHON: ${{ matrix.python-version }}
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/wheels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ jobs:
- macos
- windows
python-version:
- "9"
- "10"
- "11"
include:
Expand All @@ -36,7 +35,7 @@ jobs:
- name: set up python
uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.11"

- name: set up rust
if: matrix.os != 'ubuntu'
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ authors = [
description = "Tool for modeling and optimization of advanced locomotive powertrains for freight rail decarbonization."
readme = "README.md"
license = { file = "LICENSE.md" }
requires-python = ">=3.9, <3.12"
requires-python = ">=3.10, <3.12"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
Expand All @@ -47,6 +47,7 @@ dependencies = [
"pyarrow",
"requests",
"PyYAML==6.0.2",
"msgpack==1.1.0",
]

[project.urls]
Expand Down
75 changes: 62 additions & 13 deletions python/altrios/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,27 +119,76 @@ def history_path_list(self, element_as_list:bool=False) -> List[str]:
item for item in self.variable_path_list(
element_as_list=element_as_list) if "history" in item_str(item)
]
return history_path_list

def to_pydict(self) -> Dict:
return history_path_list

# TODO connect to crate features
data_formats = [
'yaml',
'msg_pack',
# 'toml',
'json',
]

def to_pydict(self, data_fmt: str = "msg_pack", flatten: bool = False) -> Dict:
"""
Returns self converted to pure python dictionary with no nested Rust objects
# Arguments
- `flatten`: if True, returns dict without any hierarchy
- `data_fmt`: data format for intermediate conversion step
"""
from yaml import load
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
pydict = load(self.to_yaml(), Loader = Loader)
return pydict
data_fmt = data_fmt.lower()
assert data_fmt in data_formats, f"`data_fmt` must be one of {data_formats}"
match data_fmt:
case "msg_pack":
import msgpack
pydict = msgpack.loads(self.to_msg_pack())
case "yaml":
from yaml import load
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
pydict = load(self.to_yaml(), Loader=Loader)
case "json":
from json import loads
pydict = loads(self.to_json())

if not flatten:
return pydict
else:
return next(iter(pd.json_normalize(pydict, sep=".").to_dict(orient='records')))

@classmethod
def from_pydict(cls, pydict: Dict) -> Self:
def from_pydict(cls, pydict: Dict, data_fmt: str = "msg_pack", skip_init: bool = True) -> Self:
"""
Instantiates Self from pure python dictionary
# Arguments
- `pydict`: dictionary to be converted to ALTRIOS object
- `data_fmt`: data format for intermediate conversion step
- `skip_init`: passed to `SerdeAPI` methods to control whether initialization
is skipped
"""
import yaml
return cls.from_yaml(yaml.dump(pydict),skip_init=False)
data_fmt = data_fmt.lower()
assert data_fmt in data_formats, f"`data_fmt` must be one of {data_formats}"
match data_fmt.lower():
case "yaml":
import yaml
obj = cls.from_yaml(yaml.dump(pydict), skip_init=skip_init)
case "msg_pack":
import msgpack
try:
obj = cls.from_msg_pack(
msgpack.packb(pydict), skip_init=skip_init)
except Exception as err:
print(
f"{err}\nFalling back to YAML.")
obj = cls.from_pydict(
pydict, data_fmt="yaml", skip_init=skip_init)
case "json":
from json import dumps
obj = cls.from_json(dumps(pydict), skip_init=skip_init)

return obj

def to_dataframe(self, pandas:bool=False) -> [pd.DataFrame, pl.DataFrame, pl.LazyFrame]:
"""
Expand Down
5 changes: 4 additions & 1 deletion python/altrios/altrios_pyo3.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ class SerdeAPI(object):
@classmethod
def from_yaml(cls) -> Self: ...
@classmethod
def from_file(cls) -> Self: ...
def from_file(cls, skip_init=False) -> Self: ...
def to_file(self): ...
def to_bincode(self) -> bytes: ...
def to_json(self) -> str: ...
def to_yaml(self) -> str: ...
def to_pydict(self, data_fmt: str = "msg_pack", flatten: bool = False) -> Dict: ...
@classmethod
def from_pydict(cls, pydict: Dict, data_fmt: str = "msg_pack") -> Self:


class Consist(SerdeAPI):
Expand Down
2 changes: 1 addition & 1 deletion python/altrios/demos/sim_manager_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
t0_import = time.perf_counter()
t0_total = time.perf_counter()

rail_vehicles=[alt.RailVehicle.from_file(vehicle_file)
rail_vehicles=[alt.RailVehicle.from_file(vehicle_file, skip_init=False)
for vehicle_file in Path(alt.resources_root() / "rolling_stock/").glob('*.yaml')]

location_map = alt.import_locations(alt.resources_root() / "networks/default_locations.csv")
Expand Down
4 changes: 2 additions & 2 deletions python/altrios/demos/version_migration_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def migrate_network() -> Tuple[alt.Network, alt.Network]:
old_network_path = alt.resources_root() / "networks/Taconite_v0.1.6.yaml"
new_network_path = alt.resources_root() / "networks/Taconite.yaml"

network_from_old = alt.Network.from_file(old_network_path)
network_from_new = alt.Network.from_file(new_network_path)
network_from_old = alt.Network.from_file(old_network_path, skip_init=False)
network_from_new = alt.Network.from_file(new_network_path, skip_init=False)

# `network_from_old` could be used to overwrite the file in the new format with
# ```
Expand Down
4 changes: 2 additions & 2 deletions python/altrios/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ def simulate_prescribed_rollout(
else:
demand_paths.append(demand_file)

rail_vehicles=[alt.RailVehicle.from_file(vehicle_file)
rail_vehicles=[alt.RailVehicle.from_file(vehicle_file, skip_init=False)
for vehicle_file in Path(alt.resources_root() / "rolling_stock/").glob('*.yaml')]

location_map = alt.import_locations(
str(alt.resources_root() / "networks/default_locations.csv")
)
network = alt.Network.from_file(network_filename_path)
network = alt.Network.from_file(network_filename_path, skip_init=False)
sim_days = defaults.SIMULATION_DAYS
scenarios = []
for idx, scenario_year in enumerate(years):
Expand Down
125 changes: 125 additions & 0 deletions python/altrios/tests/test_serde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import time
import altrios as alt

SAVE_INTERVAL = 100
def get_solved_speed_limit_train_sim():
# Build the train config
rail_vehicle_loaded = alt.RailVehicle.from_file(
alt.resources_root() / "rolling_stock/Manifest_Loaded.yaml")
rail_vehicle_empty = alt.RailVehicle.from_file(
alt.resources_root() / "rolling_stock/Manifest_Empty.yaml")

# https://docs.rs/altrios-core/latest/altrios_core/train/struct.TrainConfig.html
train_config = alt.TrainConfig(
rail_vehicles=[rail_vehicle_loaded, rail_vehicle_empty],
n_cars_by_type={
"Manifest_Loaded": 50,
"Manifest_Empty": 50,
},
train_length_meters=None,
train_mass_kilograms=None,
)

# Build the locomotive consist model
# instantiate battery model
# https://docs.rs/altrios-core/latest/altrios_core/consist/locomotive/powertrain/reversible_energy_storage/struct.ReversibleEnergyStorage.html#
res = alt.ReversibleEnergyStorage.from_file(
alt.resources_root() / "powertrains/reversible_energy_storages/Kokam_NMC_75Ah_flx_drive.yaml"
)

edrv = alt.ElectricDrivetrain(
pwr_out_frac_interp=[0., 1.],
eta_interp=[0.98, 0.98],
pwr_out_max_watts=5e9,
save_interval=SAVE_INTERVAL,
)

bel: alt.Locomotive = alt.Locomotive.build_battery_electric_loco(
reversible_energy_storage=res,
drivetrain=edrv,
loco_params=alt.LocoParams.from_dict(dict(
pwr_aux_offset_watts=8.55e3,
pwr_aux_traction_coeff_ratio=540.e-6,
force_max_newtons=667.2e3,
)))

# construct a vector of one BEL and several conventional locomotives
loco_vec = [bel.clone()] + [alt.Locomotive.default()] * 7
# instantiate consist
loco_con = alt.Consist(
loco_vec
)

# Instantiate the intermediate `TrainSimBuilder`
tsb = alt.TrainSimBuilder(
train_id="0",
origin_id="A",
destination_id="B",
train_config=train_config,
loco_con=loco_con,
)

# Load the network and construct the timed link path through the network.
network = alt.Network.from_file(
alt.resources_root() / 'networks/simple_corridor_network.yaml')

location_map = alt.import_locations(
alt.resources_root() / "networks/simple_corridor_locations.csv")
train_sim: alt.SetSpeedTrainSim = tsb.make_speed_limit_train_sim(
location_map=location_map,
save_interval=1,
)
train_sim.set_save_interval(SAVE_INTERVAL)
est_time_net, _consist = alt.make_est_times(train_sim, network)

timed_link_path = alt.run_dispatch(
network,
alt.SpeedLimitTrainSimVec([train_sim]),
[est_time_net],
False,
False,
)[0]

train_sim.walk_timed_path(
network=network,
timed_path=timed_link_path,
)
assert len(train_sim.history) > 1

return train_sim


def test_pydict():
ts = get_solved_speed_limit_train_sim()

t0 = time.perf_counter_ns()
ts_dict_msg = ts.to_pydict(flatten=False, data_fmt="msg_pack")
ts_msg = alt.SpeedLimitTrainSim.from_pydict(
ts_dict_msg, data_fmt="msg_pack")
t1 = time.perf_counter_ns()
t_msg = t1 - t0
print(f"\nElapsed time for MessagePack: {t_msg:.3e} ns ")

t0 = time.perf_counter_ns()
ts_dict_yaml = ts.to_pydict(flatten=False, data_fmt="yaml")
ts_yaml = alt.SpeedLimitTrainSim.from_pydict(ts_dict_yaml, data_fmt="yaml")
t1 = time.perf_counter_ns()
t_yaml = t1 - t0
print(f"Elapsed time for YAML: {t_yaml:.3e} ns ")
print(f"YAML time per MessagePack time: {(t_yaml / t_msg):.3e} ")

t0 = time.perf_counter_ns()
ts_dict_json = ts.to_pydict(flatten=False, data_fmt="json")
_ts_json = alt.SpeedLimitTrainSim.from_pydict(
ts_dict_json, data_fmt="json")
t1 = time.perf_counter_ns()
t_json = t1 - t0
print(f"Elapsed time for json: {t_json:.3e} ns ")
print(f"JSON time per MessagePack time: {(t_json / t_msg):.3e} ")

# `to_pydict` is necessary because of some funkiness with direct equality comparison
assert ts_msg.to_pydict() == ts.to_pydict()
assert ts_yaml.to_pydict() == ts.to_pydict()

if __name__ == "__main__":
test_pydict()
2 changes: 1 addition & 1 deletion python/altrios/train_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ def run_train_planner(

if __name__ == "__main__":

rail_vehicles=[alt.RailVehicle.from_file(vehicle_file)
rail_vehicles=[alt.RailVehicle.from_file(vehicle_file, skip_init=False)
for vehicle_file in Path(alt.resources_root() / "rolling_stock/").glob('*.yaml')]

location_map = alt.import_locations(
Expand Down
23 changes: 23 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion rust/altrios-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ rust-version = { workspace = true }
[dependencies]
csv = "1.1.6"
serde = { version = "1.0.136", features = ["derive"] }
rmp-serde = { version = "1.3.0", optional = true }
serde_yaml = "0.8.23"
serde_json = "1.0"
uom = { workspace = true, features = ["use_serde"] }
Expand Down Expand Up @@ -56,9 +57,13 @@ tempfile = "3.10.1"
derive_more = { version = "1.0.0", features = ["from_str", "from", "is_variant", "try_into"] }

[features]
default = []
default = ["serde-default"]
## Enables several text file formats for serialization and deserialization
serde-default = ["msgpack"]
## Exposes ALTRIOS structs, methods, and functions to Python.
pyo3 = ["dep:pyo3"]
## Enables message pack serialization and deserialization via `rmp-serde`
msgpack = ["dep:rmp-serde"]

[lints.rust]
# `'cfg(debug_advance_rewind)'` is expected for debugging in `advance_rewind.rs`
Expand Down
Loading

0 comments on commit 8833767

Please sign in to comment.