Skip to content

Commit

Permalink
feat: add queries to druid (#421)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgabelle authored Oct 30, 2024
1 parent c859e34 commit 5ad4d0e
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 45 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package io.kestra.plugin.jdbc.druid;

import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.tasks.RunnableTask;
import io.kestra.plugin.jdbc.AbstractCellConverter;
import io.kestra.plugin.jdbc.AbstractJdbcQueries;
import io.kestra.plugin.jdbc.AbstractJdbcQuery;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.experimental.SuperBuilder;
import org.apache.calcite.avatica.remote.Driver;

import java.sql.*;
import java.time.ZoneId;

@SuperBuilder
@ToString
@EqualsAndHashCode
@Getter
@NoArgsConstructor
@Schema(
title = "Perform multiple queries on an Apache Druid database."
)
@Plugin(
examples = {
@Example(
title = "Multiple queries on an Apache Druid database.",
full = true,
code = """
id: druid_queries
namespace: company.team
tasks:
- id: queries
type: io.kestra.plugin.jdbc.druid.Queries
url: jdbc:avatica:remote:url=http://localhost:8888/druid/v2/sql/avatica/;transparent_reconnection=true
sql: |
SELECT * FROM wikiticker; SELECT * FROM product;
fetchType: STORE
"""
)
}
)
public class Queries extends AbstractJdbcQueries implements RunnableTask<AbstractJdbcQueries.MultiQueryOutput> {
@Override
protected AbstractCellConverter getCellConverter(ZoneId zoneId) {
return new DruidCellConverter(zoneId);
}

@Override
public void registerDriver() throws SQLException {
DriverManager.registerDriver(new Driver());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.kestra.plugin.jdbc.druid;

import io.kestra.core.junit.annotations.KestraTest;
import io.kestra.core.models.property.Property;
import io.kestra.core.runners.RunContext;
import io.kestra.core.runners.RunContextFactory;
import io.kestra.plugin.jdbc.AbstractJdbcQueries;
import jakarta.inject.Inject;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.util.Map;

import static io.kestra.core.models.tasks.common.FetchType.FETCH;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;

@KestraTest
public class DruidQueriesTest {
@Inject
RunContextFactory runContextFactory;

@BeforeAll
public static void startServer() throws Exception {
DruidTestHelper.initServer();
}

@Test
void insertAndQuery() throws Exception {
RunContext runContext = runContextFactory.of(Map.of());

Queries task = Queries.builder()
.url("jdbc:avatica:remote:url=http://localhost:8888/druid/v2/sql/avatica/;transparent_reconnection=true")
.fetchType(FETCH)
.timeZoneId("Europe/Paris")
.parameters(Property.of(Map.of(
"limitLow", 2,
"limitHigh", 10))
)
.sql("""
select * from products limit :limitLow;
select * from products limit :limitHigh;
""")
.build();

AbstractJdbcQueries.MultiQueryOutput runOutput = task.run(runContext);
assertThat(runOutput.getOutputs(), notNullValue());
assertThat(runOutput.getOutputs().getFirst().getRows().size(), is(2));
assertThat(runOutput.getOutputs().getLast().getRows().size(), is(10));
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.sql.*;
import java.util.*;
import java.util.function.Consumer;
Expand All @@ -29,22 +30,28 @@ public abstract class AbstractJdbcQueries extends AbstractJdbcBaseQuery implemen
@Builder.Default
protected Property<Boolean> transaction = Property.of(Boolean.TRUE);

@Getter(AccessLevel.NONE)
private Connection conn = null;

@Getter(AccessLevel.NONE)
private PreparedStatement stmt = null;

@Getter(AccessLevel.NONE)
private Savepoint savepoint = null;

public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Exception {
Logger logger = runContext.logger();
AbstractCellConverter cellConverter = getCellConverter(this.zoneId());

final boolean isTransactional = this.transaction.as(runContext, Boolean.class);
Connection conn = null;
PreparedStatement stmt = null;
Savepoint savepoint = null;
long totalSize = 0L;
List<AbstractJdbcQuery.Output> outputList = new LinkedList<>();

try {
//Create connection in not autocommit mode to enable rollback on error
conn = this.connection(runContext);
conn.setAutoCommit(false);
savepoint = conn.setSavepoint();
savepoint = initializeSavepoint(conn);

String sqlRendered = runContext.render(this.sql, this.additionalVars);
String[] queries = sqlRendered.split(";[^']");
Expand All @@ -61,40 +68,7 @@ public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Ex

//Create Outputs
while (hasMoreResult || stmt.getUpdateCount() != -1) {
try(ResultSet rs = stmt.getResultSet()) {
//When sql is not a select statement skip output creation
if(rs != null) {
AbstractJdbcQuery.Output.OutputBuilder<?, ?> output = AbstractJdbcQuery.Output.builder();
//Populate result fro result set
long size = 0L;
switch (this.getFetchType()) {
case FETCH_ONE -> {
size = 1L;
output
.row(fetchResult(rs, cellConverter, conn))
.size(size);
}
case STORE -> {
File tempFile = runContext.workingDir().createTempFile(".ion").toFile();
try (BufferedWriter fileWriter = new BufferedWriter(new FileWriter(tempFile), FileSerde.BUFFER_SIZE)) {
size = fetchToFile(stmt, rs, fileWriter, cellConverter, conn);
}
output
.uri(runContext.storage().putFile(tempFile))
.size(size);
}
case FETCH -> {
List<Map<String, Object>> maps = new ArrayList<>();
size = fetchResults(stmt, rs, maps, cellConverter, conn);
output
.rows(maps)
.size(size);
}
}
totalSize += size;
outputList.add(output.build());
}
}
totalSize = extractResultsFromResultSet(runContext, cellConverter, totalSize, outputList);
hasMoreResult = stmt.getMoreResults();
}
}
Expand All @@ -104,13 +78,73 @@ public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Ex

return MultiQueryOutput.builder().outputs(outputList).build();
} catch (Exception e) {
if(isTransactional && conn != null && savepoint != null) {
conn.rollback(savepoint);
}
rollbackIfTransactional(isTransactional);
throw new RuntimeException(e);
} finally {
if(conn != null) { conn.close(); }
if(stmt != null) { stmt.close(); }
closeConnectionAndStatement();
}
}

private long extractResultsFromResultSet(RunContext runContext, AbstractCellConverter cellConverter, long totalSize, List<Output> outputList) throws SQLException, IOException {
try(ResultSet rs = stmt.getResultSet()) {
//When sql is not a select statement skip output creation
if(rs != null) {
Output.OutputBuilder<?, ?> output = Output.builder();
//Populate result fro result set
long size = 0L;
switch (this.getFetchType()) {
case FETCH_ONE -> {
size = 1L;
output
.row(fetchResult(rs, cellConverter, conn))
.size(size);
}
case STORE -> {
File tempFile = runContext.workingDir().createTempFile(".ion").toFile();
try (BufferedWriter fileWriter = new BufferedWriter(new FileWriter(tempFile), FileSerde.BUFFER_SIZE)) {
size = fetchToFile(stmt, rs, fileWriter, cellConverter, conn);
}
output
.uri(runContext.storage().putFile(tempFile))
.size(size);
}
case FETCH -> {
List<Map<String, Object>> maps = new ArrayList<>();
size = fetchResults(stmt, rs, maps, cellConverter, conn);
output
.rows(maps)
.size(size);
}
default -> throw new IllegalArgumentException("fetchType must be either FETCH, FETCH_ONE, STORE, or NONE");
}
totalSize += size;
outputList.add(output.build());
}
}
return totalSize;
}

private void rollbackIfTransactional(boolean isTransactional) throws SQLException {
if(isTransactional && conn != null) {
if(savepoint != null) {
conn.rollback(savepoint);
return;
}
conn.rollback();
}
}

private void closeConnectionAndStatement() throws SQLException {
if(conn != null) { conn.close(); }
if(stmt != null) { stmt.close(); }
}

private Savepoint initializeSavepoint(Connection conn) throws SQLException {
try {
return conn.setSavepoint();
} catch (SQLException e) {
//Savepoint not supported by this driver
return null;
}
}

Expand Down Expand Up @@ -142,7 +176,7 @@ private PreparedStatement createPreparedStatementAndPopulateParameters(RunContex
}

//Extract parameters in orders and replace them with '?'
String preparedSql = new String(sql);
String preparedSql = sql;
Pattern pattern = Pattern.compile(" :\\w+");
Matcher matcher = pattern.matcher(preparedSql);

Expand All @@ -154,7 +188,7 @@ private PreparedStatement createPreparedStatementAndPopulateParameters(RunContex
preparedSql = matcher.replaceFirst( " ?");
matcher = pattern.matcher(preparedSql);
}
PreparedStatement stmt = conn.prepareStatement(preparedSql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
stmt = conn.prepareStatement(preparedSql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);

for(int i=0; i<params.size(); i++) {
stmt.setObject(i+1, namedParamsRendered.get(params.get(i)));
Expand Down

0 comments on commit 5ad4d0e

Please sign in to comment.