From 5ad4d0e300ef3bf706a6d5ca17a4b71aaa533581 Mon Sep 17 00:00:00 2001 From: Mathieu Gabelle <54168385+mgabelle@users.noreply.github.com> Date: Wed, 30 Oct 2024 11:14:54 +0100 Subject: [PATCH] feat: add queries to druid (#421) --- .../io/kestra/plugin/jdbc/druid/Queries.java | 58 ++++++++ .../plugin/jdbc/druid/DruidQueriesTest.java | 52 ++++++++ .../plugin/jdbc/AbstractJdbcQueries.java | 124 +++++++++++------- 3 files changed, 189 insertions(+), 45 deletions(-) create mode 100644 plugin-jdbc-druid/src/main/java/io/kestra/plugin/jdbc/druid/Queries.java create mode 100644 plugin-jdbc-druid/src/test/java/io/kestra/plugin/jdbc/druid/DruidQueriesTest.java diff --git a/plugin-jdbc-druid/src/main/java/io/kestra/plugin/jdbc/druid/Queries.java b/plugin-jdbc-druid/src/main/java/io/kestra/plugin/jdbc/druid/Queries.java new file mode 100644 index 00000000..9e80a4e1 --- /dev/null +++ b/plugin-jdbc-druid/src/main/java/io/kestra/plugin/jdbc/druid/Queries.java @@ -0,0 +1,58 @@ +package io.kestra.plugin.jdbc.druid; + +import io.kestra.core.models.annotations.Example; +import io.kestra.core.models.annotations.Plugin; +import io.kestra.core.models.tasks.RunnableTask; +import io.kestra.plugin.jdbc.AbstractCellConverter; +import io.kestra.plugin.jdbc.AbstractJdbcQueries; +import io.kestra.plugin.jdbc.AbstractJdbcQuery; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; +import lombok.experimental.SuperBuilder; +import org.apache.calcite.avatica.remote.Driver; + +import java.sql.*; +import java.time.ZoneId; + +@SuperBuilder +@ToString +@EqualsAndHashCode +@Getter +@NoArgsConstructor +@Schema( + title = "Perform multiple queries on an Apache Druid database." +) +@Plugin( + examples = { + @Example( + title = "Multiple queries on an Apache Druid database.", + full = true, + code = """ + id: druid_queries + namespace: company.team + + tasks: + - id: queries + type: io.kestra.plugin.jdbc.druid.Queries + url: jdbc:avatica:remote:url=http://localhost:8888/druid/v2/sql/avatica/;transparent_reconnection=true + sql: | + SELECT * FROM wikiticker; SELECT * FROM product; + fetchType: STORE + """ + ) + } +) +public class Queries extends AbstractJdbcQueries implements RunnableTask { + @Override + protected AbstractCellConverter getCellConverter(ZoneId zoneId) { + return new DruidCellConverter(zoneId); + } + + @Override + public void registerDriver() throws SQLException { + DriverManager.registerDriver(new Driver()); + } +} diff --git a/plugin-jdbc-druid/src/test/java/io/kestra/plugin/jdbc/druid/DruidQueriesTest.java b/plugin-jdbc-druid/src/test/java/io/kestra/plugin/jdbc/druid/DruidQueriesTest.java new file mode 100644 index 00000000..4ea388c6 --- /dev/null +++ b/plugin-jdbc-druid/src/test/java/io/kestra/plugin/jdbc/druid/DruidQueriesTest.java @@ -0,0 +1,52 @@ +package io.kestra.plugin.jdbc.druid; + +import io.kestra.core.junit.annotations.KestraTest; +import io.kestra.core.models.property.Property; +import io.kestra.core.runners.RunContext; +import io.kestra.core.runners.RunContextFactory; +import io.kestra.plugin.jdbc.AbstractJdbcQueries; +import jakarta.inject.Inject; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static io.kestra.core.models.tasks.common.FetchType.FETCH; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; + +@KestraTest +public class DruidQueriesTest { + @Inject + RunContextFactory runContextFactory; + + @BeforeAll + public static void startServer() throws Exception { + DruidTestHelper.initServer(); + } + + @Test + void insertAndQuery() throws Exception { + RunContext runContext = runContextFactory.of(Map.of()); + + Queries task = Queries.builder() + .url("jdbc:avatica:remote:url=http://localhost:8888/druid/v2/sql/avatica/;transparent_reconnection=true") + .fetchType(FETCH) + .timeZoneId("Europe/Paris") + .parameters(Property.of(Map.of( + "limitLow", 2, + "limitHigh", 10)) + ) + .sql(""" + select * from products limit :limitLow; + select * from products limit :limitHigh; + """) + .build(); + + AbstractJdbcQueries.MultiQueryOutput runOutput = task.run(runContext); + assertThat(runOutput.getOutputs(), notNullValue()); + assertThat(runOutput.getOutputs().getFirst().getRows().size(), is(2)); + assertThat(runOutput.getOutputs().getLast().getRows().size(), is(10)); + } +} + diff --git a/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQueries.java b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQueries.java index d56234a2..8324eb2c 100644 --- a/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQueries.java +++ b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQueries.java @@ -12,6 +12,7 @@ import java.io.BufferedWriter; import java.io.File; import java.io.FileWriter; +import java.io.IOException; import java.sql.*; import java.util.*; import java.util.function.Consumer; @@ -29,14 +30,20 @@ public abstract class AbstractJdbcQueries extends AbstractJdbcBaseQuery implemen @Builder.Default protected Property transaction = Property.of(Boolean.TRUE); + @Getter(AccessLevel.NONE) + private Connection conn = null; + + @Getter(AccessLevel.NONE) + private PreparedStatement stmt = null; + + @Getter(AccessLevel.NONE) + private Savepoint savepoint = null; + public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Exception { Logger logger = runContext.logger(); AbstractCellConverter cellConverter = getCellConverter(this.zoneId()); final boolean isTransactional = this.transaction.as(runContext, Boolean.class); - Connection conn = null; - PreparedStatement stmt = null; - Savepoint savepoint = null; long totalSize = 0L; List outputList = new LinkedList<>(); @@ -44,7 +51,7 @@ public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Ex //Create connection in not autocommit mode to enable rollback on error conn = this.connection(runContext); conn.setAutoCommit(false); - savepoint = conn.setSavepoint(); + savepoint = initializeSavepoint(conn); String sqlRendered = runContext.render(this.sql, this.additionalVars); String[] queries = sqlRendered.split(";[^']"); @@ -61,40 +68,7 @@ public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Ex //Create Outputs while (hasMoreResult || stmt.getUpdateCount() != -1) { - try(ResultSet rs = stmt.getResultSet()) { - //When sql is not a select statement skip output creation - if(rs != null) { - AbstractJdbcQuery.Output.OutputBuilder output = AbstractJdbcQuery.Output.builder(); - //Populate result fro result set - long size = 0L; - switch (this.getFetchType()) { - case FETCH_ONE -> { - size = 1L; - output - .row(fetchResult(rs, cellConverter, conn)) - .size(size); - } - case STORE -> { - File tempFile = runContext.workingDir().createTempFile(".ion").toFile(); - try (BufferedWriter fileWriter = new BufferedWriter(new FileWriter(tempFile), FileSerde.BUFFER_SIZE)) { - size = fetchToFile(stmt, rs, fileWriter, cellConverter, conn); - } - output - .uri(runContext.storage().putFile(tempFile)) - .size(size); - } - case FETCH -> { - List> maps = new ArrayList<>(); - size = fetchResults(stmt, rs, maps, cellConverter, conn); - output - .rows(maps) - .size(size); - } - } - totalSize += size; - outputList.add(output.build()); - } - } + totalSize = extractResultsFromResultSet(runContext, cellConverter, totalSize, outputList); hasMoreResult = stmt.getMoreResults(); } } @@ -104,13 +78,73 @@ public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Ex return MultiQueryOutput.builder().outputs(outputList).build(); } catch (Exception e) { - if(isTransactional && conn != null && savepoint != null) { - conn.rollback(savepoint); - } + rollbackIfTransactional(isTransactional); throw new RuntimeException(e); } finally { - if(conn != null) { conn.close(); } - if(stmt != null) { stmt.close(); } + closeConnectionAndStatement(); + } + } + + private long extractResultsFromResultSet(RunContext runContext, AbstractCellConverter cellConverter, long totalSize, List outputList) throws SQLException, IOException { + try(ResultSet rs = stmt.getResultSet()) { + //When sql is not a select statement skip output creation + if(rs != null) { + Output.OutputBuilder output = Output.builder(); + //Populate result fro result set + long size = 0L; + switch (this.getFetchType()) { + case FETCH_ONE -> { + size = 1L; + output + .row(fetchResult(rs, cellConverter, conn)) + .size(size); + } + case STORE -> { + File tempFile = runContext.workingDir().createTempFile(".ion").toFile(); + try (BufferedWriter fileWriter = new BufferedWriter(new FileWriter(tempFile), FileSerde.BUFFER_SIZE)) { + size = fetchToFile(stmt, rs, fileWriter, cellConverter, conn); + } + output + .uri(runContext.storage().putFile(tempFile)) + .size(size); + } + case FETCH -> { + List> maps = new ArrayList<>(); + size = fetchResults(stmt, rs, maps, cellConverter, conn); + output + .rows(maps) + .size(size); + } + default -> throw new IllegalArgumentException("fetchType must be either FETCH, FETCH_ONE, STORE, or NONE"); + } + totalSize += size; + outputList.add(output.build()); + } + } + return totalSize; + } + + private void rollbackIfTransactional(boolean isTransactional) throws SQLException { + if(isTransactional && conn != null) { + if(savepoint != null) { + conn.rollback(savepoint); + return; + } + conn.rollback(); + } + } + + private void closeConnectionAndStatement() throws SQLException { + if(conn != null) { conn.close(); } + if(stmt != null) { stmt.close(); } + } + + private Savepoint initializeSavepoint(Connection conn) throws SQLException { + try { + return conn.setSavepoint(); + } catch (SQLException e) { + //Savepoint not supported by this driver + return null; } } @@ -142,7 +176,7 @@ private PreparedStatement createPreparedStatementAndPopulateParameters(RunContex } //Extract parameters in orders and replace them with '?' - String preparedSql = new String(sql); + String preparedSql = sql; Pattern pattern = Pattern.compile(" :\\w+"); Matcher matcher = pattern.matcher(preparedSql); @@ -154,7 +188,7 @@ private PreparedStatement createPreparedStatementAndPopulateParameters(RunContex preparedSql = matcher.replaceFirst( " ?"); matcher = pattern.matcher(preparedSql); } - PreparedStatement stmt = conn.prepareStatement(preparedSql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + stmt = conn.prepareStatement(preparedSql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); for(int i=0; i