From 0c533248505470125922403885e0f6e9df4e5801 Mon Sep 17 00:00:00 2001 From: Leo Honkanen Date: Fri, 1 Dec 2023 13:58:27 +0200 Subject: [PATCH] More peculiar handling of resource devices --- tests/test_workload_resources.py | 13 +++++++++- valohai_yaml/objs/workload_resources.py | 34 +++++++++++++++++-------- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/tests/test_workload_resources.py b/tests/test_workload_resources.py index 71ced92..11ee187 100644 --- a/tests/test_workload_resources.py +++ b/tests/test_workload_resources.py @@ -22,6 +22,10 @@ "devices": {"foo": 1, "bar": 2}, } +RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES: dict = { + "devices": {}, +} + def test_create_resources(): """All YAML properties are correctly parsed into the object.""" @@ -53,12 +57,19 @@ def test_missing_resources(): assert resources.memory.max is None assert resources.devices is not None - assert resources.devices.devices == {} + assert resources.devices.devices is None # the empty dict-initialized resources also serialize back into an empty dict assert resources.serialize() == {} +def test_cleared_devices(): + resources = WorkloadResources.parse(RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES) + + assert resources.devices.devices == {} + assert resources.serialize() == RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES + + @pytest.mark.parametrize( "resource_name,missing_key", [ diff --git a/valohai_yaml/objs/workload_resources.py b/valohai_yaml/objs/workload_resources.py index 651ed6e..44e47cb 100644 --- a/valohai_yaml/objs/workload_resources.py +++ b/valohai_yaml/objs/workload_resources.py @@ -1,9 +1,21 @@ -from typing import Dict, Optional +import functools +from typing import Any, Callable, Dict, Optional from valohai_yaml.objs.base import Item from valohai_yaml.types import SerializedDict +def none_if_empty(f: Callable) -> Callable: + @functools.wraps(f) + def wrap(*args: list, **kwargs: dict) -> Any: + ret = f(*args, **kwargs) + if not ret: + return None + return ret + + return wrap + + class ResourceCPU(Item): """CPU configuration.""" @@ -19,7 +31,8 @@ def __repr__(self) -> str: """CPU data.""" return f'ResourceCPU("max": {self.max}, "min": {self.min})' - def get_data(self) -> SerializedDict: + @none_if_empty + def get_data_or_none(self) -> SerializedDict: return { key: value for key, value in super().get_data().items() if value is not None } @@ -40,7 +53,8 @@ def __repr__(self) -> str: """Memory data.""" return f'ResourceMemory("max": {self.max}, "min": {self.min})' - def get_data(self) -> SerializedDict: + @none_if_empty + def get_data_or_none(self) -> SerializedDict: return { key: value for key, value in super().get_data().items() if value is not None } @@ -49,17 +63,17 @@ def get_data(self) -> SerializedDict: class ResourceDevices(Item): """Devices configuration.""" - def __init__(self, devices: SerializedDict) -> None: + def __init__(self, devices: SerializedDict | None) -> None: """ Devices list device name: nr of devices. Keys (and number of items) unknown, e.g.: 'nvidia.com/cpu': 2, 'nvidia.com/gpu': 1. """ - self.devices: Dict[str, int] = devices + self.devices: Optional[Dict[str, int]] = devices @classmethod - def parse(cls, data: SerializedDict) -> "ResourceDevices": + def parse(cls, data: SerializedDict | None) -> "ResourceDevices": """ Initialize a devices resource. @@ -72,7 +86,7 @@ def __repr__(self) -> str: """List the devices.""" return f"ResourceDevices({self.devices})" - def get_data(self) -> SerializedDict: + def get_data_or_none(self) -> SerializedDict | None: return self.devices @@ -99,7 +113,7 @@ def __init__( def parse(cls, data: SerializedDict) -> "WorkloadResources": cpu_data = data.get("cpu", {}) memory_data = data.get("memory", {}) - device_data = data.get("devices", {}) + device_data = data.get("devices") data_with_resources = dict( data, cpu=ResourceCPU.parse(cpu_data), @@ -111,8 +125,8 @@ def parse(cls, data: SerializedDict) -> "WorkloadResources": def get_data(self) -> SerializedDict: data = {} for key, value in super().get_data().items(): - item_data = value.get_data() - if item_data: + item_data = value.get_data_or_none() + if item_data is not None: data[key] = item_data return data