diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsParquetThriftCompatibilitySuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsParquetThriftCompatibilitySuite.scala index 5353bd139c3..e1133e7387f 100644 --- a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsParquetThriftCompatibilitySuite.scala +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsParquetThriftCompatibilitySuite.scala @@ -19,9 +19,56 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.suites +import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.parquet.ParquetThriftCompatibilitySuite import org.apache.spark.sql.rapids.utils.RapidsSQLTestsBaseTrait class RapidsParquetThriftCompatibilitySuite extends ParquetThriftCompatibilitySuite - with RapidsSQLTestsBaseTrait {} + with RapidsSQLTestsBaseTrait { + + test("Read Parquet file generated by parquet-thrift Rapids") { + + val parquetFilePath = + "test-data/parquet-thrift-compat.snappy.parquet" + + checkAnswer(spark.read.parquet(testFile(parquetFilePath)), (0 until 10).map { i => + val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS") + + val nonNullablePrimitiveValues = Seq( + i % 2 == 0, + i.toByte, + (i + 1).toShort, + i + 2, + i.toLong * 10, + i.toDouble + 0.2d, + // Thrift `BINARY` values are actually unencoded `STRING` values, and thus are always + // treated as `BINARY (UTF8)` in parquet-thrift, since parquet-thrift always assume + // Thrift `STRING`s are encoded using UTF-8. + s"val_$i", + s"val_$i", + // Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings + suits(i % 4)) + + val nullablePrimitiveValues = if (i % 3 == 0) { + Seq.fill(nonNullablePrimitiveValues.length)(null) + } else { + nonNullablePrimitiveValues + } + + val complexValues = Seq( + Seq.tabulate(3)(n => s"arr_${i + n}"), + // Thrift `SET`s are converted to Parquet `LIST`s + Seq(i), + Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap, + Seq.tabulate(3) { n => + (i + n) -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + + Row(nonNullablePrimitiveValues ++ nullablePrimitiveValues ++ complexValues: _*) + }) + } + +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala index 8b76e350fef..85bd47a5f3b 100644 --- a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala @@ -110,8 +110,8 @@ class RapidsTestSettings extends BackendTestSettings { .exclude("schema mismatch failure error message for parquet reader", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11434")) .exclude("schema mismatch failure error message for parquet vectorized reader", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11446")) enableSuite[RapidsParquetThriftCompatibilitySuite] - .exclude("Read Parquet file generated by parquet-thrift", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11381")) - .exclude("SPARK-10136 list of primitive list", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11381")) + .exclude("Read Parquet file generated by parquet-thrift", ADJUST_UT("https://github.com/NVIDIA/spark-rapids/pull/11591")) + .exclude("SPARK-10136 list of primitive list", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11589")) enableSuite[RapidsParquetVectorizedSuite] enableSuite[RapidsRegexpExpressionsSuite] enableSuite[RapidsStringExpressionsSuite]