diff --git a/spark-doris-connector/pom.xml b/spark-doris-connector/pom.xml index 74a53cc0..4148a660 100644 --- a/spark-doris-connector/pom.xml +++ b/spark-doris-connector/pom.xml @@ -184,6 +184,15 @@ jackson-core ${fasterxml.jackson.version} + + + + com.mysql + mysql-connector-j + 8.0.33 + test + + diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java index de381ca1..7877bc88 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java @@ -20,6 +20,7 @@ public interface ConfigurationOptions { // doris fe node address String DORIS_FENODES = "doris.fenodes"; + String DORIS_QUERY_PORT = "doris.query.port"; String DORIS_DEFAULT_CLUSTER = "default_cluster"; @@ -70,7 +71,7 @@ public interface ConfigurationOptions { int SINK_BATCH_SIZE_DEFAULT = 100000; String DORIS_SINK_MAX_RETRIES = "doris.sink.max-retries"; - int SINK_MAX_RETRIES_DEFAULT = 1; + int SINK_MAX_RETRIES_DEFAULT = 0; String DORIS_MAX_FILTER_RATIO = "doris.max.filter.ratio"; diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index 1631eeb3..3d5bf362 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -101,7 +101,6 @@ public class DorisStreamLoad implements Serializable { private String FIELD_DELIMITER; private final String LINE_DELIMITER; private boolean streamingPassthrough = false; - private final Integer batchSize; private final boolean enable2PC; private final Integer txnRetries; private final Integer txnIntervalMs; @@ -133,8 +132,6 @@ public DorisStreamLoad(SparkSettings settings) { LINE_DELIMITER = escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n")); this.streamingPassthrough = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH, ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT); - this.batchSize = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE, - ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT); this.enable2PC = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC, ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT); this.txnRetries = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_RETRIES, @@ -215,7 +212,6 @@ public long load(Iterator rows, StructType schema) this.loadUrlStr = loadUrlStr; HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC); RecordBatchInputStream recodeBatchInputStream = new RecordBatchInputStream(RecordBatch.newBuilder(rows) - .batchSize(batchSize) .format(fileType) .sep(FIELD_DELIMITER) .delim(LINE_DELIMITER) diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java index 4ce297f2..e471d5b8 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java @@ -36,11 +36,6 @@ public class RecordBatch { */ private final Iterator iterator; - /** - * batch size for single load - */ - private final int batchSize; - /** * stream load format */ @@ -63,10 +58,9 @@ public class RecordBatch { private final boolean addDoubleQuotes; - private RecordBatch(Iterator iterator, int batchSize, String format, String sep, byte[] delim, + private RecordBatch(Iterator iterator, String format, String sep, byte[] delim, StructType schema, boolean addDoubleQuotes) { this.iterator = iterator; - this.batchSize = batchSize; this.format = format; this.sep = sep; this.delim = delim; @@ -78,10 +72,6 @@ public Iterator getIterator() { return iterator; } - public int getBatchSize() { - return batchSize; - } - public String getFormat() { return format; } @@ -112,8 +102,6 @@ public static class Builder { private final Iterator iterator; - private int batchSize; - private String format; private String sep; @@ -128,11 +116,6 @@ public Builder(Iterator iterator) { this.iterator = iterator; } - public Builder batchSize(int batchSize) { - this.batchSize = batchSize; - return this; - } - public Builder format(String format) { this.format = format; return this; @@ -159,7 +142,7 @@ public Builder addDoubleQuotes(boolean addDoubleQuotes) { } public RecordBatch build() { - return new RecordBatch(iterator, batchSize, format, sep, delim, schema, addDoubleQuotes); + return new RecordBatch(iterator, format, sep, delim, schema, addDoubleQuotes); } } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java index f70809b3..047ac3be 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java @@ -40,8 +40,6 @@ public class RecordBatchInputStream extends InputStream { public static final Logger LOG = LoggerFactory.getLogger(RecordBatchInputStream.class); - private static final int DEFAULT_BUF_SIZE = 4096; - /** * Load record batch */ @@ -55,12 +53,12 @@ public class RecordBatchInputStream extends InputStream { /** * record buffer */ - private ByteBuffer buffer = ByteBuffer.allocate(0); - /** - * record count has been read - */ - private int readCount = 0; + private ByteBuffer lineBuf = ByteBuffer.allocate(0);; + + private ByteBuffer delimBuf = ByteBuffer.allocate(0); + + private final byte[] delim; /** * streaming mode pass through data without process @@ -70,31 +68,42 @@ public class RecordBatchInputStream extends InputStream { public RecordBatchInputStream(RecordBatch recordBatch, boolean passThrough) { this.recordBatch = recordBatch; this.passThrough = passThrough; + this.delim = recordBatch.getDelim(); } @Override public int read() throws IOException { try { - if (buffer.remaining() == 0 && endOfBatch()) { - return -1; // End of stream + if (lineBuf.remaining() == 0 && endOfBatch()) { + return -1; + } + + if (delimBuf != null && delimBuf.remaining() > 0) { + return delimBuf.get() & 0xff; } } catch (DorisException e) { throw new IOException(e); } - return buffer.get() & 0xFF; + return lineBuf.get() & 0xFF; } @Override public int read(byte[] b, int off, int len) throws IOException { try { - if (buffer.remaining() == 0 && endOfBatch()) { - return -1; // End of stream + if (lineBuf.remaining() == 0 && endOfBatch()) { + return -1; + } + + if (delimBuf != null && delimBuf.remaining() > 0) { + int bytesRead = Math.min(len, delimBuf.remaining()); + delimBuf.get(b, off, bytesRead); + return bytesRead; } } catch (DorisException e) { throw new IOException(e); } - int bytesRead = Math.min(len, buffer.remaining()); - buffer.get(b, off, bytesRead); + int bytesRead = Math.min(len, lineBuf.remaining()); + lineBuf.get(b, off, bytesRead); return bytesRead; } @@ -108,11 +117,12 @@ public int read(byte[] b, int off, int len) throws IOException { */ public boolean endOfBatch() throws DorisException { Iterator iterator = recordBatch.getIterator(); - if (readCount >= recordBatch.getBatchSize() || !iterator.hasNext()) { - return true; + if (iterator.hasNext()) { + readNext(iterator); + return false; } - readNext(iterator); - return false; + delimBuf = null; + return true; } /** @@ -125,60 +135,15 @@ private void readNext(Iterator iterator) throws DorisException { if (!iterator.hasNext()) { throw new ShouldNeverHappenException(); } - byte[] delim = recordBatch.getDelim(); byte[] rowBytes = rowToByte(iterator.next()); if (isFirst) { - ensureCapacity(rowBytes.length); - buffer.put(rowBytes); - buffer.flip(); + delimBuf = null; + lineBuf = ByteBuffer.wrap(rowBytes); isFirst = false; } else { - ensureCapacity(delim.length + rowBytes.length); - buffer.put(delim); - buffer.put(rowBytes); - buffer.flip(); - } - readCount++; - } - - /** - * Check if the buffer has enough capacity. - * - * @param need required buffer space - */ - private void ensureCapacity(int need) { - - int capacity = buffer.capacity(); - - if (need <= capacity) { - buffer.clear(); - return; - } - - // need to extend - int newCapacity = calculateNewCapacity(capacity, need); - LOG.info("expand buffer, min cap: {}, now cap: {}, new cap: {}", need, capacity, newCapacity); - buffer = ByteBuffer.allocate(newCapacity); - - } - - /** - * Calculate new capacity for buffer expansion. - * - * @param capacity current buffer capacity - * @param minCapacity required min buffer space - * @return new capacity - */ - private int calculateNewCapacity(int capacity, int minCapacity) { - int newCapacity = 0; - if (capacity == 0) { - newCapacity = DEFAULT_BUF_SIZE; - - } - while (newCapacity < minCapacity) { - newCapacity = newCapacity << 1; + delimBuf = ByteBuffer.wrap(delim); + lineBuf = ByteBuffer.wrap(rowBytes); } - return newCapacity; } /** @@ -220,5 +185,4 @@ private byte[] rowToByte(InternalRow row) throws DorisException { } - } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/jdbc/JdbcUtils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/jdbc/JdbcUtils.scala new file mode 100644 index 00000000..aab10328 --- /dev/null +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/jdbc/JdbcUtils.scala @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.spark.jdbc + +import java.sql.{Connection, DriverManager} +import java.util.Properties + +object JdbcUtils { + + def getJdbcUrl(host: String, port: Int): String = s"jdbc:mysql://$host:$port/information_schema" + + def getConnection(url: String, props: Properties): Connection = { + + DriverManager.getConnection(url, props) + } + + def getTruncateQuery(table: String): String = s"TRUNCATE TABLE $table" + +} diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala index 57d20f3b..b1e9d84d 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala @@ -47,8 +47,8 @@ class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Long], txnIds.foreach(txnId => Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTnxIntervalMs), logger) { dorisStreamLoad.commit(txnId) - } match { - case Success(_) => + } () match { + case Success(_) => // do nothing case Failure(_) => failedTxnIds += txnId } ) @@ -68,8 +68,8 @@ class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Long], txnIds.foreach(txnId => Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTnxIntervalMs), logger) { dorisStreamLoad.abortById(txnId) - } match { - case Success(_) => + } () match { + case Success(_) => // do nothing case Failure(_) => failedTxnIds += txnId }) if (failedTxnIds.nonEmpty) { diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisRelation.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisRelation.scala index 049d5a25..fe7e63d7 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisRelation.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisRelation.scala @@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -98,6 +98,7 @@ private[sql] class DorisRelation( } data.write.format(DorisSourceProvider.SHORT_NAME) .options(insertCfg) + .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) .save() } } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala index 94fab9e6..ac04401f 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala @@ -17,7 +17,10 @@ package org.apache.doris.spark.sql -import org.apache.doris.spark.cfg.SparkSettings +import org.apache.commons.lang3.exception.ExceptionUtils +import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings} +import org.apache.doris.spark.exception.DorisException +import org.apache.doris.spark.jdbc.JdbcUtils import org.apache.doris.spark.sql.DorisSourceProvider.SHORT_NAME import org.apache.doris.spark.writer.DorisWriter import org.apache.spark.SparkConf @@ -28,7 +31,10 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} import org.slf4j.{Logger, LoggerFactory} +import java.util.Properties import scala.collection.JavaConverters.mapAsJavaMapConverter +import scala.util.control.Breaks +import scala.util.{Failure, Success, Try} private[sql] class DorisSourceProvider extends DataSourceRegister with RelationProvider @@ -54,6 +60,13 @@ private[sql] class DorisSourceProvider extends DataSourceRegister val sparkSettings = new SparkSettings(sqlContext.sparkContext.getConf) sparkSettings.merge(Utils.params(parameters, logger).asJava) + + mode match { + case SaveMode.Overwrite => + truncateTable(sparkSettings) + case _: SaveMode => // do nothing + } + // init stream loader val writer = new DorisWriter(sparkSettings) writer.write(data) @@ -79,6 +92,50 @@ private[sql] class DorisSourceProvider extends DataSourceRegister sparkSettings.merge(Utils.params(parameters, logger).asJava) new DorisStreamLoadSink(sqlContext, sparkSettings) } + + private def truncateTable(sparkSettings: SparkSettings): Unit = { + + val feNodes = sparkSettings.getProperty(ConfigurationOptions.DORIS_FENODES) + val port = sparkSettings.getIntegerProperty(ConfigurationOptions.DORIS_QUERY_PORT) + require(feNodes != null && feNodes.nonEmpty, "doris.fenodes cannot be null or empty") + require(port != null, "doris.query.port cannot be null") + val feNodesArr = feNodes.split(",") + val breaks = new Breaks + + var success = false + var exOption: Option[Exception] = None + + breaks.breakable { + feNodesArr.foreach(feNode => { + Try { + val host = feNode.split(":")(0) + val url = JdbcUtils.getJdbcUrl(host, port) + val props = new Properties() + props.setProperty("user", sparkSettings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER)) + props.setProperty("password", sparkSettings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD)) + val conn = JdbcUtils.getConnection(url, props) + val statement = conn.createStatement() + val tableIdentifier = sparkSettings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER) + val query = JdbcUtils.getTruncateQuery(tableIdentifier) + statement.execute(query) + success = true + logger.info(s"truncate table $tableIdentifier success") + } match { + case Success(_) => breaks.break() + case Failure(e: Exception) => + exOption = Some(e) + logger.warn(s"truncate table failed on $feNode, error: {}", ExceptionUtils.getStackTrace(e)) + } + }) + + } + + if (!success) { + throw new DorisException("truncate table failed", exOption.get) + } + + } + } object DorisSourceProvider { diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala index 44baa95d..e806059c 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala @@ -32,7 +32,6 @@ import org.slf4j.LoggerFactory import java.sql.Timestamp import java.time.{LocalDateTime, ZoneOffset} import scala.collection.JavaConversions._ -import scala.collection.mutable private[spark] object SchemaUtils { private val logger = LoggerFactory.getLogger(SchemaUtils.getClass.getSimpleName.stripSuffix("$")) @@ -166,13 +165,14 @@ private[spark] object SchemaUtils { case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale) case at: ArrayType => val arrayData = row.getArray(ordinal) - var i = 0 - val buffer = mutable.Buffer[Any]() - while (i < arrayData.numElements()) { - if (arrayData.isNullAt(i)) buffer += null else buffer += rowColumnValue(arrayData, i, at.elementType) - i += 1 + if (arrayData == null) DataUtil.NULL_VALUE + else if(arrayData.numElements() == 0) "[]" + else { + (0 until arrayData.numElements()).map(i => { + if (arrayData.isNullAt(i)) null else rowColumnValue(arrayData, i, at.elementType) + }).mkString("[", ",", "]") } - s"[${buffer.mkString(",")}]" + case mt: MapType => val mapData = row.getMap(ordinal) val keys = mapData.keyArray() diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala index 54976a7d..89103892 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala @@ -34,6 +34,7 @@ import scala.util.{Failure, Success, Try} private[spark] object Utils { /** * quote column name + * * @param colName column name * @return quoted column name */ @@ -41,8 +42,9 @@ private[spark] object Utils { /** * compile a filter to Doris FE filter format. - * @param filter filter to be compile - * @param dialect jdbc dialect to translate value to sql format + * + * @param filter filter to be compile + * @param dialect jdbc dialect to translate value to sql format * @param inValueLengthLimit max length of in value array * @return if Doris FE can handle this filter, return None if Doris FE can not handled it. */ @@ -87,6 +89,7 @@ private[spark] object Utils { /** * Escape special characters in SQL string literals. + * * @param value The string to be escaped. * @return Escaped string. */ @@ -95,6 +98,7 @@ private[spark] object Utils { /** * Converts value to SQL expression. + * * @param value The value to be converted. * @return Converted value. */ @@ -108,16 +112,17 @@ private[spark] object Utils { /** * check parameters validation and process it. + * * @param parameters parameters from rdd and spark conf - * @param logger slf4j logger + * @param logger slf4j logger * @return processed parameters */ def params(parameters: Map[String, String], logger: Logger) = { // '.' seems to be problematic when specifying the options val dottedParams = parameters.map { case (k, v) => - if (k.startsWith("sink.properties.") || k.startsWith("doris.sink.properties.")){ - (k,v) - }else { + if (k.startsWith("sink.properties.") || k.startsWith("doris.sink.properties.")) { + (k, v) + } else { (k.replace('_', '.'), v) } } @@ -141,7 +146,7 @@ private[spark] object Utils { case (k, v) => if (k.startsWith("doris.")) (k, v) else ("doris." + k, v) - }.map{ + }.map { case (ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD, _) => logger.error(s"${ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD} cannot use in Doris Datasource.") throw new DorisException(s"${ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD} cannot use in" + @@ -165,13 +170,14 @@ private[spark] object Utils { // validate path is available finalParams.getOrElse(ConfigurationOptions.DORIS_TABLE_IDENTIFIER, - throw new DorisException("table identifier must be specified for doris table identifier.")) + throw new DorisException("table identifier must be specified for doris table identifier.")) finalParams } @tailrec - def retry[R, T <: Throwable : ClassTag](retryTimes: Int, interval: Duration, logger: Logger)(f: => R): Try[R] = { + def retry[R, T <: Throwable : ClassTag](retryTimes: Int, interval: Duration, logger: Logger) + (f: => R)(h: => Unit): Try[R] = { assert(retryTimes >= 0) val result = Try(f) result match { @@ -182,7 +188,8 @@ private[spark] object Utils { logger.warn(s"Execution failed caused by: ", exception) logger.warn(s"$retryTimes times retry remaining, the next attempt will be in ${interval.toMillis} ms") LockSupport.parkNanos(interval.toNanos) - retry(retryTimes - 1, interval, logger)(f) + h + retry(retryTimes - 1, interval, logger)(f)(h) case Failure(exception) => Failure(exception) } } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala index b829fef6..770d7009 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala @@ -21,7 +21,6 @@ import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings} import org.apache.doris.spark.listener.DorisTransactionListener import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad} import org.apache.doris.spark.sql.Utils -import org.apache.spark.TaskContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType @@ -32,10 +31,10 @@ import java.io.IOException import java.time.Duration import java.util import java.util.Objects -import java.util.concurrent.locks.LockSupport import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.util.{Failure, Success, Try} +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Success} class DorisWriter(settings: SparkSettings) extends Serializable { @@ -44,9 +43,18 @@ class DorisWriter(settings: SparkSettings) extends Serializable { private val sinkTaskPartitionSize: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE) private val sinkTaskUseRepartition: Boolean = settings.getProperty(ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION, ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION_DEFAULT.toString).toBoolean + + private val maxRetryTimes: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_MAX_RETRIES, + ConfigurationOptions.SINK_MAX_RETRIES_DEFAULT) + private val batchSize: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE, + ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT) private val batchInterValMs: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS, ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS_DEFAULT) + if (maxRetryTimes > 0) { + logger.info(s"batch retry enabled, size is $batchSize, interval is $batchInterValMs") + } + private val enable2PC: Boolean = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC, ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT) private val sinkTxnIntervalMs: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS, @@ -77,7 +85,6 @@ class DorisWriter(settings: SparkSettings) extends Serializable { private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType) => Long): Unit = { val sc = dataFrame.sqlContext.sparkContext - logger.info(s"applicationAttemptId: ${sc.applicationAttemptId.getOrElse(-1)}") val preCommittedTxnAcc = sc.collectionAccumulator[Long]("preCommittedTxnAcc") if (enable2PC) { sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader, sinkTxnIntervalMs, sinkTxnRetries)) @@ -89,19 +96,22 @@ class DorisWriter(settings: SparkSettings) extends Serializable { resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize) } resultRdd.foreachPartition(iterator => { - val intervalNanos = Duration.ofMillis(batchInterValMs.toLong).toNanos + while (iterator.hasNext) { - Try { - loadFunc(iterator.asJava, schema) - } match { - case Success(txnId) => if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc) + val batchIterator = new BatchIterator[InternalRow](iterator, batchSize, maxRetryTimes > 0) + val retry = Utils.retry[Int, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) _ + retry(loadFunc(batchIterator.asJava, schema))(batchIterator.reset()) match { + case Success(txnId) => + if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc) + batchIterator.close() case Failure(e) => if (enable2PC) handleLoadFailure(preCommittedTxnAcc) + batchIterator.close() throw new IOException( s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node.", e) } - LockSupport.parkNanos(intervalNanos) } + }) } @@ -120,10 +130,10 @@ class DorisWriter(settings: SparkSettings) extends Serializable { } val abortFailedTxnIds = mutable.Buffer[Long]() acc.value.asScala.foreach(txnId => { - Utils.retry[Unit, Exception](sinkTxnRetries, Duration.ofMillis(sinkTxnIntervalMs), logger) { + Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) { dorisStreamLoader.abortById(txnId) - } match { - case Success(_) => + }() match { + case Success(_) => // do nothing case Failure(_) => abortFailedTxnIds += txnId } }) @@ -131,5 +141,60 @@ class DorisWriter(settings: SparkSettings) extends Serializable { acc.reset() } + /** + * iterator for batch load + * if retry time is greater than zero, enable batch retry and put batch data into buffer + * + * @param iterator parent iterator + * @param batchSize batch size + * @param batchRetryEnable whether enable batch retry + * @tparam T data type + */ + private class BatchIterator[T](iterator: Iterator[T], batchSize: Int, batchRetryEnable: Boolean) extends Iterator[T] { + + private val buffer: ArrayBuffer[T] = if (batchRetryEnable) new ArrayBuffer[T](batchSize) else ArrayBuffer.empty[T] + + private var recordCount = 0 + + private var isReset = false + + override def hasNext: Boolean = recordCount < batchSize && iterator.hasNext + + override def next(): T = { + recordCount += 1 + if (batchRetryEnable) { + if (isReset && buffer.nonEmpty) { + buffer(recordCount) + } else { + val elem = iterator.next + buffer += elem + elem + } + } else { + iterator.next + } + } + + /** + * reset record count for re-read + */ + def reset(): Unit = { + recordCount = 0 + isReset = true + logger.info("batch iterator is reset") + } + + /** + * clear buffer when buffer is not empty + */ + def close(): Unit = { + if (buffer.nonEmpty) { + buffer.clear() + logger.info("buffer is cleared and batch iterator is closed") + } + } + + } + }