From bc585180b912a5c91b353092155138caa7d9f6e6 Mon Sep 17 00:00:00 2001 From: Marcus Rosti Date: Tue, 17 Dec 2024 00:42:03 -0800 Subject: [PATCH] Appends prediction columns to transform schema (#60) * Appends prediction columns to transform schema * fixes the comment --- .../isolationforest/IsolationForest.scala | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala index e885339..1864926 100644 --- a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala +++ b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala @@ -7,7 +7,7 @@ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.ml.Estimator -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.sql.Dataset import org.apache.spark.{HashPartitioner, TaskContext} @@ -187,8 +187,8 @@ class IsolationForest(override val uid: String) extends Estimator[IsolationFores /** * Validates the input schema and transforms it into the output schema. It validates that the - * input DataFrame has a $(featuresCol) of the correct type. In this case, the output schema is - * identical to the input schema. + * input DataFrame has a $(featuresCol) of the correct type. In this case, the output schema appends + * the output columns to the input schema. * * @param schema The schema of the DataFrame containing the data to be fit. * @return The schema of the DataFrame containing the data to be fit. @@ -200,7 +200,16 @@ class IsolationForest(override val uid: String) extends Estimator[IsolationFores require(schema($(featuresCol)).dataType == VectorType, s"Input column ${$(featuresCol)} is not of required type ${VectorType}") - val outputFields = schema.fields + val outputFields: Array[StructField] = schema.fields ++ Array( + StructField( + name = s"$predictionCol", + dataType = DoubleType + ), + StructField( + name = s"$scoreCol", + dataType = DoubleType + ) + ) StructType(outputFields) }