Skip to content

Commit

Permalink
[VL] Enable AtLeastNNonNulls function
Browse files Browse the repository at this point in the history
  • Loading branch information
zhli1142015 committed Sep 25, 2024
1 parent 2330dc4 commit e9d6dc9
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ object CHExpressionUtil {
DATE_FROM_UNIX_DATE -> DefaultValidator(),
MONOTONICALLY_INCREASING_ID -> DefaultValidator(),
SPARK_PARTITION_ID -> DefaultValidator(),
AT_LEAST_N_NON_NULLS -> DefaultValidator(),
URL_DECODE -> DefaultValidator(),
URL_ENCODE -> DefaultValidator(),
FORMAT_STRING -> FormatStringValidator(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
newExpr)
}

override def genAtLeastNNonNullsTransformer(
substraitExprName: String,
children: Seq[ExpressionTransformer],
original: AtLeastNNonNulls): ExpressionTransformer = {
GenericExpressionTransformer(
substraitExprName,
Seq(LiteralTransformer(Literal(original.n))) ++ children,
original)
}

/** Transform Uuid to Substrait. */
override def genUuidTransformer(
substraitExprName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1417,4 +1417,25 @@ abstract class ScalarFunctionsValidateSuite extends FunctionsValidateSuite {
// Scale < 0 should round down even on integral values
compareResultsAgainstVanillaSpark("select round(44, -1)", true, { _ => })
}

test("test internal function: AtLeastNNonNulls") {
// AtLeastNNonNulls is called by drop DataFrameNafunction
withTempPath {
path =>
val input = Seq[(String, java.lang.Integer, java.lang.Double)](
("Bob", 16, 176.5),
("Alice", null, 164.3),
("David", 60, null),
("Nina", 25, Double.NaN),
("Amy", null, null),
(null, null, null)
).toDF("name", "age", "height")
val rows = input.collect()
input.write.parquet(path.getCanonicalPath)

val df = spark.read.parquet(path.getCanonicalPath).na.drop(2, Seq("age", "height"))
checkAnswer(df, rows(0) :: Nil)
checkGlutenOperatorMatch[FilterExecTransformer](df)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ trait SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, Seq(left, right), original)
}

def genAtLeastNNonNullsTransformer(
substraitExprName: String,
children: Seq[ExpressionTransformer],
original: AtLeastNNonNulls): ExpressionTransformer = {
throw new GlutenNotSupportException("AtLeastNNonNulls is not supported")
}

def genUuidTransformer(substraitExprName: String, original: Uuid): ExpressionTransformer = {
GenericExpressionTransformer(substraitExprName, Seq(), original)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,12 @@ object ExpressionConverter extends SQLConfHelper with Logging {
replaceWithExpressionTransformer0(n.right, attributeSeq, expressionsMap),
n
)
case a: AtLeastNNonNulls =>
BackendsApiManager.getSparkPlanExecApiInstance.genAtLeastNNonNullsTransformer(
substraitExprName,
a.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)),
a
)
case m: MakeTimestamp =>
BackendsApiManager.getSparkPlanExecApiInstance.genMakeTimestampTransformer(
substraitExprName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ object ExpressionMappings {
Sig[PromotePrecision](PROMOTE_PRECISION),
Sig[MonotonicallyIncreasingID](MONOTONICALLY_INCREASING_ID),
Sig[SparkPartitionID](SPARK_PARTITION_ID),
Sig[AtLeastNNonNulls](AT_LEAST_N_NON_NULLS),
Sig[WidthBucket](WIDTH_BUCKET),
Sig[ReplicateRows](REPLICATE_ROWS),
Sig[RaiseError](RAISE_ERROR),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ object ExpressionNames {
final val REPLICATE_ROWS = "replicaterows"
final val RAISE_ERROR = "raise_error"
final val VERSION = "version"
final val AT_LEAST_N_NON_NULLS = "at_least_n_non_nulls"

// Directly use child expression transformer
final val KNOWN_NULLABLE = "known_nullable"
Expand Down

0 comments on commit e9d6dc9

Please sign in to comment.