Skip to content

Commit

Permalink
[fix] streaming write execution plan error (#135)
Browse files Browse the repository at this point in the history
* fix streaming write error and add json data pass through option
* handle stream pass through, force set read_json_by_line is true when format is json
  • Loading branch information
gnehil authored Sep 5, 2023
1 parent 71af841 commit 11f4976
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,10 @@ public interface ConfigurationOptions {
String DORIS_SINK_ENABLE_2PC = "doris.sink.enable-2pc";
boolean DORIS_SINK_ENABLE_2PC_DEFAULT = false;

/**
* pass through json data when sink to doris in streaming mode
*/
String DORIS_SINK_STREAMING_PASSTHROUGH = "doris.sink.streaming.passthrough";
boolean DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT = false;

}
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ public class DorisStreamLoad implements Serializable {

private boolean readJsonByLine = false;

private boolean streamingPassthrough = false;

public DorisStreamLoad(SparkSettings settings) {
String[] dbTable = settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\.");
this.db = dbTable[0];
Expand All @@ -121,6 +123,8 @@ public DorisStreamLoad(SparkSettings settings) {
}
}
LINE_DELIMITER = escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n"));
this.streamingPassthrough = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH,
ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT);
}

public String getLoadUrlStr() {
Expand Down Expand Up @@ -196,6 +200,38 @@ public List<Integer> loadV2(List<List<Object>> rows, String[] dfColumns, Boolean

}

public List<Integer> loadStream(List<List<Object>> rows, String[] dfColumns, Boolean enable2PC)
throws StreamLoadException, JsonProcessingException {

List<String> loadData;

if (this.streamingPassthrough) {
handleStreamPassThrough();
loadData = passthrough(rows);
} else {
loadData = parseLoadData(rows, dfColumns);
}

List<Integer> txnIds = new ArrayList<>(loadData.size());

try {
for (String data : loadData) {
txnIds.add(load(data, enable2PC));
}
} catch (StreamLoadException e) {
if (enable2PC && !txnIds.isEmpty()) {
LOG.error("load batch failed, abort previously pre-committed transactions");
for (Integer txnId : txnIds) {
abort(txnId);
}
}
throw e;
}

return txnIds;

}

public int load(String value, Boolean enable2PC) throws StreamLoadException {

String label = generateLoadLabel();
Expand Down Expand Up @@ -442,4 +478,18 @@ private String escapeString(String hexData) {
return hexData;
}

private void handleStreamPassThrough() {

if ("json".equalsIgnoreCase(fileType)) {
LOG.info("handle stream pass through, force set read_json_by_line is true for json format");
streamLoadProp.put("read_json_by_line", "true");
streamLoadProp.remove("strip_outer_array");
}

}

private List<String> passthrough(List<List<Object>> values) {
return values.stream().map(list -> list.get(0).toString()).collect(Collectors.toList());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ private[sql] class DorisStreamLoadSink(sqlContext: SQLContext, settings: SparkSe
if (batchId <= latestBatchId) {
logger.info(s"Skipping already committed batch $batchId")
} else {
writer.write(data)
writer.writeStream(data)
latestBatchId = batchId
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ import org.apache.doris.spark.listener.DorisTransactionListener
import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad}
import org.apache.doris.spark.sql.Utils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.CollectionAccumulator
import org.slf4j.{Logger, LoggerFactory}

import java.io.IOException
Expand Down Expand Up @@ -76,34 +79,90 @@ class DorisWriter(settings: SparkSettings) extends Serializable {
* flush data to Doris and do retry when flush error
*
*/
def flush(batch: Iterable[util.List[Object]], dfColumns: Array[String]): Unit = {
def flush(batch: Seq[util.List[Object]], dfColumns: Array[String]): Unit = {
Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) {
dorisStreamLoader.loadV2(batch.toList.asJava, dfColumns, enable2PC)
dorisStreamLoader.loadV2(batch.asJava, dfColumns, enable2PC)
} match {
case Success(txnIds) => if (enable2PC) txnIds.asScala.foreach(txnId => preCommittedTxnAcc.add(txnId))
case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc)
case Failure(e) =>
if (enable2PC) {
// if task run failed, acc value will not be returned to driver,
// should abort all pre committed transactions inside the task
logger.info("load task failed, start aborting previously pre-committed transactions")
val abortFailedTxnIds = mutable.Buffer[Int]()
preCommittedTxnAcc.value.asScala.foreach(txnId => {
Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
dorisStreamLoader.abort(txnId)
} match {
case Success(_) =>
case Failure(_) => abortFailedTxnIds += txnId
}
})
if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(","))
preCommittedTxnAcc.reset()
}
if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
throw new IOException(
s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e)
}
}

}

def writeStream(dataFrame: DataFrame): Unit = {

val sc = dataFrame.sqlContext.sparkContext
val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc")
if (enable2PC) {
sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader))
}

var resultRdd = dataFrame.queryExecution.toRdd
val schema = dataFrame.schema
val dfColumns = dataFrame.columns
if (Objects.nonNull(sinkTaskPartitionSize)) {
resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize)
}
resultRdd
.foreachPartition(partition => {
partition
.grouped(batchSize)
.foreach(batch =>
flush(batch, dfColumns))
})

/**
* flush data to Doris and do retry when flush error
*
*/
def flush(batch: Seq[InternalRow], dfColumns: Array[String]): Unit = {
Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) {
dorisStreamLoader.loadStream(convertToObjectList(batch, schema), dfColumns, enable2PC)
} match {
case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc)
case Failure(e) =>
if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
throw new IOException(
s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e)
}
}

def convertToObjectList(rows: Seq[InternalRow], schema: StructType): util.List[util.List[Object]] = {
rows.map(row => {
row.toSeq(schema).map(_.asInstanceOf[AnyRef]).toList.asJava
}).asJava
}

}

private def handleLoadSuccess(txnIds: mutable.Buffer[Integer], acc: CollectionAccumulator[Int]): Unit = {
txnIds.foreach(txnId => acc.add(txnId))
}

def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = {
// if task run failed, acc value will not be returned to driver,
// should abort all pre committed transactions inside the task
logger.info("load task failed, start aborting previously pre-committed transactions")
if (acc.isZero) {
logger.info("no pre-committed transactions, skip abort")
return
}
val abortFailedTxnIds = mutable.Buffer[Int]()
acc.value.asScala.foreach(txnId => {
Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
dorisStreamLoader.abort(txnId)
} match {
case Success(_) =>
case Failure(_) => abortFailedTxnIds += txnId
}
})
if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(","))
acc.reset()
}


}

0 comments on commit 11f4976

Please sign in to comment.