From 4bc08ab555cec5b2defa3622174a5e7a3d8fd9dc Mon Sep 17 00:00:00 2001
From: Leo Honkanen <leo@valohai.com>
Date: Fri, 1 Dec 2023 13:58:27 +0200
Subject: [PATCH] More peculiar handling of resource devices

---
 tests/test_workload_resources.py        | 15 +++++++--
 valohai_yaml/objs/workload_resources.py | 44 ++++++++++++++++---------
 2 files changed, 41 insertions(+), 18 deletions(-)

diff --git a/tests/test_workload_resources.py b/tests/test_workload_resources.py
index 71ced92..5b80b4c 100644
--- a/tests/test_workload_resources.py
+++ b/tests/test_workload_resources.py
@@ -22,6 +22,10 @@
     "devices": {"foo": 1, "bar": 2},
 }
 
+RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES: dict = {
+    "devices": {},
+}
+
 
 def test_create_resources():
     """All YAML properties are correctly parsed into the object."""
@@ -36,7 +40,7 @@ def test_create_resources():
     assert resources.memory.max == 20
 
     assert isinstance(resources.devices, ResourceDevices)
-    assert resources.devices.get_data() == {"foo": 1, "bar": 2}
+    assert resources.devices.get_data_or_none() == {"foo": 1, "bar": 2}
 
 
 def test_missing_resources():
@@ -53,12 +57,19 @@ def test_missing_resources():
     assert resources.memory.max is None
 
     assert resources.devices is not None
-    assert resources.devices.devices == {}
+    assert resources.devices.devices is None
 
     # the empty dict-initialized resources also serialize back into an empty dict
     assert resources.serialize() == {}
 
 
+def test_cleared_devices():
+    resources = WorkloadResources.parse(RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES)
+
+    assert resources.devices.devices == {}
+    assert resources.serialize() == RESOURCE_DATA_WITH_DELIBERATE_EMPTY_DEVICES
+
+
 @pytest.mark.parametrize(
     "resource_name,missing_key",
     [
diff --git a/valohai_yaml/objs/workload_resources.py b/valohai_yaml/objs/workload_resources.py
index 651ed6e..4a227e6 100644
--- a/valohai_yaml/objs/workload_resources.py
+++ b/valohai_yaml/objs/workload_resources.py
@@ -1,16 +1,28 @@
-from typing import Dict, Optional
+from __future__ import annotations
 
 from valohai_yaml.objs.base import Item
 from valohai_yaml.types import SerializedDict
 
 
-class ResourceCPU(Item):
+class WorkloadResourceItem(Item):
+    """
+    Adds get_data_or_none method supporting distinction between an empty SerializedDict and None.
+
+    The method allows defining an empty devices dictionary for ResourceDevices that will override default values,
+    if any are defined. Other subclasses will get default behaviour.
+    """
+
+    def get_data_or_none(self) -> SerializedDict | None:
+        return self.get_data() or None
+
+
+class ResourceCPU(WorkloadResourceItem):
     """CPU configuration."""
 
     def __init__(
         self,
-        max: Optional[int] = None,
-        min: Optional[int] = None,
+        max: int | None = None,
+        min: int | None = None,
     ) -> None:
         self.max = max
         self.min = min
@@ -25,13 +37,13 @@ def get_data(self) -> SerializedDict:
         }
 
 
-class ResourceMemory(Item):
+class ResourceMemory(WorkloadResourceItem):
     """Memory configuration."""
 
     def __init__(
         self,
-        max: Optional[int] = None,
-        min: Optional[int] = None,
+        max: int | None = None,
+        min: int | None = None,
     ) -> None:
         self.max = max
         self.min = min
@@ -46,20 +58,20 @@ def get_data(self) -> SerializedDict:
         }
 
 
-class ResourceDevices(Item):
+class ResourceDevices(WorkloadResourceItem):
     """Devices configuration."""
 
-    def __init__(self, devices: SerializedDict) -> None:
+    def __init__(self, devices: SerializedDict | None) -> None:
         """
         Devices list device name: nr of devices.
 
         Keys (and number of items) unknown, e.g.:
         'nvidia.com/cpu': 2, 'nvidia.com/gpu': 1.
         """
-        self.devices: Dict[str, int] = devices
+        self.devices: dict[str, int] | None = devices
 
     @classmethod
-    def parse(cls, data: SerializedDict) -> "ResourceDevices":
+    def parse(cls, data: SerializedDict | None) -> ResourceDevices:
         """
         Initialize a devices resource.
 
@@ -72,7 +84,7 @@ def __repr__(self) -> str:
         """List the devices."""
         return f"ResourceDevices({self.devices})"
 
-    def get_data(self) -> SerializedDict:
+    def get_data_or_none(self) -> SerializedDict | None:
         return self.devices
 
 
@@ -96,10 +108,10 @@ def __init__(
         self.devices = devices
 
     @classmethod
-    def parse(cls, data: SerializedDict) -> "WorkloadResources":
+    def parse(cls, data: SerializedDict) -> WorkloadResources:
         cpu_data = data.get("cpu", {})
         memory_data = data.get("memory", {})
-        device_data = data.get("devices", {})
+        device_data = data.get("devices")
         data_with_resources = dict(
             data,
             cpu=ResourceCPU.parse(cpu_data),
@@ -111,8 +123,8 @@ def parse(cls, data: SerializedDict) -> "WorkloadResources":
     def get_data(self) -> SerializedDict:
         data = {}
         for key, value in super().get_data().items():
-            item_data = value.get_data()
-            if item_data:
+            item_data = value.get_data_or_none()
+            if item_data is not None:
                 data[key] = item_data
         return data