Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add queries to druid #421

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package io.kestra.plugin.jdbc.druid;

import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.tasks.RunnableTask;
import io.kestra.plugin.jdbc.AbstractCellConverter;
import io.kestra.plugin.jdbc.AbstractJdbcQueries;
import io.kestra.plugin.jdbc.AbstractJdbcQuery;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.experimental.SuperBuilder;
import org.apache.calcite.avatica.remote.Driver;

import java.sql.*;
import java.time.ZoneId;

@SuperBuilder
@ToString
@EqualsAndHashCode
@Getter
@NoArgsConstructor
@Schema(
title = "Perform multiple queries on an Apache Druid database."
)
@Plugin(
examples = {
@Example(
title = "Multiple queries on an Apache Druid database.",
full = true,
code = """
id: druid_queries
namespace: company.team

tasks:
- id: queries
type: io.kestra.plugin.jdbc.druid.Queries
url: jdbc:avatica:remote:url=http://localhost:8888/druid/v2/sql/avatica/;transparent_reconnection=true
sql: |
SELECT * FROM wikiticker; SELECT * FROM product;
fetchType: STORE
"""
)
}
)
public class Queries extends AbstractJdbcQueries implements RunnableTask<AbstractJdbcQueries.MultiQueryOutput> {
@Override
protected AbstractCellConverter getCellConverter(ZoneId zoneId) {
return new DruidCellConverter(zoneId);
}

@Override
public void registerDriver() throws SQLException {
DriverManager.registerDriver(new Driver());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.kestra.plugin.jdbc.druid;

import io.kestra.core.junit.annotations.KestraTest;
import io.kestra.core.models.property.Property;
import io.kestra.core.runners.RunContext;
import io.kestra.core.runners.RunContextFactory;
import io.kestra.plugin.jdbc.AbstractJdbcQueries;
import jakarta.inject.Inject;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.util.Map;

import static io.kestra.core.models.tasks.common.FetchType.FETCH;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;

@KestraTest
public class DruidQueriesTest {
@Inject
RunContextFactory runContextFactory;

@BeforeAll
public static void startServer() throws Exception {
DruidTestHelper.initServer();
}

@Test
void insertAndQuery() throws Exception {
RunContext runContext = runContextFactory.of(Map.of());

Queries task = Queries.builder()
.url("jdbc:avatica:remote:url=http://localhost:8888/druid/v2/sql/avatica/;transparent_reconnection=true")
.fetchType(FETCH)
.timeZoneId("Europe/Paris")
.parameters(Property.of(Map.of(
"limitLow", 2,
"limitHigh", 10))
)
.sql("""
select * from products limit :limitLow;
select * from products limit :limitHigh;
""")
.build();

AbstractJdbcQueries.MultiQueryOutput runOutput = task.run(runContext);
assertThat(runOutput.getOutputs(), notNullValue());
assertThat(runOutput.getOutputs().getFirst().getRows().size(), is(2));
assertThat(runOutput.getOutputs().getLast().getRows().size(), is(10));
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.sql.*;
import java.util.*;
import java.util.function.Consumer;
Expand All @@ -29,22 +30,28 @@ public abstract class AbstractJdbcQueries extends AbstractJdbcBaseQuery implemen
@Builder.Default
protected Property<Boolean> transaction = Property.of(Boolean.TRUE);

@Getter(AccessLevel.NONE)
private Connection conn = null;

@Getter(AccessLevel.NONE)
private PreparedStatement stmt = null;

@Getter(AccessLevel.NONE)
private Savepoint savepoint = null;

public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Exception {
Logger logger = runContext.logger();
AbstractCellConverter cellConverter = getCellConverter(this.zoneId());

final boolean isTransactional = this.transaction.as(runContext, Boolean.class);
Connection conn = null;
PreparedStatement stmt = null;
Savepoint savepoint = null;
long totalSize = 0L;
List<AbstractJdbcQuery.Output> outputList = new LinkedList<>();

try {
//Create connection in not autocommit mode to enable rollback on error
conn = this.connection(runContext);
conn.setAutoCommit(false);
savepoint = conn.setSavepoint();
savepoint = initializeSavepoint(conn);

String sqlRendered = runContext.render(this.sql, this.additionalVars);
String[] queries = sqlRendered.split(";[^']");
Expand All @@ -61,40 +68,7 @@ public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Ex

//Create Outputs
while (hasMoreResult || stmt.getUpdateCount() != -1) {
try(ResultSet rs = stmt.getResultSet()) {
//When sql is not a select statement skip output creation
if(rs != null) {
AbstractJdbcQuery.Output.OutputBuilder<?, ?> output = AbstractJdbcQuery.Output.builder();
//Populate result fro result set
long size = 0L;
switch (this.getFetchType()) {
case FETCH_ONE -> {
size = 1L;
output
.row(fetchResult(rs, cellConverter, conn))
.size(size);
}
case STORE -> {
File tempFile = runContext.workingDir().createTempFile(".ion").toFile();
try (BufferedWriter fileWriter = new BufferedWriter(new FileWriter(tempFile), FileSerde.BUFFER_SIZE)) {
size = fetchToFile(stmt, rs, fileWriter, cellConverter, conn);
}
output
.uri(runContext.storage().putFile(tempFile))
.size(size);
}
case FETCH -> {
List<Map<String, Object>> maps = new ArrayList<>();
size = fetchResults(stmt, rs, maps, cellConverter, conn);
output
.rows(maps)
.size(size);
}
}
totalSize += size;
outputList.add(output.build());
}
}
totalSize = extractResultsFromResultSet(runContext, cellConverter, totalSize, outputList);
hasMoreResult = stmt.getMoreResults();
}
}
Expand All @@ -104,13 +78,73 @@ public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Ex

return MultiQueryOutput.builder().outputs(outputList).build();
} catch (Exception e) {
if(isTransactional && conn != null && savepoint != null) {
conn.rollback(savepoint);
}
rollbackIfTransactional(isTransactional);
throw new RuntimeException(e);
} finally {
if(conn != null) { conn.close(); }
if(stmt != null) { stmt.close(); }
closeConnectionAndStatement();
}
}

private long extractResultsFromResultSet(RunContext runContext, AbstractCellConverter cellConverter, long totalSize, List<Output> outputList) throws SQLException, IOException {
try(ResultSet rs = stmt.getResultSet()) {
//When sql is not a select statement skip output creation
if(rs != null) {
Output.OutputBuilder<?, ?> output = Output.builder();
//Populate result fro result set
long size = 0L;
switch (this.getFetchType()) {
case FETCH_ONE -> {
size = 1L;
output
.row(fetchResult(rs, cellConverter, conn))
.size(size);
}
case STORE -> {
File tempFile = runContext.workingDir().createTempFile(".ion").toFile();
try (BufferedWriter fileWriter = new BufferedWriter(new FileWriter(tempFile), FileSerde.BUFFER_SIZE)) {
size = fetchToFile(stmt, rs, fileWriter, cellConverter, conn);
}
output
.uri(runContext.storage().putFile(tempFile))
.size(size);
}
case FETCH -> {
List<Map<String, Object>> maps = new ArrayList<>();
size = fetchResults(stmt, rs, maps, cellConverter, conn);
output
.rows(maps)
.size(size);
}
default -> throw new IllegalArgumentException("fetchType must be either FETCH, FETCH_ONE, STORE, or NONE");
}
totalSize += size;
outputList.add(output.build());
}
}
return totalSize;
}

private void rollbackIfTransactional(boolean isTransactional) throws SQLException {
if(isTransactional && conn != null) {
if(savepoint != null) {
conn.rollback(savepoint);
return;
}
conn.rollback();
}
}

private void closeConnectionAndStatement() throws SQLException {
if(conn != null) { conn.close(); }
if(stmt != null) { stmt.close(); }
}

private Savepoint initializeSavepoint(Connection conn) throws SQLException {
try {
return conn.setSavepoint();
} catch (SQLException e) {
//Savepoint not supported by this driver
return null;
}
}

Expand Down Expand Up @@ -142,7 +176,7 @@ private PreparedStatement createPreparedStatementAndPopulateParameters(RunContex
}

//Extract parameters in orders and replace them with '?'
String preparedSql = new String(sql);
String preparedSql = sql;
Pattern pattern = Pattern.compile(" :\\w+");
Matcher matcher = pattern.matcher(preparedSql);

Expand All @@ -154,7 +188,7 @@ private PreparedStatement createPreparedStatementAndPopulateParameters(RunContex
preparedSql = matcher.replaceFirst( " ?");
matcher = pattern.matcher(preparedSql);
}
PreparedStatement stmt = conn.prepareStatement(preparedSql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
stmt = conn.prepareStatement(preparedSql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);

for(int i=0; i<params.size(); i++) {
stmt.setObject(i+1, namedParamsRendered.get(params.get(i)));
Expand Down