Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: cleanup AbstractJdbcQueries ensuring no memory leak #468

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,30 @@
import io.kestra.core.models.property.Property;
import io.kestra.core.runners.RunContext;
import io.kestra.core.serializers.FileSerde;
import lombok.*;
import io.kestra.core.utils.Rethrow;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.experimental.SuperBuilder;
import org.slf4j.Logger;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.sql.*;
import java.util.*;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Savepoint;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand All @@ -30,15 +44,6 @@ public abstract class AbstractJdbcQueries extends AbstractJdbcBaseQuery implemen
@Builder.Default
protected Property<Boolean> transaction = Property.of(Boolean.TRUE);

@Getter(AccessLevel.NONE)
private Connection conn = null;

@Getter(AccessLevel.NONE)
private PreparedStatement stmt = null;

@Getter(AccessLevel.NONE)
private Savepoint savepoint = null;

public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Exception {
Logger logger = runContext.logger();
AbstractCellConverter cellConverter = getCellConverter(this.zoneId());
Expand All @@ -47,72 +52,90 @@ public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Ex
long totalSize = 0L;
List<AbstractJdbcQuery.Output> outputList = new LinkedList<>();

try {
//Create connection in not autocommit mode to enable rollback on error
conn = this.connection(runContext);
conn.setAutoCommit(false);
savepoint = initializeSavepoint(conn);
//Create connection in not autocommit mode to enable rollback on error
Connection connection = null;
Savepoint savepoint = null;
try {
connection = this.connection(runContext);
savepoint = initializeSavepoint(connection);

connection.setAutoCommit(false);

String sqlRendered = runContext.render(this.sql, this.additionalVars);
String[] queries = sqlRendered.split(";[^']");

for(String query : queries) {
for (String query : queries) {
//Create statement, execute
stmt = createPreparedStatementAndPopulateParameters(runContext, conn, query);
stmt.setFetchSize(this.getFetchSize());
logger.debug("Starting query: {}", query);
stmt.execute();

if(!isTransactional) {
conn.commit();
try (PreparedStatement stmt = prepareStatement(runContext, connection, query)) {
stmt.setFetchSize(this.getFetchSize());
logger.debug("Starting query: {}", query);
stmt.execute();
if (!isTransactional) {
connection.commit();
}
totalSize = extractResultsFromResultSet(connection, stmt, runContext, cellConverter, totalSize, outputList);
}
totalSize = extractResultsFromResultSet(runContext, cellConverter, totalSize, outputList);
}
conn.commit();

connection.commit();
runContext.metric(Counter.of("fetch.size", totalSize, this.tags()));

return MultiQueryOutput.builder().outputs(outputList).build();
} catch (Exception e) {
rollbackIfTransactional(isTransactional);
rollbackIfTransactional(connection, savepoint, isTransactional);
throw new RuntimeException(e);
} finally {
closeConnectionAndStatement(runContext);
safelyCloseConnection(runContext, connection);
}
}

private long extractResultsFromResultSet(RunContext runContext, AbstractCellConverter cellConverter, long totalSize, List<Output> outputList) throws SQLException, IOException {
try(ResultSet rs = stmt.getResultSet()) {
private static void safelyCloseConnection(final RunContext runContext, final Connection connection) {
try {
if (connection != null) {
connection.close();
}
} catch (SQLException e) {
runContext.logger().warn("Issue when closing the connection : {}", e.getMessage());
}
}

private long extractResultsFromResultSet(final Connection connection,
final PreparedStatement stmt,
final RunContext runContext,
final AbstractCellConverter cellConverter,
long totalSize,
final List<Output> outputList) throws SQLException, IOException {
try (ResultSet rs = stmt.getResultSet()) {
//When sql is not a select statement skip output creation
if(rs != null) {
if (rs != null) {
Output.OutputBuilder<?, ?> output = Output.builder();
//Populate result fro result set
long size = 0L;
switch (this.getFetchType()) {
case FETCH_ONE -> {
size = 1L;
output
.row(fetchResult(rs, cellConverter, conn))
.row(fetchResult(rs, cellConverter, connection))
.size(size);
}
case STORE -> {
File tempFile = runContext.workingDir().createTempFile(".ion").toFile();
try (BufferedWriter fileWriter = new BufferedWriter(new FileWriter(tempFile), FileSerde.BUFFER_SIZE)) {
size = fetchToFile(stmt, rs, fileWriter, cellConverter, conn);
size = fetchToFile(stmt, rs, fileWriter, cellConverter, connection);
}
output
.uri(runContext.storage().putFile(tempFile))
.size(size);
}
case FETCH -> {
List<Map<String, Object>> maps = new ArrayList<>();
size = fetchResults(stmt, rs, maps, cellConverter, conn);
size = fetchResults(stmt, rs, maps, cellConverter, connection);
output
.rows(maps)
.size(size);
}
case NONE -> runContext.logger().info("fetchType is set to NONE, no output will be returned");
default -> throw new IllegalArgumentException("fetchType must be either FETCH, FETCH_ONE, STORE, or NONE");
default ->
throw new IllegalArgumentException("fetchType must be either FETCH, FETCH_ONE, STORE, or NONE");
}
totalSize += size;
outputList.add(output.build());
Expand All @@ -121,26 +144,19 @@ private long extractResultsFromResultSet(RunContext runContext, AbstractCellConv
return totalSize;
}

private void rollbackIfTransactional(boolean isTransactional) throws SQLException {
if(isTransactional && conn != null) {
if(savepoint != null) {
conn.rollback(savepoint);
private static void rollbackIfTransactional(final Connection connection,
final Savepoint savepoint,
final boolean isTransactional) throws SQLException {
if (isTransactional) {
if (savepoint != null) {
connection.rollback(savepoint);
return;
}
conn.rollback();
connection.rollback();
}
}

private void closeConnectionAndStatement(RunContext runContext) {
try {
if(conn != null && !conn.isClosed()) { conn.close(); }
if(stmt != null && !stmt.isClosed()) { stmt.close(); }
} catch (SQLException e) {
runContext.logger().warn("Issue when closing the connection : {}", e.getMessage());
}
}

private Savepoint initializeSavepoint(Connection conn) throws SQLException {
private static Savepoint initializeSavepoint(final Connection conn) {
try {
return conn.setSavepoint();
} catch (SQLException e) {
Expand Down Expand Up @@ -168,37 +184,43 @@ public static class MultiQueryOutput implements io.kestra.core.models.tasks.Outp
List<AbstractJdbcQuery.Output> outputs;
}

private PreparedStatement createPreparedStatementAndPopulateParameters(RunContext runContext, Connection conn, String sql) throws SQLException, IllegalVariableEvaluationException {
//Inject named parameters (ex: ':param')
Map<String, Object> namedParamsRendered = this.getParameters() == null ? null : this.getParameters().asMap(runContext, String.class, Object.class);
private PreparedStatement prepareStatement(final RunContext runContext,
final Connection conn,
final String sql) throws SQLException, IllegalVariableEvaluationException {

if(namedParamsRendered == null || namedParamsRendered.isEmpty()) {
// Inject named parameters (ex: ':param')
Optional<Map<String, Object>> namedParamsRendered = Optional
.ofNullable(this.getParameters())
.map(Rethrow.throwFunction(it -> it.asMap(runContext, String.class, Object.class)));

if (namedParamsRendered.isEmpty()) {
return createPreparedStatement(conn, sql);
}

//Extract parameters in orders and replace them with '?'
String preparedSql = sql;
Pattern pattern = Pattern.compile(":\\w+");
Pattern pattern = Pattern.compile(":\\w+");
Matcher matcher = pattern.matcher(preparedSql);

List<String> params = new LinkedList<>();

while (matcher.find()) {
String param = matcher.group();
params.add(param.substring(1));
preparedSql = matcher.replaceFirst( "?");
preparedSql = matcher.replaceFirst("?");
matcher = pattern.matcher(preparedSql);
}
stmt = createPreparedStatement(conn, preparedSql);

for(int i=0; i<params.size(); i++) {
stmt.setObject(i+1, namedParamsRendered.get(params.get(i)));
PreparedStatement stmt = createPreparedStatement(conn, preparedSql);

for (int i = 0; i < params.size(); i++) {
stmt.setObject(i + 1, namedParamsRendered.get().get(params.get(i)));
}

return stmt;
}

protected PreparedStatement createPreparedStatement(Connection conn, String preparedSql) throws SQLException {
return conn.prepareStatement(preparedSql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
protected PreparedStatement createPreparedStatement(final Connection conn, final String sql) throws SQLException {
return conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
}
}
Loading