From 4bc08ab555cec5b2defa3622174a5e7a3d8fd9dc Mon Sep 17 00:00:00 2001 From: Leo Honkanen <leo@valohai.com> Date: Fri, 1 Dec 2023 13:58:27 +0200 Subject: [PATCH] More peculiar handling of resource devices --- tests/test_workload_resources.py | 15 +++++++-- valohai_yaml/objs/workload_resources.py | 44 ++++++++++++++++--------- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/tests/test_workload_resources.py b/tests/test_workload_resources.py index 71ced92..5b80b4c 100644 --- a/tests/test_workload_resources.py +++ b/tests/test_workload_resources.py @@ -22,6 +22,10 @@ "devices": {"foo": 1, "bar": 2}, } +RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES: dict = { + "devices": {}, +} + def test_create_resources(): """All YAML properties are correctly parsed into the object.""" @@ -36,7 +40,7 @@ def test_create_resources(): assert resources.memory.max == 20 assert isinstance(resources.devices, ResourceDevices) - assert resources.devices.get_data() == {"foo": 1, "bar": 2} + assert resources.devices.get_data_or_none() == {"foo": 1, "bar": 2} def test_missing_resources(): @@ -53,12 +57,19 @@ def test_missing_resources(): assert resources.memory.max is None assert resources.devices is not None - assert resources.devices.devices == {} + assert resources.devices.devices is None # the empty dict-initialized resources also serialize back into an empty dict assert resources.serialize() == {} +def test_cleared_devices(): + resources = WorkloadResources.parse(RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES) + + assert resources.devices.devices == {} + assert resources.serialize() == RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES + + @pytest.mark.parametrize( "resource_name,missing_key", [ diff --git a/valohai_yaml/objs/workload_resources.py b/valohai_yaml/objs/workload_resources.py index 651ed6e..4a227e6 100644 --- a/valohai_yaml/objs/workload_resources.py +++ b/valohai_yaml/objs/workload_resources.py @@ -1,16 +1,28 @@ -from typing import Dict, Optional +from __future__ import annotations from valohai_yaml.objs.base import Item from valohai_yaml.types import SerializedDict -class ResourceCPU(Item): +class WorkloadResourceItem(Item): + """ + Adds get_data_or_none method supporting distinction between an empty SerializedDict and None. + + The method allows defining an empty devices dictionary for ResourceDevices that will override default values, + if any are defined. Other subclasses will get default behaviour. + """ + + def get_data_or_none(self) -> SerializedDict | None: + return self.get_data() or None + + +class ResourceCPU(WorkloadResourceItem): """CPU configuration.""" def __init__( self, - max: Optional[int] = None, - min: Optional[int] = None, + max: int | None = None, + min: int | None = None, ) -> None: self.max = max self.min = min @@ -25,13 +37,13 @@ def get_data(self) -> SerializedDict: } -class ResourceMemory(Item): +class ResourceMemory(WorkloadResourceItem): """Memory configuration.""" def __init__( self, - max: Optional[int] = None, - min: Optional[int] = None, + max: int | None = None, + min: int | None = None, ) -> None: self.max = max self.min = min @@ -46,20 +58,20 @@ def get_data(self) -> SerializedDict: } -class ResourceDevices(Item): +class ResourceDevices(WorkloadResourceItem): """Devices configuration.""" - def __init__(self, devices: SerializedDict) -> None: + def __init__(self, devices: SerializedDict | None) -> None: """ Devices list device name: nr of devices. Keys (and number of items) unknown, e.g.: 'nvidia.com/cpu': 2, 'nvidia.com/gpu': 1. """ - self.devices: Dict[str, int] = devices + self.devices: dict[str, int] | None = devices @classmethod - def parse(cls, data: SerializedDict) -> "ResourceDevices": + def parse(cls, data: SerializedDict | None) -> ResourceDevices: """ Initialize a devices resource. @@ -72,7 +84,7 @@ def __repr__(self) -> str: """List the devices.""" return f"ResourceDevices({self.devices})" - def get_data(self) -> SerializedDict: + def get_data_or_none(self) -> SerializedDict | None: return self.devices @@ -96,10 +108,10 @@ def __init__( self.devices = devices @classmethod - def parse(cls, data: SerializedDict) -> "WorkloadResources": + def parse(cls, data: SerializedDict) -> WorkloadResources: cpu_data = data.get("cpu", {}) memory_data = data.get("memory", {}) - device_data = data.get("devices", {}) + device_data = data.get("devices") data_with_resources = dict( data, cpu=ResourceCPU.parse(cpu_data), @@ -111,8 +123,8 @@ def parse(cls, data: SerializedDict) -> "WorkloadResources": def get_data(self) -> SerializedDict: data = {} for key, value in super().get_data().items(): - item_data = value.get_data() - if item_data: + item_data = value.get_data_or_none() + if item_data is not None: data[key] = item_data return data