Skip to content

Commit bf9e707

Browse files
committed
add flag
1 parent 8f4c0d4 commit bf9e707

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

+11
Original file line numberDiff line numberDiff line change
@@ -3548,6 +3548,15 @@ object SQLConf {
35483548
// show full stacktrace in tests but hide in production by default.
35493549
.createWithDefault(!Utils.isTesting)
35503550

3551+
val PYSPARK_ARROW_VALIDATE_SCHEMA =
3552+
buildConf("spark.sql.execution.arrow.pyspark.validateSchema.enabled")
3553+
.doc(
3554+
"When true, validate the schema of Arrow batches returned by mapInArrow, mapInPandas " +
3555+
"and DataSource against the expected schema to ensure that they are compatible.")
3556+
.version("4.1.0")
3557+
.booleanConf
3558+
.createWithDefault(true)
3559+
35513560
val PYTHON_UDF_ARROW_ENABLED =
35523561
buildConf("spark.sql.execution.pythonUDF.arrow.enabled")
35533562
.doc("Enable Arrow optimization in regular Python UDFs. This optimization " +
@@ -6448,6 +6457,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
64486457

64496458
def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIED_TRACEBACK)
64506459

6460+
def pysparkArrowValidateSchema: Boolean = getConf(PYSPARK_ARROW_VALIDATE_SCHEMA)
6461+
64516462
def pandasGroupedMapAssignColumnsByName: Boolean =
64526463
getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME)
64536464

sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala

+9-11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.errors.QueryExecutionErrors
2727
import org.apache.spark.sql.execution.metric.SQLMetric
28+
import org.apache.spark.sql.internal.SQLConf
2829
import org.apache.spark.sql.types.{DataType, StructField, StructType}
2930
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
3031

@@ -77,17 +78,14 @@ class MapInBatchEvaluatorFactory(
7778
val unsafeProj = UnsafeProjection.create(output, output)
7879

7980
columnarBatchIter.flatMap { batch =>
80-
// Ensure the schema matches the expected schema
81-
val actualSchema = batch.column(0).dataType()
82-
val strictCheck = true
83-
val isCompatible = if (strictCheck) {
84-
DataType.equalsIgnoreNullability(actualSchema, outputSchema)
85-
} else {
86-
outputSchema.sameType(actualSchema)
87-
}
88-
if (!isCompatible) {
89-
throw QueryExecutionErrors.arrowDataTypeMismatchError(
90-
PythonEvalType.toString(pythonEvalType), Seq(outputSchema), Seq(actualSchema))
81+
if (SQLConf.get.pysparkArrowValidateSchema) {
82+
// Ensure the schema matches the expected schema
83+
val actualSchema = batch.column(0).dataType()
84+
val isCompatible = DataType.equalsIgnoreCompatibleNullability(actualSchema, outputSchema)
85+
if (!isCompatible) {
86+
throw QueryExecutionErrors.arrowDataTypeMismatchError(
87+
PythonEvalType.toString(pythonEvalType), Seq(outputSchema), Seq(actualSchema))
88+
}
9189
}
9290

9391
// Scalar Iterator UDF returns a StructType column in ColumnarBatch, select

0 commit comments

Comments
 (0)