diff --git a/src/test/scala/org/apache/spark/ml/feature/ITSelectorSuite.scala b/src/test/scala/org/apache/spark/ml/feature/ITSelectorSuite.scala index 6ad9e0d..8c5b43a 100644 --- a/src/test/scala/org/apache/spark/ml/feature/ITSelectorSuite.scala +++ b/src/test/scala/org/apache/spark/ml/feature/ITSelectorSuite.scala @@ -8,7 +8,7 @@ import TestHelper._ /** - * Test infomartion theoretic feature selection + * Test information theoretic feature selection on datasets from Peng's webpage * * @author Sergio Ramirez */ @@ -21,20 +21,78 @@ class ITSelectorSuite extends FunSuite with BeforeAndAfterAll { sqlContext = new SQLContext(SPARK_CTX) } - /** Do entropy based binning of cars data from UC Irvine repository. */ + /** Do mRMR feature selection on COLON data. */ test("Run ITFS on colon data (nPart = 10, nfeat = 10)") { - val df = readColonData(sqlContext) + val df = readCSVData(sqlContext, "test_colon_s3.csv") val cols = df.columns val pad = 2 val allVectorsDense = true - val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head, 10, 10, allVectorsDense, pad) + val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head, + 10, 10, allVectorsDense, pad) assertResult("512, 764, 1324, 1380, 1411, 1422, 1581, 1670, 1671, 1971") { model.selectedFeatures.mkString(", ") } } + /** Do mRMR feature selection on LEUKEMIA data. */ + test("Run ITFS on leukemia data (nPart = 10, nfeat = 10)") { + + val df = readCSVData(sqlContext, "test_leukemia_s3.csv") + val cols = df.columns + val pad = 2 + val allVectorsDense = true + val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head, + 10, 10, allVectorsDense, pad) + + assertResult("1084, 1719, 1774, 1822, 2061, 2294, 3192, 4387, 4787, 6795") { + model.selectedFeatures.mkString(", ") + } + } + /** Do mRMR feature selection on LUNG data. */ + test("Run ITFS on lung data (nPart = 10, nfeat = 10)") { + + val df = readCSVData(sqlContext, "test_lung_s3.csv") + val cols = df.columns + val pad = 2 + val allVectorsDense = true + val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head, + 10, 10, allVectorsDense, pad) + + assertResult("18, 22, 29, 125, 132, 150, 166, 242, 243, 269") { + model.selectedFeatures.mkString(", ") + } + } + + /** Do mRMR feature selection on LYMPHOMA data. */ + test("Run ITFS on lymphoma data (nPart = 10, nfeat = 10)") { + val df = readCSVData(sqlContext, "test_lymphoma_s3.csv") + val cols = df.columns + val pad = 2 + val allVectorsDense = true + val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head, + 10, 10, allVectorsDense, pad) + + assertResult("236, 393, 759, 2747, 2818, 2841, 2862, 3014, 3702, 3792") { + model.selectedFeatures.mkString(", ") + } + } + + /** Do mRMR feature selection on NCI data. */ + test("Run ITFS on nci data (nPart = 10, nfeat = 10)") { + + val df = readCSVData(sqlContext, "test_nci9_s3.csv") + val cols = df.columns + val pad = 2 + val allVectorsDense = true + val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head, + 10, 10, allVectorsDense, pad) + + assertResult("443, 755, 1369, 1699, 3483, 5641, 6290, 7674, 9399, 9576") { + model.selectedFeatures.mkString(", ") + } + } } \ No newline at end of file diff --git a/src/test/scala/org/apache/spark/ml/feature/TestHelper.scala b/src/test/scala/org/apache/spark/ml/feature/TestHelper.scala index ac0a8f0..4b42e73 100644 --- a/src/test/scala/org/apache/spark/ml/feature/TestHelper.scala +++ b/src/test/scala/org/apache/spark/ml/feature/TestHelper.scala @@ -29,7 +29,7 @@ object TestHelper { final val INDEX_SUFFIX: String = "_IDX" /** - * @return the discretizer fit to the data given the specified features to bin and label use as target. + * @return the feature select fit to the data given the specified features to bin and label use as target. */ def createSelectorModel(sqlContext: SQLContext, dataframe: Dataset[_], inputCols: Array[String], @@ -37,7 +37,7 @@ object TestHelper { nPartitions: Int = 100, numTopFeatures: Int = 20, allVectorsDense: Boolean = true, - padded: Int = 0): InfoThSelectorModel = { + padded: Int = 0 /* if minimum value is negative */): InfoThSelectorModel = { val featureAssembler = new VectorAssembler() .setInputCols(inputCols) .setOutputCol("features") @@ -73,7 +73,7 @@ object TestHelper { /** * The label column will have null values replaced with MISSING values in this case. - * @return the discretizer fit to the data given the specified features to bin and label use as target. + * @return the feature selector fit to the data given the specified features to bin and label use as target. */ def getSelectorModel(sqlContext: SQLContext, dataframe: DataFrame, inputCols: Array[String], labelColumn: String, @@ -121,53 +121,17 @@ object TestHelper { sc } - /** @return standard iris dataset from UCI repo. - */ - /*def readColonData(sqlContext: SQLContext): DataFrame = { - val data = SPARK_CTX.textFile(FILE_PREFIX + "iris.data") - val nullable = true - - val schema = (0 until 9712).map(i => StructField("var" + i, DoubleType, nullable)).toList :+ - StructField("colontype", StringType, nullable) - // ints and dates must be read as doubles - val rows = data.map(line => line.split(",").map(elem => elem.trim)) - .map(x => {Row.fromSeq(Seq(asDouble(x(0)), asDouble(x(1)), asDouble(x(2)), asDouble(x(3)), asString(x(4))))}) - - sqlContext.createDataFrame(rows, schema) - } - - /** @return standard iris dataset from UCI repo. + /** @return standard csv data from the repo. */ - def readColonData2(sqlContext: SQLContext): DataFrame = { - val data = SPARK_CTX.textFile(FILE_PREFIX + "iris.data") - val nullable = true - val schema = StructType(List( - StructField("features", new VectorUDT, nullable), - StructField("class", DoubleType, nullable) - )) - val rows = data.map{line => - val split = line.split(",").map(elem => elem.trim) - val features = Vectors.dense(split.drop(1).map(_.toDouble)) - val label = split.head.toDouble - (features, label) - } - val asd = sqlContext.createDataFrame(rows, schema) - - }*/ - - - def readColonData(sqlContext: SQLContext): DataFrame = { + def readCSVData(sqlContext: SQLContext, file: String): DataFrame = { val df = sqlContext.read .format("com.databricks.spark.csv") .option("header", "true") // Use first line of all files as header .option("inferSchema", "true") // Automatically infer data types - .load(FILE_PREFIX + "test_colon_s3.csv") + .load(FILE_PREFIX + file) df } - - - /** @return dataset with 3 double columns. The first is the label column and contain null. */ def readNullLabelTestData(sqlContext: SQLContext): DataFrame = {