From 9dd57b004afd91320d493c18ff53eb91bea3125d Mon Sep 17 00:00:00 2001 From: gnehil Date: Wed, 25 Oct 2023 15:45:20 +0800 Subject: [PATCH] [improvement] batch load retry (#148) Co-authored-by: gnehil --- .../doris/spark/cfg/ConfigurationOptions.java | 2 +- .../doris/spark/load/DorisStreamLoad.java | 4 - .../apache/doris/spark/load/RecordBatch.java | 21 +---- .../spark/load/RecordBatchInputStream.java | 16 +--- .../listener/DorisTransactionListener.scala | 8 +- .../org/apache/doris/spark/sql/Utils.scala | 27 ++++-- .../doris/spark/writer/DorisWriter.scala | 91 ++++++++++++++++--- 7 files changed, 107 insertions(+), 62 deletions(-) diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java index a6767f08..a144fb8c 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java @@ -70,7 +70,7 @@ public interface ConfigurationOptions { int SINK_BATCH_SIZE_DEFAULT = 100000; String DORIS_SINK_MAX_RETRIES = "doris.sink.max-retries"; - int SINK_MAX_RETRIES_DEFAULT = 1; + int SINK_MAX_RETRIES_DEFAULT = 0; String DORIS_MAX_FILTER_RATIO = "doris.max.filter.ratio"; diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index c524a4c6..338ffbef 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -98,7 +98,6 @@ public class DorisStreamLoad implements Serializable { private String FIELD_DELIMITER; private final String LINE_DELIMITER; private boolean streamingPassthrough = false; - private final Integer batchSize; private final boolean enable2PC; private final Integer txnRetries; private final Integer txnIntervalMs; @@ -128,8 +127,6 @@ public DorisStreamLoad(SparkSettings settings) { LINE_DELIMITER = escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n")); this.streamingPassthrough = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH, ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT); - this.batchSize = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE, - ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT); this.enable2PC = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC, ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT); this.txnRetries = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_RETRIES, @@ -200,7 +197,6 @@ public int load(Iterator rows, StructType schema) this.loadUrlStr = loadUrlStr; HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC); RecordBatchInputStream recodeBatchInputStream = new RecordBatchInputStream(RecordBatch.newBuilder(rows) - .batchSize(batchSize) .format(fileType) .sep(FIELD_DELIMITER) .delim(LINE_DELIMITER) diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java index 4ce297f2..e471d5b8 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java @@ -36,11 +36,6 @@ public class RecordBatch { */ private final Iterator iterator; - /** - * batch size for single load - */ - private final int batchSize; - /** * stream load format */ @@ -63,10 +58,9 @@ public class RecordBatch { private final boolean addDoubleQuotes; - private RecordBatch(Iterator iterator, int batchSize, String format, String sep, byte[] delim, + private RecordBatch(Iterator iterator, String format, String sep, byte[] delim, StructType schema, boolean addDoubleQuotes) { this.iterator = iterator; - this.batchSize = batchSize; this.format = format; this.sep = sep; this.delim = delim; @@ -78,10 +72,6 @@ public Iterator getIterator() { return iterator; } - public int getBatchSize() { - return batchSize; - } - public String getFormat() { return format; } @@ -112,8 +102,6 @@ public static class Builder { private final Iterator iterator; - private int batchSize; - private String format; private String sep; @@ -128,11 +116,6 @@ public Builder(Iterator iterator) { this.iterator = iterator; } - public Builder batchSize(int batchSize) { - this.batchSize = batchSize; - return this; - } - public Builder format(String format) { this.format = format; return this; @@ -159,7 +142,7 @@ public Builder addDoubleQuotes(boolean addDoubleQuotes) { } public RecordBatch build() { - return new RecordBatch(iterator, batchSize, format, sep, delim, schema, addDoubleQuotes); + return new RecordBatch(iterator, format, sep, delim, schema, addDoubleQuotes); } } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java index a361c399..047ac3be 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java @@ -60,11 +60,6 @@ public class RecordBatchInputStream extends InputStream { private final byte[] delim; - /** - * record count has been read - */ - private int readCount = 0; - /** * streaming mode pass through data without process */ @@ -122,12 +117,12 @@ public int read(byte[] b, int off, int len) throws IOException { */ public boolean endOfBatch() throws DorisException { Iterator iterator = recordBatch.getIterator(); - if (readCount >= recordBatch.getBatchSize() || !iterator.hasNext()) { - delimBuf = null; - return true; + if (iterator.hasNext()) { + readNext(iterator); + return false; } - readNext(iterator); - return false; + delimBuf = null; + return true; } /** @@ -149,7 +144,6 @@ private void readNext(Iterator iterator) throws DorisException { delimBuf = ByteBuffer.wrap(delim); lineBuf = ByteBuffer.wrap(rowBytes); } - readCount++; } /** diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala index e5991de3..e670a30b 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala @@ -47,8 +47,8 @@ class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Int], d txnIds.foreach(txnId => Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTnxIntervalMs), logger) { dorisStreamLoad.commit(txnId) - } match { - case Success(_) => + } () match { + case Success(_) => // do nothing case Failure(_) => failedTxnIds += txnId } ) @@ -68,8 +68,8 @@ class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Int], d txnIds.foreach(txnId => Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTnxIntervalMs), logger) { dorisStreamLoad.abortById(txnId) - } match { - case Success(_) => + } () match { + case Success(_) => // do nothing case Failure(_) => failedTxnIds += txnId }) if (failedTxnIds.nonEmpty) { diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala index 54976a7d..89103892 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala @@ -34,6 +34,7 @@ import scala.util.{Failure, Success, Try} private[spark] object Utils { /** * quote column name + * * @param colName column name * @return quoted column name */ @@ -41,8 +42,9 @@ private[spark] object Utils { /** * compile a filter to Doris FE filter format. - * @param filter filter to be compile - * @param dialect jdbc dialect to translate value to sql format + * + * @param filter filter to be compile + * @param dialect jdbc dialect to translate value to sql format * @param inValueLengthLimit max length of in value array * @return if Doris FE can handle this filter, return None if Doris FE can not handled it. */ @@ -87,6 +89,7 @@ private[spark] object Utils { /** * Escape special characters in SQL string literals. + * * @param value The string to be escaped. * @return Escaped string. */ @@ -95,6 +98,7 @@ private[spark] object Utils { /** * Converts value to SQL expression. + * * @param value The value to be converted. * @return Converted value. */ @@ -108,16 +112,17 @@ private[spark] object Utils { /** * check parameters validation and process it. + * * @param parameters parameters from rdd and spark conf - * @param logger slf4j logger + * @param logger slf4j logger * @return processed parameters */ def params(parameters: Map[String, String], logger: Logger) = { // '.' seems to be problematic when specifying the options val dottedParams = parameters.map { case (k, v) => - if (k.startsWith("sink.properties.") || k.startsWith("doris.sink.properties.")){ - (k,v) - }else { + if (k.startsWith("sink.properties.") || k.startsWith("doris.sink.properties.")) { + (k, v) + } else { (k.replace('_', '.'), v) } } @@ -141,7 +146,7 @@ private[spark] object Utils { case (k, v) => if (k.startsWith("doris.")) (k, v) else ("doris." + k, v) - }.map{ + }.map { case (ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD, _) => logger.error(s"${ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD} cannot use in Doris Datasource.") throw new DorisException(s"${ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD} cannot use in" + @@ -165,13 +170,14 @@ private[spark] object Utils { // validate path is available finalParams.getOrElse(ConfigurationOptions.DORIS_TABLE_IDENTIFIER, - throw new DorisException("table identifier must be specified for doris table identifier.")) + throw new DorisException("table identifier must be specified for doris table identifier.")) finalParams } @tailrec - def retry[R, T <: Throwable : ClassTag](retryTimes: Int, interval: Duration, logger: Logger)(f: => R): Try[R] = { + def retry[R, T <: Throwable : ClassTag](retryTimes: Int, interval: Duration, logger: Logger) + (f: => R)(h: => Unit): Try[R] = { assert(retryTimes >= 0) val result = Try(f) result match { @@ -182,7 +188,8 @@ private[spark] object Utils { logger.warn(s"Execution failed caused by: ", exception) logger.warn(s"$retryTimes times retry remaining, the next attempt will be in ${interval.toMillis} ms") LockSupport.parkNanos(interval.toNanos) - retry(retryTimes - 1, interval, logger)(f) + h + retry(retryTimes - 1, interval, logger)(f)(h) case Failure(exception) => Failure(exception) } } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala index a8c414ec..6498bea6 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala @@ -21,7 +21,6 @@ import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings} import org.apache.doris.spark.listener.DorisTransactionListener import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad} import org.apache.doris.spark.sql.Utils -import org.apache.spark.TaskContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType @@ -32,10 +31,10 @@ import java.io.IOException import java.time.Duration import java.util import java.util.Objects -import java.util.concurrent.locks.LockSupport import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.util.{Failure, Success, Try} +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Success} class DorisWriter(settings: SparkSettings) extends Serializable { @@ -44,9 +43,18 @@ class DorisWriter(settings: SparkSettings) extends Serializable { private val sinkTaskPartitionSize: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE) private val sinkTaskUseRepartition: Boolean = settings.getProperty(ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION, ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION_DEFAULT.toString).toBoolean + + private val maxRetryTimes: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_MAX_RETRIES, + ConfigurationOptions.SINK_MAX_RETRIES_DEFAULT) + private val batchSize: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE, + ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT) private val batchInterValMs: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS, ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS_DEFAULT) + if (maxRetryTimes > 0) { + logger.info(s"batch retry enabled, size is $batchSize, interval is $batchInterValMs") + } + private val enable2PC: Boolean = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC, ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT) private val sinkTxnIntervalMs: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS, @@ -77,7 +85,6 @@ class DorisWriter(settings: SparkSettings) extends Serializable { private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType) => Int): Unit = { val sc = dataFrame.sqlContext.sparkContext - logger.info(s"applicationAttemptId: ${sc.applicationAttemptId.getOrElse(-1)}") val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc") if (enable2PC) { sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader, sinkTxnIntervalMs, sinkTxnRetries)) @@ -89,19 +96,22 @@ class DorisWriter(settings: SparkSettings) extends Serializable { resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize) } resultRdd.foreachPartition(iterator => { - val intervalNanos = Duration.ofMillis(batchInterValMs.toLong).toNanos + while (iterator.hasNext) { - Try { - loadFunc(iterator.asJava, schema) - } match { - case Success(txnId) => if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc) + val batchIterator = new BatchIterator[InternalRow](iterator, batchSize, maxRetryTimes > 0) + val retry = Utils.retry[Int, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) _ + retry(loadFunc(batchIterator.asJava, schema))(batchIterator.reset()) match { + case Success(txnId) => + if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc) + batchIterator.close() case Failure(e) => if (enable2PC) handleLoadFailure(preCommittedTxnAcc) + batchIterator.close() throw new IOException( s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node.", e) } - LockSupport.parkNanos(intervalNanos) } + }) } @@ -120,10 +130,10 @@ class DorisWriter(settings: SparkSettings) extends Serializable { } val abortFailedTxnIds = mutable.Buffer[Int]() acc.value.asScala.foreach(txnId => { - Utils.retry[Unit, Exception](sinkTxnRetries, Duration.ofMillis(sinkTxnIntervalMs), logger) { + Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) { dorisStreamLoader.abortById(txnId) - } match { - case Success(_) => + }() match { + case Success(_) => // do nothing case Failure(_) => abortFailedTxnIds += txnId } }) @@ -131,5 +141,60 @@ class DorisWriter(settings: SparkSettings) extends Serializable { acc.reset() } + /** + * iterator for batch load + * if retry time is greater than zero, enable batch retry and put batch data into buffer + * + * @param iterator parent iterator + * @param batchSize batch size + * @param batchRetryEnable whether enable batch retry + * @tparam T data type + */ + private class BatchIterator[T](iterator: Iterator[T], batchSize: Int, batchRetryEnable: Boolean) extends Iterator[T] { + + private val buffer: ArrayBuffer[T] = if (batchRetryEnable) new ArrayBuffer[T](batchSize) else ArrayBuffer.empty[T] + + private var recordCount = 0 + + private var isReset = false + + override def hasNext: Boolean = recordCount < batchSize && iterator.hasNext + + override def next(): T = { + recordCount += 1 + if (batchRetryEnable) { + if (isReset && buffer.nonEmpty) { + buffer(recordCount) + } else { + val elem = iterator.next + buffer += elem + elem + } + } else { + iterator.next + } + } + + /** + * reset record count for re-read + */ + def reset(): Unit = { + recordCount = 0 + isReset = true + logger.info("batch iterator is reset") + } + + /** + * clear buffer when buffer is not empty + */ + def close(): Unit = { + if (buffer.nonEmpty) { + buffer.clear() + logger.info("buffer is cleared and batch iterator is closed") + } + } + + } + }