Skip to content

Commit

Permalink
Fix data loss due to internal retries (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
gnehil authored Oct 8, 2023
1 parent ad5d62f commit 5410651
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import java.io.IOException;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
Expand All @@ -64,13 +65,14 @@
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.LockSupport;
import java.util.function.Consumer;


/**
* DorisStreamLoad
**/
public class DorisStreamLoad implements Serializable {
private static final String NULL_VALUE = "\\N";

private static final Logger LOG = LoggerFactory.getLogger(DorisStreamLoad.class);

Expand All @@ -97,7 +99,9 @@ public class DorisStreamLoad implements Serializable {
private final String LINE_DELIMITER;
private boolean streamingPassthrough = false;
private final Integer batchSize;
private boolean enable2PC;
private final boolean enable2PC;
private final Integer txnRetries;
private final Integer txnIntervalMs;

public DorisStreamLoad(SparkSettings settings) {
String[] dbTable = settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\.");
Expand Down Expand Up @@ -128,6 +132,10 @@ public DorisStreamLoad(SparkSettings settings) {
ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT);
this.enable2PC = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT);
this.txnRetries = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_RETRIES,
ConfigurationOptions.DORIS_SINK_TXN_RETRIES_DEFAULT);
this.txnIntervalMs = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS,
ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS_DEFAULT);
}

public String getLoadUrlStr() {
Expand Down Expand Up @@ -202,7 +210,19 @@ public int load(Iterator<InternalRow> rows, StructType schema)
HttpResponse httpResponse = httpClient.execute(httpPut);
loadResponse = new LoadResponse(httpResponse);
} catch (IOException e) {
throw new RuntimeException(e);
if (enable2PC) {
int retries = txnRetries;
while (retries > 0) {
try {
abortByLabel(label);
retries = 0;
} catch (StreamLoadException ex) {
LockSupport.parkNanos(Duration.ofMillis(txnIntervalMs).toNanos());
retries--;
}
}
}
throw new StreamLoadException("load execute failed", e);
}

if (loadResponse.status != HttpStatus.SC_OK) {
Expand Down Expand Up @@ -274,40 +294,85 @@ public void commit(int txnId) throws StreamLoadException {

}

public void abort(int txnId) throws StreamLoadException {
/**
* abort transaction by id
*
* @param txnId transaction id
* @throws StreamLoadException
*/
public void abortById(int txnId) throws StreamLoadException {

LOG.info("start abort transaction {}.", txnId);

try {
doAbort(httpPut -> httpPut.setHeader("txn_id", String.valueOf(txnId)));
} catch (StreamLoadException e) {
LOG.error("abort transaction by id: {} failed.", txnId);
throw e;
}

LOG.info("abort transaction {} succeed.", txnId);

}

/**
* abort transaction by label
*
* @param label label
* @throws StreamLoadException
*/
public void abortByLabel(String label) throws StreamLoadException {

LOG.info("start abort transaction by label: {}.", label);

try {
doAbort(httpPut -> httpPut.setHeader("label", label));
} catch (StreamLoadException e) {
LOG.error("abort transaction by label: {} failed.", label);
throw e;
}

LOG.info("abort transaction by label {} succeed.", label);

}

/**
* execute abort
*
* @param putConsumer http put process function
* @throws StreamLoadException
*/
private void doAbort(Consumer<HttpPut> putConsumer) throws StreamLoadException {

try (CloseableHttpClient client = getHttpClient()) {
String abortUrl = String.format(abortUrlPattern, getBackend(), db, tbl);
HttpPut httpPut = new HttpPut(abortUrl);
addCommonHeader(httpPut);
httpPut.setHeader("txn_operation", "abort");
httpPut.setHeader("txn_id", String.valueOf(txnId));
putConsumer.accept(httpPut);

CloseableHttpResponse response = client.execute(httpPut);
int statusCode = response.getStatusLine().getStatusCode();
if (statusCode != 200 || response.getEntity() == null) {
LOG.warn("abort transaction response: " + response.getStatusLine().toString());
throw new StreamLoadException("Fail to abort transaction " + txnId + " with url " + abortUrl);
LOG.error("abort transaction response: " + response.getStatusLine().toString());
throw new IOException("Fail to abort transaction with url " + abortUrl);
}

String loadResult = EntityUtils.toString(response.getEntity());
Map<String, String> res = MAPPER.readValue(loadResult, new TypeReference<HashMap<String, String>>() {
});
if (!"Success".equals(res.get("status"))) {
if (ResponseUtil.isCommitted(res.get("msg"))) {
throw new StreamLoadException("try abort committed transaction, " + "do you recover from old savepoint?");
throw new IOException("try abort committed transaction");
}
LOG.warn("Fail to abort transaction. txnId: {}, error: {}", txnId, res.get("msg"));
LOG.error("Fail to abort transaction. error: {}", res.get("msg"));
throw new IOException(String.format("Fail to abort transaction. error: %s", res.get("msg")));
}

} catch (IOException e) {
throw new StreamLoadException(e);
}

LOG.info("abort transaction {} succeed.", txnId);

}

public Map<String, String> getStreamLoadProp(SparkSettings sparkSettings) {
Expand Down Expand Up @@ -386,12 +451,21 @@ private String escapeString(String hexData) {
return hexData;
}

/**
* add common header to http request
*
* @param httpReq http request
*/
private void addCommonHeader(HttpRequestBase httpReq) {
httpReq.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded);
httpReq.setHeader(HttpHeaders.EXPECT, "100-continue");
httpReq.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; charset=UTF-8");
}

/**
* handle stream sink data pass through
* if load format is json, set read_json_by_line to true and remove strip_outer_array parameter
*/
private void handleStreamPassThrough() {

if ("json".equalsIgnoreCase(fileType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Int], d
logger.info("job run failed, start aborting transactions")
txnIds.foreach(txnId =>
Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTnxIntervalMs), logger) {
dorisStreamLoad.abort(txnId)
dorisStreamLoad.abortById(txnId)
} match {
case Success(_) =>
case Failure(_) => failedTxnIds += txnId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
import org.apache.doris.spark.listener.DorisTransactionListener
import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad}
import org.apache.doris.spark.sql.Utils
import org.apache.spark.TaskContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType
Expand All @@ -31,46 +32,52 @@ import java.io.IOException
import java.time.Duration
import java.util
import java.util.Objects
import java.util.concurrent.locks.LockSupport
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.{Failure, Success}
import scala.util.{Failure, Success, Try}

class DorisWriter(settings: SparkSettings) extends Serializable {

private val logger: Logger = LoggerFactory.getLogger(classOf[DorisWriter])

val batchSize: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE,
ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT)
private val maxRetryTimes: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_MAX_RETRIES,
ConfigurationOptions.SINK_MAX_RETRIES_DEFAULT)
private val sinkTaskPartitionSize: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE)
private val sinkTaskUseRepartition: Boolean = settings.getProperty(ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION,
ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION_DEFAULT.toString).toBoolean
private val batchInterValMs: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS,
ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS_DEFAULT)

private val enable2PC: Boolean = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT);
ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT)
private val sinkTxnIntervalMs: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS,
ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS_DEFAULT)
private val sinkTxnRetries: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_RETRIES,
ConfigurationOptions.DORIS_SINK_TXN_RETRIES_DEFAULT)

private val dorisStreamLoader: DorisStreamLoad = CachedDorisStreamLoadClient.getOrCreate(settings)

/**
* write data in batch mode
*
* @param dataFrame source dataframe
*/
def write(dataFrame: DataFrame): Unit = {
doWrite(dataFrame, dorisStreamLoader.load)
}

/**
* write data in stream mode
*
* @param dataFrame source dataframe
*/
def writeStream(dataFrame: DataFrame): Unit = {
doWrite(dataFrame, dorisStreamLoader.loadStream)
}

private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType) => Int): Unit = {



val sc = dataFrame.sqlContext.sparkContext
logger.info(s"applicationAttemptId: ${sc.applicationAttemptId.getOrElse(-1)}")
val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc")
if (enable2PC) {
sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader, sinkTxnIntervalMs, sinkTxnRetries))
Expand All @@ -82,17 +89,18 @@ class DorisWriter(settings: SparkSettings) extends Serializable {
resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize)
}
resultRdd.foreachPartition(iterator => {
val intervalNanos = Duration.ofMillis(batchInterValMs.toLong).toNanos
while (iterator.hasNext) {
// do load batch with retries
Utils.retry[Int, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) {
Try {
loadFunc(iterator.asJava, schema)
} match {
case Success(txnId) => if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc)
case Failure(e) =>
if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
throw new IOException(
s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e)
s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node.", e)
}
LockSupport.parkNanos(intervalNanos)
}
})

Expand All @@ -113,7 +121,7 @@ class DorisWriter(settings: SparkSettings) extends Serializable {
val abortFailedTxnIds = mutable.Buffer[Int]()
acc.value.asScala.foreach(txnId => {
Utils.retry[Unit, Exception](sinkTxnRetries, Duration.ofMillis(sinkTxnIntervalMs), logger) {
dorisStreamLoader.abort(txnId)
dorisStreamLoader.abortById(txnId)
} match {
case Success(_) =>
case Failure(_) => abortFailedTxnIds += txnId
Expand Down

0 comments on commit 5410651

Please sign in to comment.