Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse and serialize Step resources as WorkloadResources #140

Merged
merged 2 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tests/test_serialize_step.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def test_serialize_workload_resources(step_with_resources):
"""Must not flatten workload resource data."""
config = step_with_resources
resources = config.steps["contains kubernetes resources"].resources
resources = config.steps["contains kubernetes resources"].serialize()["resources"]

assert isinstance(resources, dict), "Resources should be defined."
assert "cpu" in resources, "Resources should contain data."
Expand All @@ -10,7 +10,9 @@ def test_serialize_workload_resources(step_with_resources):
def test_serialize_partial_resources(step_with_partial_resources):
"""Serialized data only contains keys found in the config."""
config = step_with_partial_resources
resources = config.steps["contains partial workload resources"].resources
resources = config.steps["contains partial workload resources"].serialize()[
"resources"
]

assert "min" in resources["cpu"]
assert "max" not in resources["cpu"]
32 changes: 27 additions & 5 deletions tests/test_workload_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
"devices": {"foo": 1, "bar": 2},
}

RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES: dict = {
"devices": {},
}


def test_create_resources():
"""All YAML properties are correctly parsed into the object."""
Expand All @@ -36,16 +40,34 @@ def test_create_resources():
assert resources.memory.max == 20

assert isinstance(resources.devices, ResourceDevices)
assert resources.devices.get_data() == {"foo": 1, "bar": 2}
assert resources.devices.get_data_or_none() == {"foo": 1, "bar": 2}


def test_missing_resources():
"""None of the workload properties are required."""
resources = WorkloadResources.parse(OrderedDict([]))

assert resources.cpu is None
assert resources.memory is None
assert resources.devices is None
# Subresources are created with None/empty leaf values
assert resources.cpu is not None
hylje marked this conversation as resolved.
Show resolved Hide resolved
assert resources.cpu.min is None
assert resources.cpu.max is None

assert resources.memory is not None
assert resources.memory.min is None
assert resources.memory.max is None

assert resources.devices is not None
assert resources.devices.devices is None

# the empty dict-initialized resources also serialize back into an empty dict
assert resources.serialize() == {}


def test_cleared_devices():
resources = WorkloadResources.parse(RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES)

assert resources.devices.devices == {}
assert resources.serialize() == RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES


@pytest.mark.parametrize(
Expand All @@ -62,7 +84,7 @@ def test_missing_sub_resources(resource_name, missing_key):
resources = create_resources(resource_name, missing_key)

for this_resource_name, sub_resources in resources.get_data().items():
for name, value in sub_resources.get_data().items():
for name, value in sub_resources.items():
if this_resource_name == resource_name and name == missing_key:
assert value is None
else:
Expand Down
5 changes: 3 additions & 2 deletions valohai_yaml/objs/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(

self.time_limit = time_limit
self.no_output_timeout = no_output_timeout
self.resources = resources
self.resources = resources if resources else WorkloadResources.parse({})
hylje marked this conversation as resolved.
Show resolved Hide resolved
self.stop_condition = stop_condition

@classmethod
Expand All @@ -84,6 +84,7 @@ def parse(cls, data: SerializedDict) -> "Step":
kwargs["source_path"] = kwargs.pop("source-path", None)
kwargs["stop_condition"] = kwargs.pop("stop-condition", None)
kwargs["upload_store"] = kwargs.pop("upload-store", None)
kwargs["resources"] = WorkloadResources.parse(kwargs.pop("resources", {}))
inst = cls(**kwargs)
inst._original_data = data
return inst
Expand Down Expand Up @@ -121,7 +122,7 @@ def serialize(self) -> OrderedDict: # type: ignore[type-arg]
("icon", self.icon),
("category", self.category),
("source-path", self.source_path),
("resources", self.resources),
("resources", self.resources.get_data() or None),
("stop-condition", self.stop_condition),
("upload-store", self.upload_store),
]:
Expand Down
72 changes: 51 additions & 21 deletions valohai_yaml/objs/workload_resources.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
from typing import Dict, Optional
from __future__ import annotations

from valohai_yaml.objs.base import Item
from valohai_yaml.types import SerializedDict


class ResourceCPU(Item):
class WorkloadResourceItem(Item):
"""
Adds get_data_or_none method supporting distinction between an empty SerializedDict and None.

The method allows defining an empty devices dictionary for ResourceDevices that will override default values,
if any are defined. Other subclasses will get default behaviour.
"""

def get_data_or_none(self) -> SerializedDict | None:
return self.get_data() or None


class ResourceCPU(WorkloadResourceItem):
"""CPU configuration."""

def __init__(
self,
max: Optional[int] = None,
min: Optional[int] = None,
max: int | None = None,
min: int | None = None,
) -> None:
self.max = max
self.min = min
Expand All @@ -19,14 +31,19 @@ def __repr__(self) -> str:
"""CPU data."""
return f'ResourceCPU("max": {self.max}, "min": {self.min})'

def get_data(self) -> SerializedDict:
return {
key: value for key, value in super().get_data().items() if value is not None
}

class ResourceMemory(Item):

class ResourceMemory(WorkloadResourceItem):
"""Memory configuration."""

def __init__(
self,
max: Optional[int] = None,
min: Optional[int] = None,
max: int | None = None,
min: int | None = None,
) -> None:
self.max = max
self.min = min
Expand All @@ -35,21 +52,26 @@ def __repr__(self) -> str:
"""Memory data."""
return f'ResourceMemory("max": {self.max}, "min": {self.min})'

def get_data(self) -> SerializedDict:
return {
key: value for key, value in super().get_data().items() if value is not None
}


class ResourceDevices(Item):
class ResourceDevices(WorkloadResourceItem):
"""Devices configuration."""

def __init__(self, devices: SerializedDict) -> None:
def __init__(self, devices: SerializedDict | None) -> None:
"""
Devices list device name: nr of devices.

Keys (and number of items) unknown, e.g.:
'nvidia.com/cpu': 2, 'nvidia.com/gpu': 1.
"""
self.devices: Dict[str, int] = devices
self.devices: dict[str, int] | None = devices

@classmethod
def parse(cls, data: SerializedDict) -> "ResourceDevices":
def parse(cls, data: SerializedDict | None) -> ResourceDevices:
"""
Initialize a devices resource.

Expand All @@ -62,7 +84,7 @@ def __repr__(self) -> str:
"""List the devices."""
return f"ResourceDevices({self.devices})"

def get_data(self) -> SerializedDict:
def get_data_or_none(self) -> SerializedDict | None:
return self.devices


Expand All @@ -77,27 +99,35 @@ class WorkloadResources(Item):
def __init__(
self,
*,
cpu: Optional[ResourceCPU],
memory: Optional[ResourceMemory],
devices: Optional[ResourceDevices],
cpu: ResourceCPU,
memory: ResourceMemory,
devices: ResourceDevices,
) -> None:
self.cpu = cpu
self.memory = memory
self.devices = devices

@classmethod
def parse(cls, data: SerializedDict) -> "WorkloadResources":
cpu_data = data.get("cpu")
memory_data = data.get("memory")
def parse(cls, data: SerializedDict) -> WorkloadResources:
cpu_data = data.get("cpu", {})
memory_data = data.get("memory", {})
device_data = data.get("devices")
data_with_resources = dict(
data,
cpu=ResourceCPU.parse(cpu_data) if cpu_data else None,
memory=ResourceMemory.parse(memory_data) if memory_data else None,
devices=ResourceDevices.parse(device_data) if device_data else None,
cpu=ResourceCPU.parse(cpu_data),
memory=ResourceMemory.parse(memory_data),
devices=ResourceDevices.parse(device_data),
)
return super().parse(data_with_resources)

def get_data(self) -> SerializedDict:
data = {}
for key, value in super().get_data().items():
item_data = value.get_data_or_none()
if item_data is not None:
data[key] = item_data
return data

def __repr__(self) -> str:
"""Resources contents."""
return f'WorkloadResources("cpu": {self.cpu}, "memory": {self.memory}, "devices": {self.devices})'