diff --git a/spark-doris-connector/pom.xml b/spark-doris-connector/pom.xml
index 74a53cc0..4148a660 100644
--- a/spark-doris-connector/pom.xml
+++ b/spark-doris-connector/pom.xml
@@ -184,6 +184,15 @@
jackson-core
${fasterxml.jackson.version}
+
+
+
+ com.mysql
+ mysql-connector-j
+ 8.0.33
+ test
+
+
diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
index de381ca1..7877bc88 100644
--- a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
+++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
@@ -20,6 +20,7 @@
public interface ConfigurationOptions {
// doris fe node address
String DORIS_FENODES = "doris.fenodes";
+ String DORIS_QUERY_PORT = "doris.query.port";
String DORIS_DEFAULT_CLUSTER = "default_cluster";
@@ -70,7 +71,7 @@ public interface ConfigurationOptions {
int SINK_BATCH_SIZE_DEFAULT = 100000;
String DORIS_SINK_MAX_RETRIES = "doris.sink.max-retries";
- int SINK_MAX_RETRIES_DEFAULT = 1;
+ int SINK_MAX_RETRIES_DEFAULT = 0;
String DORIS_MAX_FILTER_RATIO = "doris.max.filter.ratio";
diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
index 1631eeb3..3d5bf362 100644
--- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
+++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
@@ -101,7 +101,6 @@ public class DorisStreamLoad implements Serializable {
private String FIELD_DELIMITER;
private final String LINE_DELIMITER;
private boolean streamingPassthrough = false;
- private final Integer batchSize;
private final boolean enable2PC;
private final Integer txnRetries;
private final Integer txnIntervalMs;
@@ -133,8 +132,6 @@ public DorisStreamLoad(SparkSettings settings) {
LINE_DELIMITER = escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n"));
this.streamingPassthrough = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH,
ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT);
- this.batchSize = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE,
- ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT);
this.enable2PC = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT);
this.txnRetries = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_RETRIES,
@@ -215,7 +212,6 @@ public long load(Iterator rows, StructType schema)
this.loadUrlStr = loadUrlStr;
HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC);
RecordBatchInputStream recodeBatchInputStream = new RecordBatchInputStream(RecordBatch.newBuilder(rows)
- .batchSize(batchSize)
.format(fileType)
.sep(FIELD_DELIMITER)
.delim(LINE_DELIMITER)
diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
index 4ce297f2..e471d5b8 100644
--- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
+++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
@@ -36,11 +36,6 @@ public class RecordBatch {
*/
private final Iterator iterator;
- /**
- * batch size for single load
- */
- private final int batchSize;
-
/**
* stream load format
*/
@@ -63,10 +58,9 @@ public class RecordBatch {
private final boolean addDoubleQuotes;
- private RecordBatch(Iterator iterator, int batchSize, String format, String sep, byte[] delim,
+ private RecordBatch(Iterator iterator, String format, String sep, byte[] delim,
StructType schema, boolean addDoubleQuotes) {
this.iterator = iterator;
- this.batchSize = batchSize;
this.format = format;
this.sep = sep;
this.delim = delim;
@@ -78,10 +72,6 @@ public Iterator getIterator() {
return iterator;
}
- public int getBatchSize() {
- return batchSize;
- }
-
public String getFormat() {
return format;
}
@@ -112,8 +102,6 @@ public static class Builder {
private final Iterator iterator;
- private int batchSize;
-
private String format;
private String sep;
@@ -128,11 +116,6 @@ public Builder(Iterator iterator) {
this.iterator = iterator;
}
- public Builder batchSize(int batchSize) {
- this.batchSize = batchSize;
- return this;
- }
-
public Builder format(String format) {
this.format = format;
return this;
@@ -159,7 +142,7 @@ public Builder addDoubleQuotes(boolean addDoubleQuotes) {
}
public RecordBatch build() {
- return new RecordBatch(iterator, batchSize, format, sep, delim, schema, addDoubleQuotes);
+ return new RecordBatch(iterator, format, sep, delim, schema, addDoubleQuotes);
}
}
diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
index f70809b3..047ac3be 100644
--- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
+++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
@@ -40,8 +40,6 @@ public class RecordBatchInputStream extends InputStream {
public static final Logger LOG = LoggerFactory.getLogger(RecordBatchInputStream.class);
- private static final int DEFAULT_BUF_SIZE = 4096;
-
/**
* Load record batch
*/
@@ -55,12 +53,12 @@ public class RecordBatchInputStream extends InputStream {
/**
* record buffer
*/
- private ByteBuffer buffer = ByteBuffer.allocate(0);
- /**
- * record count has been read
- */
- private int readCount = 0;
+ private ByteBuffer lineBuf = ByteBuffer.allocate(0);;
+
+ private ByteBuffer delimBuf = ByteBuffer.allocate(0);
+
+ private final byte[] delim;
/**
* streaming mode pass through data without process
@@ -70,31 +68,42 @@ public class RecordBatchInputStream extends InputStream {
public RecordBatchInputStream(RecordBatch recordBatch, boolean passThrough) {
this.recordBatch = recordBatch;
this.passThrough = passThrough;
+ this.delim = recordBatch.getDelim();
}
@Override
public int read() throws IOException {
try {
- if (buffer.remaining() == 0 && endOfBatch()) {
- return -1; // End of stream
+ if (lineBuf.remaining() == 0 && endOfBatch()) {
+ return -1;
+ }
+
+ if (delimBuf != null && delimBuf.remaining() > 0) {
+ return delimBuf.get() & 0xff;
}
} catch (DorisException e) {
throw new IOException(e);
}
- return buffer.get() & 0xFF;
+ return lineBuf.get() & 0xFF;
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
try {
- if (buffer.remaining() == 0 && endOfBatch()) {
- return -1; // End of stream
+ if (lineBuf.remaining() == 0 && endOfBatch()) {
+ return -1;
+ }
+
+ if (delimBuf != null && delimBuf.remaining() > 0) {
+ int bytesRead = Math.min(len, delimBuf.remaining());
+ delimBuf.get(b, off, bytesRead);
+ return bytesRead;
}
} catch (DorisException e) {
throw new IOException(e);
}
- int bytesRead = Math.min(len, buffer.remaining());
- buffer.get(b, off, bytesRead);
+ int bytesRead = Math.min(len, lineBuf.remaining());
+ lineBuf.get(b, off, bytesRead);
return bytesRead;
}
@@ -108,11 +117,12 @@ public int read(byte[] b, int off, int len) throws IOException {
*/
public boolean endOfBatch() throws DorisException {
Iterator iterator = recordBatch.getIterator();
- if (readCount >= recordBatch.getBatchSize() || !iterator.hasNext()) {
- return true;
+ if (iterator.hasNext()) {
+ readNext(iterator);
+ return false;
}
- readNext(iterator);
- return false;
+ delimBuf = null;
+ return true;
}
/**
@@ -125,60 +135,15 @@ private void readNext(Iterator iterator) throws DorisException {
if (!iterator.hasNext()) {
throw new ShouldNeverHappenException();
}
- byte[] delim = recordBatch.getDelim();
byte[] rowBytes = rowToByte(iterator.next());
if (isFirst) {
- ensureCapacity(rowBytes.length);
- buffer.put(rowBytes);
- buffer.flip();
+ delimBuf = null;
+ lineBuf = ByteBuffer.wrap(rowBytes);
isFirst = false;
} else {
- ensureCapacity(delim.length + rowBytes.length);
- buffer.put(delim);
- buffer.put(rowBytes);
- buffer.flip();
- }
- readCount++;
- }
-
- /**
- * Check if the buffer has enough capacity.
- *
- * @param need required buffer space
- */
- private void ensureCapacity(int need) {
-
- int capacity = buffer.capacity();
-
- if (need <= capacity) {
- buffer.clear();
- return;
- }
-
- // need to extend
- int newCapacity = calculateNewCapacity(capacity, need);
- LOG.info("expand buffer, min cap: {}, now cap: {}, new cap: {}", need, capacity, newCapacity);
- buffer = ByteBuffer.allocate(newCapacity);
-
- }
-
- /**
- * Calculate new capacity for buffer expansion.
- *
- * @param capacity current buffer capacity
- * @param minCapacity required min buffer space
- * @return new capacity
- */
- private int calculateNewCapacity(int capacity, int minCapacity) {
- int newCapacity = 0;
- if (capacity == 0) {
- newCapacity = DEFAULT_BUF_SIZE;
-
- }
- while (newCapacity < minCapacity) {
- newCapacity = newCapacity << 1;
+ delimBuf = ByteBuffer.wrap(delim);
+ lineBuf = ByteBuffer.wrap(rowBytes);
}
- return newCapacity;
}
/**
@@ -220,5 +185,4 @@ private byte[] rowToByte(InternalRow row) throws DorisException {
}
-
}
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/jdbc/JdbcUtils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/jdbc/JdbcUtils.scala
new file mode 100644
index 00000000..aab10328
--- /dev/null
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/jdbc/JdbcUtils.scala
@@ -0,0 +1,34 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.jdbc
+
+import java.sql.{Connection, DriverManager}
+import java.util.Properties
+
+object JdbcUtils {
+
+ def getJdbcUrl(host: String, port: Int): String = s"jdbc:mysql://$host:$port/information_schema"
+
+ def getConnection(url: String, props: Properties): Connection = {
+
+ DriverManager.getConnection(url, props)
+ }
+
+ def getTruncateQuery(table: String): String = s"TRUNCATE TABLE $table"
+
+}
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
index 57d20f3b..b1e9d84d 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
@@ -47,8 +47,8 @@ class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Long],
txnIds.foreach(txnId =>
Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTnxIntervalMs), logger) {
dorisStreamLoad.commit(txnId)
- } match {
- case Success(_) =>
+ } () match {
+ case Success(_) => // do nothing
case Failure(_) => failedTxnIds += txnId
}
)
@@ -68,8 +68,8 @@ class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Long],
txnIds.foreach(txnId =>
Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTnxIntervalMs), logger) {
dorisStreamLoad.abortById(txnId)
- } match {
- case Success(_) =>
+ } () match {
+ case Success(_) => // do nothing
case Failure(_) => failedTxnIds += txnId
})
if (failedTxnIds.nonEmpty) {
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisRelation.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisRelation.scala
index 049d5a25..fe7e63d7 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisRelation.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisRelation.scala
@@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -98,6 +98,7 @@ private[sql] class DorisRelation(
}
data.write.format(DorisSourceProvider.SHORT_NAME)
.options(insertCfg)
+ .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
.save()
}
}
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
index 94fab9e6..ac04401f 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
@@ -17,7 +17,10 @@
package org.apache.doris.spark.sql
-import org.apache.doris.spark.cfg.SparkSettings
+import org.apache.commons.lang3.exception.ExceptionUtils
+import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
+import org.apache.doris.spark.exception.DorisException
+import org.apache.doris.spark.jdbc.JdbcUtils
import org.apache.doris.spark.sql.DorisSourceProvider.SHORT_NAME
import org.apache.doris.spark.writer.DorisWriter
import org.apache.spark.SparkConf
@@ -28,7 +31,10 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.slf4j.{Logger, LoggerFactory}
+import java.util.Properties
import scala.collection.JavaConverters.mapAsJavaMapConverter
+import scala.util.control.Breaks
+import scala.util.{Failure, Success, Try}
private[sql] class DorisSourceProvider extends DataSourceRegister
with RelationProvider
@@ -54,6 +60,13 @@ private[sql] class DorisSourceProvider extends DataSourceRegister
val sparkSettings = new SparkSettings(sqlContext.sparkContext.getConf)
sparkSettings.merge(Utils.params(parameters, logger).asJava)
+
+ mode match {
+ case SaveMode.Overwrite =>
+ truncateTable(sparkSettings)
+ case _: SaveMode => // do nothing
+ }
+
// init stream loader
val writer = new DorisWriter(sparkSettings)
writer.write(data)
@@ -79,6 +92,50 @@ private[sql] class DorisSourceProvider extends DataSourceRegister
sparkSettings.merge(Utils.params(parameters, logger).asJava)
new DorisStreamLoadSink(sqlContext, sparkSettings)
}
+
+ private def truncateTable(sparkSettings: SparkSettings): Unit = {
+
+ val feNodes = sparkSettings.getProperty(ConfigurationOptions.DORIS_FENODES)
+ val port = sparkSettings.getIntegerProperty(ConfigurationOptions.DORIS_QUERY_PORT)
+ require(feNodes != null && feNodes.nonEmpty, "doris.fenodes cannot be null or empty")
+ require(port != null, "doris.query.port cannot be null")
+ val feNodesArr = feNodes.split(",")
+ val breaks = new Breaks
+
+ var success = false
+ var exOption: Option[Exception] = None
+
+ breaks.breakable {
+ feNodesArr.foreach(feNode => {
+ Try {
+ val host = feNode.split(":")(0)
+ val url = JdbcUtils.getJdbcUrl(host, port)
+ val props = new Properties()
+ props.setProperty("user", sparkSettings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER))
+ props.setProperty("password", sparkSettings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD))
+ val conn = JdbcUtils.getConnection(url, props)
+ val statement = conn.createStatement()
+ val tableIdentifier = sparkSettings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER)
+ val query = JdbcUtils.getTruncateQuery(tableIdentifier)
+ statement.execute(query)
+ success = true
+ logger.info(s"truncate table $tableIdentifier success")
+ } match {
+ case Success(_) => breaks.break()
+ case Failure(e: Exception) =>
+ exOption = Some(e)
+ logger.warn(s"truncate table failed on $feNode, error: {}", ExceptionUtils.getStackTrace(e))
+ }
+ })
+
+ }
+
+ if (!success) {
+ throw new DorisException("truncate table failed", exOption.get)
+ }
+
+ }
+
}
object DorisSourceProvider {
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
index 44baa95d..e806059c 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
@@ -32,7 +32,6 @@ import org.slf4j.LoggerFactory
import java.sql.Timestamp
import java.time.{LocalDateTime, ZoneOffset}
import scala.collection.JavaConversions._
-import scala.collection.mutable
private[spark] object SchemaUtils {
private val logger = LoggerFactory.getLogger(SchemaUtils.getClass.getSimpleName.stripSuffix("$"))
@@ -166,13 +165,14 @@ private[spark] object SchemaUtils {
case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale)
case at: ArrayType =>
val arrayData = row.getArray(ordinal)
- var i = 0
- val buffer = mutable.Buffer[Any]()
- while (i < arrayData.numElements()) {
- if (arrayData.isNullAt(i)) buffer += null else buffer += rowColumnValue(arrayData, i, at.elementType)
- i += 1
+ if (arrayData == null) DataUtil.NULL_VALUE
+ else if(arrayData.numElements() == 0) "[]"
+ else {
+ (0 until arrayData.numElements()).map(i => {
+ if (arrayData.isNullAt(i)) null else rowColumnValue(arrayData, i, at.elementType)
+ }).mkString("[", ",", "]")
}
- s"[${buffer.mkString(",")}]"
+
case mt: MapType =>
val mapData = row.getMap(ordinal)
val keys = mapData.keyArray()
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
index 54976a7d..89103892 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
@@ -34,6 +34,7 @@ import scala.util.{Failure, Success, Try}
private[spark] object Utils {
/**
* quote column name
+ *
* @param colName column name
* @return quoted column name
*/
@@ -41,8 +42,9 @@ private[spark] object Utils {
/**
* compile a filter to Doris FE filter format.
- * @param filter filter to be compile
- * @param dialect jdbc dialect to translate value to sql format
+ *
+ * @param filter filter to be compile
+ * @param dialect jdbc dialect to translate value to sql format
* @param inValueLengthLimit max length of in value array
* @return if Doris FE can handle this filter, return None if Doris FE can not handled it.
*/
@@ -87,6 +89,7 @@ private[spark] object Utils {
/**
* Escape special characters in SQL string literals.
+ *
* @param value The string to be escaped.
* @return Escaped string.
*/
@@ -95,6 +98,7 @@ private[spark] object Utils {
/**
* Converts value to SQL expression.
+ *
* @param value The value to be converted.
* @return Converted value.
*/
@@ -108,16 +112,17 @@ private[spark] object Utils {
/**
* check parameters validation and process it.
+ *
* @param parameters parameters from rdd and spark conf
- * @param logger slf4j logger
+ * @param logger slf4j logger
* @return processed parameters
*/
def params(parameters: Map[String, String], logger: Logger) = {
// '.' seems to be problematic when specifying the options
val dottedParams = parameters.map { case (k, v) =>
- if (k.startsWith("sink.properties.") || k.startsWith("doris.sink.properties.")){
- (k,v)
- }else {
+ if (k.startsWith("sink.properties.") || k.startsWith("doris.sink.properties.")) {
+ (k, v)
+ } else {
(k.replace('_', '.'), v)
}
}
@@ -141,7 +146,7 @@ private[spark] object Utils {
case (k, v) =>
if (k.startsWith("doris.")) (k, v)
else ("doris." + k, v)
- }.map{
+ }.map {
case (ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD, _) =>
logger.error(s"${ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD} cannot use in Doris Datasource.")
throw new DorisException(s"${ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD} cannot use in" +
@@ -165,13 +170,14 @@ private[spark] object Utils {
// validate path is available
finalParams.getOrElse(ConfigurationOptions.DORIS_TABLE_IDENTIFIER,
- throw new DorisException("table identifier must be specified for doris table identifier."))
+ throw new DorisException("table identifier must be specified for doris table identifier."))
finalParams
}
@tailrec
- def retry[R, T <: Throwable : ClassTag](retryTimes: Int, interval: Duration, logger: Logger)(f: => R): Try[R] = {
+ def retry[R, T <: Throwable : ClassTag](retryTimes: Int, interval: Duration, logger: Logger)
+ (f: => R)(h: => Unit): Try[R] = {
assert(retryTimes >= 0)
val result = Try(f)
result match {
@@ -182,7 +188,8 @@ private[spark] object Utils {
logger.warn(s"Execution failed caused by: ", exception)
logger.warn(s"$retryTimes times retry remaining, the next attempt will be in ${interval.toMillis} ms")
LockSupport.parkNanos(interval.toNanos)
- retry(retryTimes - 1, interval, logger)(f)
+ h
+ retry(retryTimes - 1, interval, logger)(f)(h)
case Failure(exception) => Failure(exception)
}
}
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
index b829fef6..770d7009 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
@@ -21,7 +21,6 @@ import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
import org.apache.doris.spark.listener.DorisTransactionListener
import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad}
import org.apache.doris.spark.sql.Utils
-import org.apache.spark.TaskContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType
@@ -32,10 +31,10 @@ import java.io.IOException
import java.time.Duration
import java.util
import java.util.Objects
-import java.util.concurrent.locks.LockSupport
import scala.collection.JavaConverters._
import scala.collection.mutable
-import scala.util.{Failure, Success, Try}
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
class DorisWriter(settings: SparkSettings) extends Serializable {
@@ -44,9 +43,18 @@ class DorisWriter(settings: SparkSettings) extends Serializable {
private val sinkTaskPartitionSize: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE)
private val sinkTaskUseRepartition: Boolean = settings.getProperty(ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION,
ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION_DEFAULT.toString).toBoolean
+
+ private val maxRetryTimes: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_MAX_RETRIES,
+ ConfigurationOptions.SINK_MAX_RETRIES_DEFAULT)
+ private val batchSize: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE,
+ ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT)
private val batchInterValMs: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS,
ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS_DEFAULT)
+ if (maxRetryTimes > 0) {
+ logger.info(s"batch retry enabled, size is $batchSize, interval is $batchInterValMs")
+ }
+
private val enable2PC: Boolean = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT)
private val sinkTxnIntervalMs: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS,
@@ -77,7 +85,6 @@ class DorisWriter(settings: SparkSettings) extends Serializable {
private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType) => Long): Unit = {
val sc = dataFrame.sqlContext.sparkContext
- logger.info(s"applicationAttemptId: ${sc.applicationAttemptId.getOrElse(-1)}")
val preCommittedTxnAcc = sc.collectionAccumulator[Long]("preCommittedTxnAcc")
if (enable2PC) {
sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader, sinkTxnIntervalMs, sinkTxnRetries))
@@ -89,19 +96,22 @@ class DorisWriter(settings: SparkSettings) extends Serializable {
resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize)
}
resultRdd.foreachPartition(iterator => {
- val intervalNanos = Duration.ofMillis(batchInterValMs.toLong).toNanos
+
while (iterator.hasNext) {
- Try {
- loadFunc(iterator.asJava, schema)
- } match {
- case Success(txnId) => if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc)
+ val batchIterator = new BatchIterator[InternalRow](iterator, batchSize, maxRetryTimes > 0)
+ val retry = Utils.retry[Int, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) _
+ retry(loadFunc(batchIterator.asJava, schema))(batchIterator.reset()) match {
+ case Success(txnId) =>
+ if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc)
+ batchIterator.close()
case Failure(e) =>
if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
+ batchIterator.close()
throw new IOException(
s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node.", e)
}
- LockSupport.parkNanos(intervalNanos)
}
+
})
}
@@ -120,10 +130,10 @@ class DorisWriter(settings: SparkSettings) extends Serializable {
}
val abortFailedTxnIds = mutable.Buffer[Long]()
acc.value.asScala.foreach(txnId => {
- Utils.retry[Unit, Exception](sinkTxnRetries, Duration.ofMillis(sinkTxnIntervalMs), logger) {
+ Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
dorisStreamLoader.abortById(txnId)
- } match {
- case Success(_) =>
+ }() match {
+ case Success(_) => // do nothing
case Failure(_) => abortFailedTxnIds += txnId
}
})
@@ -131,5 +141,60 @@ class DorisWriter(settings: SparkSettings) extends Serializable {
acc.reset()
}
+ /**
+ * iterator for batch load
+ * if retry time is greater than zero, enable batch retry and put batch data into buffer
+ *
+ * @param iterator parent iterator
+ * @param batchSize batch size
+ * @param batchRetryEnable whether enable batch retry
+ * @tparam T data type
+ */
+ private class BatchIterator[T](iterator: Iterator[T], batchSize: Int, batchRetryEnable: Boolean) extends Iterator[T] {
+
+ private val buffer: ArrayBuffer[T] = if (batchRetryEnable) new ArrayBuffer[T](batchSize) else ArrayBuffer.empty[T]
+
+ private var recordCount = 0
+
+ private var isReset = false
+
+ override def hasNext: Boolean = recordCount < batchSize && iterator.hasNext
+
+ override def next(): T = {
+ recordCount += 1
+ if (batchRetryEnable) {
+ if (isReset && buffer.nonEmpty) {
+ buffer(recordCount)
+ } else {
+ val elem = iterator.next
+ buffer += elem
+ elem
+ }
+ } else {
+ iterator.next
+ }
+ }
+
+ /**
+ * reset record count for re-read
+ */
+ def reset(): Unit = {
+ recordCount = 0
+ isReset = true
+ logger.info("batch iterator is reset")
+ }
+
+ /**
+ * clear buffer when buffer is not empty
+ */
+ def close(): Unit = {
+ if (buffer.nonEmpty) {
+ buffer.clear()
+ logger.info("buffer is cleared and batch iterator is closed")
+ }
+ }
+
+ }
+
}