Skip to content

[SPARK-51739][PYTHON] Validate Arrow schema from mapInArrow & mapInPandas & DataSource #50531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""

from itertools import groupby
from typing import TYPE_CHECKING, Optional

import pyspark
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
Expand Down Expand Up @@ -48,6 +49,10 @@
IntegerType,
)

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa


class SpecialLengths:
END_OF_DATA_SECTION = -1
Expand Down Expand Up @@ -472,15 +477,20 @@ def arrow_to_pandas(self, arrow_column, idx):
)
return s

def _create_struct_array(self, df, arrow_struct_type, spark_type=None):
def _create_struct_array(
self,
df: "pd.DataFrame",
arrow_struct_type: "pa.StructType",
spark_type: Optional[StructType] = None,
):
"""
Create an Arrow StructArray from the given pandas.DataFrame and arrow struct type.

Parameters
----------
df : pandas.DataFrame
A pandas DataFrame
arrow_struct_type : pyarrow.DataType
arrow_struct_type : pyarrow.StructType
pyarrow struct type

Returns
Expand Down Expand Up @@ -518,8 +528,7 @@ def _create_struct_array(self, df, arrow_struct_type, spark_type=None):
for i, field in enumerate(arrow_struct_type)
]

struct_names = [field.name for field in arrow_struct_type]
return pa.StructArray.from_arrays(struct_arrs, struct_names)
return pa.StructArray.from_arrays(struct_arrs, fields=list(arrow_struct_type))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correctly handle non nullable fields required by the arrow_struct_type schema


def _create_batch(self, series):
"""
Expand Down
40 changes: 39 additions & 1 deletion python/pyspark/sql/tests/arrow/test_arrow_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_empty_rows(self):
def empty_rows(_):
return iter([pa.RecordBatch.from_pandas(pd.DataFrame({"a": []}))])

self.assertEqual(self.spark.range(10).mapInArrow(empty_rows, "a int").count(), 0)
self.assertEqual(self.spark.range(10).mapInArrow(empty_rows, "a double").count(), 0)

def test_chain_map_in_arrow(self):
def func(iterator):
Expand Down Expand Up @@ -175,6 +175,44 @@ def test_negative_and_zero_batch_size(self):
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
MapInArrowTests.test_map_in_arrow(self)

def test_nested_extraneous_field(self):
def func(iterator):
for _ in iterator:
struct_arr = pa.StructArray.from_arrays([[1, 2], [3, 4]], names=["a", "b"])
yield pa.RecordBatch.from_arrays([struct_arr], ["x"])

df = self.spark.range(1)
with self.assertRaisesRegex(Exception, r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"):
df.mapInArrow(func, "x struct<b:int>").collect()

def test_top_level_wrong_order(self):
def func(iterator):
for _ in iterator:
yield pa.RecordBatch.from_arrays([[1], [2]], ["b", "a"])

df = self.spark.range(1)
with self.assertRaisesRegex(Exception, r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"):
df.mapInArrow(func, "a int, b int").collect()

def test_nullability_widen(self):
def func(iterator):
for _ in iterator:
yield pa.RecordBatch.from_arrays([[1]], ["a"])

df = self.spark.range(1)
with self.assertRaisesRegex(Exception, r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"):
df.mapInArrow(func, "a int not null").collect()

def test_nullability_narrow(self):
def func(iterator):
for _ in iterator:
yield pa.RecordBatch.from_arrays(
[[1]], pa.schema([pa.field("a", pa.int32(), nullable=False)])
)

df = self.spark.range(1)
df.mapInArrow(func, "a int").collect()


class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/sql/tests/pandas/test_pandas_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from pyspark.sql import Row
from pyspark.sql.functions import col, encode, lit
from pyspark.errors import PythonException
from pyspark.sql.session import SparkSession
from pyspark.sql.types import StructType
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
Expand All @@ -42,6 +44,8 @@
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
class MapInPandasTestsMixin:
spark: SparkSession

@staticmethod
def identity_dataframes_iter(*columns: str):
def func(iterator):
Expand Down Expand Up @@ -128,6 +132,27 @@ def func(iterator):
expected = df.collect()
self.assertEqual(actual, expected)

def test_not_null(self):
def func(iterator):
for _ in iterator:
yield pd.DataFrame({"a": [1, 2]})

schema = "a long not null"
df = self.spark.range(1).mapInPandas(func, schema)
self.assertEqual(df.schema, StructType.fromDDL(schema))
self.assertEqual(df.collect(), [Row(1), Row(2)])

def test_violate_not_null(self):
def func(iterator):
for _ in iterator:
yield pd.DataFrame({"a": [1, None]})

schema = "a long not null"
df = self.spark.range(1).mapInPandas(func, schema)
self.assertEqual(df.schema, StructType.fromDDL(schema))
with self.assertRaisesRegex(Exception, "is null"):
df.collect()

def test_different_output_length(self):
def func(iterator):
for _ in iterator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3548,6 +3548,15 @@ object SQLConf {
// show full stacktrace in tests but hide in production by default.
.createWithDefault(!Utils.isTesting)

val PYSPARK_ARROW_VALIDATE_SCHEMA =
buildConf("spark.sql.execution.arrow.pyspark.validateSchema.enabled")
.doc(
"When true, validate the schema of Arrow batches returned by mapInArrow, mapInPandas " +
"and DataSource against the expected schema to ensure that they are compatible.")
.version("4.1.0")
.booleanConf
.createWithDefault(true)

val PYTHON_UDF_ARROW_ENABLED =
buildConf("spark.sql.execution.pythonUDF.arrow.enabled")
.doc("Enable Arrow optimization in regular Python UDFs. This optimization " +
Expand Down Expand Up @@ -6448,6 +6457,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIED_TRACEBACK)

def pysparkArrowValidateSchema: Boolean = getConf(PYSPARK_ARROW_VALIDATE_SCHEMA)

def pandasGroupedMapAssignColumnsByName: Boolean =
getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
toAttributes(outputSchema),
Seq((ChainedPythonFunctions(Seq(pythonUDF.func)), pythonUDF.resultId.id)),
inputSchema,
outputSchema,
conf.arrowMaxRecordsPerBatch,
pythonEvalType,
conf.sessionLocalTimeZone,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@ package org.apache.spark.sql.execution.python
import scala.jdk.CollectionConverters._

import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory, TaskContext}
import org.apache.spark.api.python.ChainedPythonFunctions
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}

class MapInBatchEvaluatorFactory(
output: Seq[Attribute],
chainedFunc: Seq[(ChainedPythonFunctions, Long)],
outputTypes: StructType,
inputSchema: StructType,
outputSchema: DataType,
batchSize: Int,
pythonEvalType: Int,
sessionLocalTimeZone: String,
Expand Down Expand Up @@ -63,7 +66,7 @@ class MapInBatchEvaluatorFactory(
chainedFunc,
pythonEvalType,
argOffsets,
StructType(Array(StructField("struct", outputTypes))),
StructType(Array(StructField("struct", inputSchema))),
sessionLocalTimeZone,
largeVarTypes,
pythonRunnerConf,
Expand All @@ -75,6 +78,18 @@ class MapInBatchEvaluatorFactory(
val unsafeProj = UnsafeProjection.create(output, output)

columnarBatchIter.flatMap { batch =>
if (SQLConf.get.pysparkArrowValidateSchema) {
// Ensure the schema matches the expected schema, but allowing nullable fields in the
// output schema to become non-nullable in the actual schema.
val actualSchema = batch.column(0).dataType()
val isCompatible =
DataType.equalsIgnoreCompatibleNullability(from = actualSchema, to = outputSchema)
if (!isCompatible) {
throw QueryExecutionErrors.arrowDataTypeMismatchError(
PythonEvalType.toString(pythonEvalType), Seq(outputSchema), Seq(actualSchema))
}
}

// Scalar Iterator UDF returns a StructType column in ColumnarBatch, select
// the children here
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
output,
chainedFunc,
child.schema,
pythonUDF.dataType,
conf.arrowMaxRecordsPerBatch,
pythonEvalType,
conf.sessionLocalTimeZone,
Expand Down