Skip to content

Commit

Permalink
More peculiar handling of resource devices
Browse files Browse the repository at this point in the history
  • Loading branch information
hylje committed Dec 1, 2023
1 parent bc5b464 commit 4bc08ab
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 18 deletions.
15 changes: 13 additions & 2 deletions tests/test_workload_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
"devices": {"foo": 1, "bar": 2},
}

RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES: dict = {
"devices": {},
}


def test_create_resources():
"""All YAML properties are correctly parsed into the object."""
Expand All @@ -36,7 +40,7 @@ def test_create_resources():
assert resources.memory.max == 20

assert isinstance(resources.devices, ResourceDevices)
assert resources.devices.get_data() == {"foo": 1, "bar": 2}
assert resources.devices.get_data_or_none() == {"foo": 1, "bar": 2}


def test_missing_resources():
Expand All @@ -53,12 +57,19 @@ def test_missing_resources():
assert resources.memory.max is None

assert resources.devices is not None
assert resources.devices.devices == {}
assert resources.devices.devices is None

# the empty dict-initialized resources also serialize back into an empty dict
assert resources.serialize() == {}


def test_cleared_devices():
resources = WorkloadResources.parse(RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES)

assert resources.devices.devices == {}
assert resources.serialize() == RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES


@pytest.mark.parametrize(
"resource_name,missing_key",
[
Expand Down
44 changes: 28 additions & 16 deletions valohai_yaml/objs/workload_resources.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
from typing import Dict, Optional
from __future__ import annotations

from valohai_yaml.objs.base import Item
from valohai_yaml.types import SerializedDict


class ResourceCPU(Item):
class WorkloadResourceItem(Item):
"""
Adds get_data_or_none method supporting distinction between an empty SerializedDict and None.
The method allows defining an empty devices dictionary for ResourceDevices that will override default values,
if any are defined. Other subclasses will get default behaviour.
"""

def get_data_or_none(self) -> SerializedDict | None:
return self.get_data() or None


class ResourceCPU(WorkloadResourceItem):
"""CPU configuration."""

def __init__(
self,
max: Optional[int] = None,
min: Optional[int] = None,
max: int | None = None,
min: int | None = None,
) -> None:
self.max = max
self.min = min
Expand All @@ -25,13 +37,13 @@ def get_data(self) -> SerializedDict:
}


class ResourceMemory(Item):
class ResourceMemory(WorkloadResourceItem):
"""Memory configuration."""

def __init__(
self,
max: Optional[int] = None,
min: Optional[int] = None,
max: int | None = None,
min: int | None = None,
) -> None:
self.max = max
self.min = min
Expand All @@ -46,20 +58,20 @@ def get_data(self) -> SerializedDict:
}


class ResourceDevices(Item):
class ResourceDevices(WorkloadResourceItem):
"""Devices configuration."""

def __init__(self, devices: SerializedDict) -> None:
def __init__(self, devices: SerializedDict | None) -> None:
"""
Devices list device name: nr of devices.
Keys (and number of items) unknown, e.g.:
'nvidia.com/cpu': 2, 'nvidia.com/gpu': 1.
"""
self.devices: Dict[str, int] = devices
self.devices: dict[str, int] | None = devices

@classmethod
def parse(cls, data: SerializedDict) -> "ResourceDevices":
def parse(cls, data: SerializedDict | None) -> ResourceDevices:
"""
Initialize a devices resource.
Expand All @@ -72,7 +84,7 @@ def __repr__(self) -> str:
"""List the devices."""
return f"ResourceDevices({self.devices})"

def get_data(self) -> SerializedDict:
def get_data_or_none(self) -> SerializedDict | None:
return self.devices


Expand All @@ -96,10 +108,10 @@ def __init__(
self.devices = devices

@classmethod
def parse(cls, data: SerializedDict) -> "WorkloadResources":
def parse(cls, data: SerializedDict) -> WorkloadResources:
cpu_data = data.get("cpu", {})
memory_data = data.get("memory", {})
device_data = data.get("devices", {})
device_data = data.get("devices")
data_with_resources = dict(
data,
cpu=ResourceCPU.parse(cpu_data),
Expand All @@ -111,8 +123,8 @@ def parse(cls, data: SerializedDict) -> "WorkloadResources":
def get_data(self) -> SerializedDict:
data = {}
for key, value in super().get_data().items():
item_data = value.get_data()
if item_data:
item_data = value.get_data_or_none()
if item_data is not None:
data[key] = item_data
return data

Expand Down

0 comments on commit 4bc08ab

Please sign in to comment.