From 11f4976be1aae4c041ea66e3e83487aa2614c947 Mon Sep 17 00:00:00 2001 From: gnehil Date: Tue, 5 Sep 2023 13:56:25 +0800 Subject: [PATCH] [fix] streaming write execution plan error (#135) * fix streaming write error and add json data pass through option * handle stream pass through, force set read_json_by_line is true when format is json --- .../doris/spark/cfg/ConfigurationOptions.java | 6 ++ .../doris/spark/load/DorisStreamLoad.java | 50 ++++++++++ .../doris/spark/sql/DorisStreamLoadSink.scala | 2 +- .../doris/spark/writer/DorisWriter.scala | 97 +++++++++++++++---- 4 files changed, 135 insertions(+), 20 deletions(-) diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java index 2ab200d8..09c0416f 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java @@ -100,4 +100,10 @@ public interface ConfigurationOptions { String DORIS_SINK_ENABLE_2PC = "doris.sink.enable-2pc"; boolean DORIS_SINK_ENABLE_2PC_DEFAULT = false; + /** + * pass through json data when sink to doris in streaming mode + */ + String DORIS_SINK_STREAMING_PASSTHROUGH = "doris.sink.streaming.passthrough"; + boolean DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT = false; + } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index 4a7b1e05..ac920cd0 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -96,6 +96,8 @@ public class DorisStreamLoad implements Serializable { private boolean readJsonByLine = false; + private boolean streamingPassthrough = false; + public DorisStreamLoad(SparkSettings settings) { String[] dbTable = settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\."); this.db = dbTable[0]; @@ -121,6 +123,8 @@ public DorisStreamLoad(SparkSettings settings) { } } LINE_DELIMITER = escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n")); + this.streamingPassthrough = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH, + ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT); } public String getLoadUrlStr() { @@ -196,6 +200,38 @@ public List loadV2(List> rows, String[] dfColumns, Boolean } + public List loadStream(List> rows, String[] dfColumns, Boolean enable2PC) + throws StreamLoadException, JsonProcessingException { + + List loadData; + + if (this.streamingPassthrough) { + handleStreamPassThrough(); + loadData = passthrough(rows); + } else { + loadData = parseLoadData(rows, dfColumns); + } + + List txnIds = new ArrayList<>(loadData.size()); + + try { + for (String data : loadData) { + txnIds.add(load(data, enable2PC)); + } + } catch (StreamLoadException e) { + if (enable2PC && !txnIds.isEmpty()) { + LOG.error("load batch failed, abort previously pre-committed transactions"); + for (Integer txnId : txnIds) { + abort(txnId); + } + } + throw e; + } + + return txnIds; + + } + public int load(String value, Boolean enable2PC) throws StreamLoadException { String label = generateLoadLabel(); @@ -442,4 +478,18 @@ private String escapeString(String hexData) { return hexData; } + private void handleStreamPassThrough() { + + if ("json".equalsIgnoreCase(fileType)) { + LOG.info("handle stream pass through, force set read_json_by_line is true for json format"); + streamLoadProp.put("read_json_by_line", "true"); + streamLoadProp.remove("strip_outer_array"); + } + + } + + private List passthrough(List> values) { + return values.stream().map(list -> list.get(0).toString()).collect(Collectors.toList()); + } + } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala index 342e940e..d1a2b748 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala @@ -34,7 +34,7 @@ private[sql] class DorisStreamLoadSink(sqlContext: SQLContext, settings: SparkSe if (batchId <= latestBatchId) { logger.info(s"Skipping already committed batch $batchId") } else { - writer.write(data) + writer.writeStream(data) latestBatchId = batchId } } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala index 2b918e88..e32267ee 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala @@ -22,6 +22,9 @@ import org.apache.doris.spark.listener.DorisTransactionListener import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad} import org.apache.doris.spark.sql.Utils import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.CollectionAccumulator import org.slf4j.{Logger, LoggerFactory} import java.io.IOException @@ -76,28 +79,13 @@ class DorisWriter(settings: SparkSettings) extends Serializable { * flush data to Doris and do retry when flush error * */ - def flush(batch: Iterable[util.List[Object]], dfColumns: Array[String]): Unit = { + def flush(batch: Seq[util.List[Object]], dfColumns: Array[String]): Unit = { Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { - dorisStreamLoader.loadV2(batch.toList.asJava, dfColumns, enable2PC) + dorisStreamLoader.loadV2(batch.asJava, dfColumns, enable2PC) } match { - case Success(txnIds) => if (enable2PC) txnIds.asScala.foreach(txnId => preCommittedTxnAcc.add(txnId)) + case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc) case Failure(e) => - if (enable2PC) { - // if task run failed, acc value will not be returned to driver, - // should abort all pre committed transactions inside the task - logger.info("load task failed, start aborting previously pre-committed transactions") - val abortFailedTxnIds = mutable.Buffer[Int]() - preCommittedTxnAcc.value.asScala.foreach(txnId => { - Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) { - dorisStreamLoader.abort(txnId) - } match { - case Success(_) => - case Failure(_) => abortFailedTxnIds += txnId - } - }) - if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(",")) - preCommittedTxnAcc.reset() - } + if (enable2PC) handleLoadFailure(preCommittedTxnAcc) throw new IOException( s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e) } @@ -105,5 +93,76 @@ class DorisWriter(settings: SparkSettings) extends Serializable { } + def writeStream(dataFrame: DataFrame): Unit = { + + val sc = dataFrame.sqlContext.sparkContext + val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc") + if (enable2PC) { + sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader)) + } + + var resultRdd = dataFrame.queryExecution.toRdd + val schema = dataFrame.schema + val dfColumns = dataFrame.columns + if (Objects.nonNull(sinkTaskPartitionSize)) { + resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize) + } + resultRdd + .foreachPartition(partition => { + partition + .grouped(batchSize) + .foreach(batch => + flush(batch, dfColumns)) + }) + + /** + * flush data to Doris and do retry when flush error + * + */ + def flush(batch: Seq[InternalRow], dfColumns: Array[String]): Unit = { + Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { + dorisStreamLoader.loadStream(convertToObjectList(batch, schema), dfColumns, enable2PC) + } match { + case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc) + case Failure(e) => + if (enable2PC) handleLoadFailure(preCommittedTxnAcc) + throw new IOException( + s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e) + } + } + + def convertToObjectList(rows: Seq[InternalRow], schema: StructType): util.List[util.List[Object]] = { + rows.map(row => { + row.toSeq(schema).map(_.asInstanceOf[AnyRef]).toList.asJava + }).asJava + } + + } + + private def handleLoadSuccess(txnIds: mutable.Buffer[Integer], acc: CollectionAccumulator[Int]): Unit = { + txnIds.foreach(txnId => acc.add(txnId)) + } + + def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = { + // if task run failed, acc value will not be returned to driver, + // should abort all pre committed transactions inside the task + logger.info("load task failed, start aborting previously pre-committed transactions") + if (acc.isZero) { + logger.info("no pre-committed transactions, skip abort") + return + } + val abortFailedTxnIds = mutable.Buffer[Int]() + acc.value.asScala.foreach(txnId => { + Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) { + dorisStreamLoader.abort(txnId) + } match { + case Success(_) => + case Failure(_) => abortFailedTxnIds += txnId + } + }) + if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(",")) + acc.reset() + } + }