diff --git a/tony-core/src/test/java/com/linkedin/tony/util/TestUtils.java b/tony-core/src/test/java/com/linkedin/tony/util/TestUtils.java index 5234962d..b7f3dc98 100644 --- a/tony-core/src/test/java/com/linkedin/tony/util/TestUtils.java +++ b/tony-core/src/test/java/com/linkedin/tony/util/TestUtils.java @@ -7,22 +7,30 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.linkedin.tony.TFConfig; +import com.linkedin.tony.TonyConfigurationKeys; import com.linkedin.tony.tensorflow.TensorFlowContainerRequest; import java.io.File; import java.io.IOException; +import java.net.SocketException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.CommonConfigurationKeys; +import org.apache.hadoop.yarn.api.records.Container; import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnException; import org.testng.annotations.Test; import static org.mockito.Mockito.*; import static org.testng.Assert.*; - public class TestUtils { + @Test public void testParseMemoryString() { assertEquals(Utils.parseMemoryString("2g"), "2048"); @@ -62,12 +70,12 @@ public void testParseContainerRequests() { conf.setInt("tony.chief.gpus", 1); Map requests = Utils.parseContainerRequests(conf); - assertEquals(3, requests.get("worker").getNumInstances()); - assertEquals(1, requests.get("evaluator").getNumInstances()); - assertEquals(1, requests.get("worker").getGPU()); - assertEquals(2, requests.get("evaluator").getVCores()); + assertEquals(requests.get("worker").getNumInstances(), 3); + assertEquals(requests.get("evaluator").getNumInstances(), 1); + assertEquals(requests.get("worker").getGPU(), 1); + assertEquals(requests.get("evaluator").getVCores(), 2); // Check default value. - assertEquals(2048, requests.get("worker").getMemory()); + assertEquals(requests.get("worker").getMemory(), 2048); // Check job does not exist if no instances are configured. assertFalse(requests.containsKey("chief")); } @@ -90,13 +98,12 @@ public void testIsNotArchive() { assertFalse(Utils.isArchive(file1.getAbsolutePath())); } - @Test public void testRenameFile() throws IOException { File tempFile = File.createTempFile("testRenameFile-", "-suffix"); tempFile.deleteOnExit(); boolean result = Utils.renameFile(tempFile.getAbsolutePath(), - tempFile.getAbsolutePath() + "bak"); + tempFile.getAbsolutePath() + "bak"); assertTrue(Files.exists(Paths.get(tempFile.getAbsolutePath() + "bak"))); assertTrue(result); Files.deleteIfExists(Paths.get(tempFile.getAbsolutePath() + "bak")); @@ -108,11 +115,11 @@ public void testConstructTFConfig() throws IOException { String tfConfig = Utils.constructTFConfig(spec, "worker", 1); ObjectMapper mapper = new ObjectMapper(); TFConfig config = mapper.readValue(tfConfig, new TypeReference() { }); - assertEquals("worker", config.getTask().getType()); - assertEquals(1, config.getTask().getIndex()); - assertEquals("host0:1234", config.getCluster().get("worker").get(0)); - assertEquals("host1:1234", config.getCluster().get("worker").get(1)); - assertEquals("host2:1234", config.getCluster().get("ps").get(0)); + assertEquals(config.getTask().getType(), "worker"); + assertEquals(config.getTask().getIndex(), 1); + assertEquals(config.getCluster().get("worker").get(0), "host0:1234"); + assertEquals(config.getCluster().get("worker").get(1), "host1:1234"); + assertEquals(config.getCluster().get("ps").get(0), "host2:1234"); } @Test @@ -122,4 +129,139 @@ public void testBuildRMUrl() { String expected = "http://testrmaddress/cluster/app/1"; assertEquals(Utils.buildRMUrl(yarnConf, "1"), expected); } + + @Test + public void testPollTillNonNull() { + assertNull(Utils.pollTillNonNull(() -> null, 1, 1)); + assertTrue(Utils.pollTillNonNull(() -> true, 1, 1)); + } + + @Test + public void testConstructUrl() { + assertEquals(Utils.constructUrl("foobar"), "http://foobar"); + assertEquals(Utils.constructUrl("http://foobar"), "http://foobar"); + } + + @Test + public void testConstructContainerUrl() { + Container container = mock(Container.class); + assertNotNull(Utils.constructContainerUrl(container)); + assertNotNull(Utils.constructContainerUrl("foo", null)); + } + + @Test + public void testParseKeyValue() { + HashMap hashMap = new HashMap<>(); + hashMap.put("bar", ""); + hashMap.put("foo", "1"); + hashMap.put("baz", "3"); + + assertEquals(Utils.parseKeyValue(null), new HashMap<>()); + assertEquals(Utils.parseKeyValue( + new String[]{"foo=1", "bar", "baz=3"}), hashMap); + } + + @Test + public void testExecuteShell() throws IOException, InterruptedException { + assertEquals(Utils.executeShell("foo", 0, null), 127); + } + + @Test + public void testGetCurrentHostName() { + assertNull(Utils.getCurrentHostName()); + } + + @Test + public void testGetHostNameOrIpFromTokenConf() + throws SocketException, YarnException { + Configuration conf = mock(Configuration.class); + when(conf.getBoolean( + CommonConfigurationKeys.HADOOP_SECURITY_TOKEN_SERVICE_USE_IP, + CommonConfigurationKeys + .HADOOP_SECURITY_TOKEN_SERVICE_USE_IP_DEFAULT)) + .thenReturn(false); + assertNull(Utils.getHostNameOrIpFromTokenConf(conf)); + } + + @Test + public void testGetAllJobTypes() { + Configuration conf = new Configuration(); + conf.addResource("tony-default.xml"); + conf.setInt("tony.worker.instances", 3); + conf.setInt("tony.evaluator.instances", 1); + conf.setInt("tony.worker.gpus", 1); + conf.setInt("tony.evaluator.vcores", 2); + conf.setInt("tony.chief.gpus", 1); + + assertEquals(Utils.getAllJobTypes(conf), + new HashSet(Arrays.asList("worker", "evaluator"))); + } + + @Test + public void testGetNumTotalTasks() { + Configuration conf = new Configuration(); + conf.addResource("tony-default.xml"); + conf.setInt("tony.worker.instances", 3); + conf.setInt("tony.evaluator.instances", 1); + conf.setInt("tony.worker.gpus", 1); + conf.setInt("tony.evaluator.vcores", 2); + conf.setInt("tony.chief.gpus", 1); + + assertEquals(Utils.getNumTotalTasks(conf), 4); + } + + @Test + public void testGetTaskType() { + assertNull(Utils.getTaskType("foo")); + assertEquals(Utils.getTaskType("tony.evaluator.instances"), + "evaluator"); + } + + @Test + public void testGetClientResourcesPath() { + assertEquals(Utils.getClientResourcesPath("foo", "bar"), + "foo-bar"); + } + + @Test + public void testGetUntrackedJobTypes() { + Configuration conf = new Configuration(); + conf.addResource("tony-default.xml"); + conf.setInt("tony.worker.instances", 3); + conf.setInt("tony.evaluator.instances", 1); + conf.setInt("tony.worker.gpus", 1); + conf.setInt("tony.evaluator.vcores", 2); + conf.setInt("tony.chief.gpus", 1); + + assertEquals(Utils.getUntrackedJobTypes(conf), + new String[]{"ps"}, "Arrays do not match"); + } + + @Test + public void testIsJobTypeTracked() { + Configuration conf = new Configuration(); + conf.addResource("tony-default.xml"); + conf.setInt("tony.worker.instances", 3); + conf.setInt("tony.evaluator.instances", 1); + conf.setInt("tony.worker.gpus", 1); + conf.setInt("tony.evaluator.vcores", 2); + conf.setInt("tony.chief.gpus", 1); + + assertTrue(Utils.isJobTypeTracked("tony.worker.gpus", conf)); + } + + @Test + public void testGetContainerEnvForDocker() { + Configuration conf = mock(Configuration.class); + when(conf.getBoolean(TonyConfigurationKeys.DOCKER_ENABLED, + TonyConfigurationKeys.DEFAULT_DOCKER_ENABLED)) + .thenReturn(true); + assertEquals(Utils.getContainerEnvForDocker(conf, "tony.worker.gpus"), + new HashMap<>()); + + when(conf.get(TonyConfigurationKeys + .getDockerImageKey("tony.worker.gpus"))).thenReturn("foo"); + assertEquals(Utils.getContainerEnvForDocker(conf, "tony.worker.gpus"), + new HashMap<>()); + } }