diff --git a/libcontainer/cgroups/utils.go b/libcontainer/cgroups/utils.go index d303cf204c9..d95c53ab1dc 100644 --- a/libcontainer/cgroups/utils.go +++ b/libcontainer/cgroups/utils.go @@ -415,26 +415,29 @@ func ConvertCPUSharesToCgroupV2Value(cpuShares uint64) uint64 { // ConvertMemorySwapToCgroupV2Value converts MemorySwap value from OCI spec // for use by cgroup v2 drivers. A conversion is needed since Resources.MemorySwap -// is defined as memory+swap combined, while in cgroup v2 swap is a separate value. +// is defined as memory+swap combined, while in cgroup v2 swap is a separate value, +// so we need to subtract memory from it where it makes sense. func ConvertMemorySwapToCgroupV2Value(memorySwap, memory int64) (int64, error) { - // for compatibility with cgroup1 controller, set swap to unlimited in - // case the memory is set to unlimited, and swap is not explicitly set, - // treating the request as "set both memory and swap to unlimited". - if memory == -1 && memorySwap == 0 { + switch { + case memory == -1 && memorySwap == 0: + // For compatibility with cgroup1 controller, set swap to unlimited in + // case the memory is set to unlimited and the swap is not explicitly set, + // treating the request as "set both memory and swap to unlimited". return -1, nil - } - if memorySwap == -1 || memorySwap == 0 { - // -1 is "max", 0 is "unset", so treat as is + case memorySwap == -1, memorySwap == 0: + // Treat -1 ("max") and 0 ("unset") swap as is. return memorySwap, nil - } - // sanity checks - if memory == 0 || memory == -1 { + case memory == -1: + // Unlimited memory, so treat swap as is. + return memorySwap, nil + case memory == 0: + // Unset or unknown memory, can't calculate swap. return 0, errors.New("unable to set swap limit without memory limit") - } - if memory < 0 { + case memory < 0: + // Does not make sense to subtract a negative value. return 0, fmt.Errorf("invalid memory value: %d", memory) - } - if memorySwap < memory { + case memorySwap < memory: + // Sanity check. return 0, errors.New("memory+swap limit should be >= memory limit") } diff --git a/libcontainer/cgroups/utils_test.go b/libcontainer/cgroups/utils_test.go index c9feb84484c..fc81992f0b0 100644 --- a/libcontainer/cgroups/utils_test.go +++ b/libcontainer/cgroups/utils_test.go @@ -554,82 +554,97 @@ func TestConvertCPUSharesToCgroupV2Value(t *testing.T) { func TestConvertMemorySwapToCgroupV2Value(t *testing.T) { cases := []struct { + descr string memswap, memory int64 expected int64 expErr bool }{ { + descr: "all unset", memswap: 0, memory: 0, expected: 0, }, { + descr: "unlimited memory+swap, unset memory", memswap: -1, memory: 0, expected: -1, }, { + descr: "unlimited memory", + memswap: 300, + memory: -1, + expected: 300, + }, + { + descr: "all unlimited", memswap: -1, memory: -1, expected: -1, }, { + descr: "negative memory+swap", memswap: -2, memory: 0, expErr: true, }, { + descr: "unlimited memory+swap, set memory", memswap: -1, memory: 1000, expected: -1, }, { + descr: "memory+swap == memory", memswap: 1000, memory: 1000, expected: 0, }, { + descr: "memory+swap > memory", memswap: 500, memory: 200, expected: 300, }, { + descr: "memory+swap < memory", memswap: 300, memory: 400, expErr: true, }, { + descr: "unset memory", memswap: 300, memory: 0, expErr: true, }, { + descr: "negative memory", memswap: 300, memory: -300, expErr: true, }, - { - memswap: 300, - memory: -1, - expErr: true, - }, } for _, c := range cases { - swap, err := ConvertMemorySwapToCgroupV2Value(c.memswap, c.memory) - if c.expErr { - if err == nil { - t.Errorf("memswap: %d, memory %d, expected error, got %d, nil", c.memswap, c.memory, swap) + c := c + t.Run(c.descr, func(t *testing.T) { + swap, err := ConvertMemorySwapToCgroupV2Value(c.memswap, c.memory) + if c.expErr { + if err == nil { + t.Errorf("memswap: %d, memory %d, expected error, got %d, nil", c.memswap, c.memory, swap) + } + // No more checks. + return } - // no more checks - continue - } - if err != nil { - t.Errorf("memswap: %d, memory %d, expected success, got error %s", c.memswap, c.memory, err) - } - if swap != c.expected { - t.Errorf("memswap: %d, memory %d, expected %d, got %d", c.memswap, c.memory, c.expected, swap) - } + if err != nil { + t.Errorf("memswap: %d, memory %d, expected success, got error %s", c.memswap, c.memory, err) + } + if swap != c.expected { + t.Errorf("memswap: %d, memory %d, expected %d, got %d", c.memswap, c.memory, c.expected, swap) + } + }) } }