Skip to content

Commit

Permalink
[GLUTEN-7351][CORE] Code cleanup for Gluten session extensions
Browse files Browse the repository at this point in the history
Closes #7351
  • Loading branch information
beliefer authored Sep 26, 2024
1 parent 9093e85 commit 230e605
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.network.util.JavaUtils
import org.apache.spark.softaffinity.SoftAffinityListener
import org.apache.spark.sql.execution.ui.{GlutenEventUtils, GlutenSQLAppStatusListener}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.SPARK_SESSION_EXTENSIONS
import org.apache.spark.task.TaskResources
import org.apache.spark.util.SparkResourceUtil

Expand Down Expand Up @@ -127,14 +128,14 @@ private[gluten] class GlutenDriverPlugin extends DriverPlugin with Logging {
}

private def setPredefinedConfigs(sc: SparkContext, conf: SparkConf): Unit = {
// sql extensions
val extensions = if (conf.contains(GlutenSessionExtensions.SPARK_SESSION_EXTS_KEY)) {
s"${conf.get(GlutenSessionExtensions.SPARK_SESSION_EXTS_KEY)}," +
// Spark SQL extensions
val extensions = if (conf.contains(SPARK_SESSION_EXTENSIONS.key)) {
s"${conf.get(SPARK_SESSION_EXTENSIONS.key)}," +
s"${GlutenSessionExtensions.GLUTEN_SESSION_EXTENSION_NAME}"
} else {
s"${GlutenSessionExtensions.GLUTEN_SESSION_EXTENSION_NAME}"
}
conf.set(GlutenSessionExtensions.SPARK_SESSION_EXTS_KEY, extensions)
conf.set(SPARK_SESSION_EXTENSIONS.key, extensions)

// adaptive custom cost evaluator class
if (GlutenConfig.getConf.enableGluten && GlutenConfig.getConf.enableGlutenCostEvaluator) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ import org.apache.gluten.backend.Backend
import org.apache.gluten.extension.injector.RuleInjector

import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.internal.StaticSQLConf

import java.util.Objects

private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions => Unit) {
override def apply(exts: SparkSessionExtensions): Unit = {
Expand All @@ -33,7 +30,5 @@ private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions =>
}

private[gluten] object GlutenSessionExtensions {
val SPARK_SESSION_EXTS_KEY: String = StaticSQLConf.SPARK_SESSION_EXTENSIONS.key
val GLUTEN_SESSION_EXTENSION_NAME: String =
Objects.requireNonNull(classOf[GlutenSessionExtensions].getCanonicalName)
val GLUTEN_SESSION_EXTENSION_NAME: String = classOf[GlutenSessionExtensions].getCanonicalName
}

0 comments on commit 230e605

Please sign in to comment.