Skip to content

Commit

Permalink
fix(queries): refactor regex for named parameters (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgabelle authored Dec 2, 2024
1 parent 18c445b commit de56c68
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
import java.util.Map;
import java.util.Properties;

import static io.kestra.core.models.tasks.common.FetchType.FETCH;
import static io.kestra.core.models.tasks.common.FetchType.FETCH_ONE;
import static io.kestra.core.models.tasks.common.FetchType.*;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;
import static org.junit.jupiter.api.Assertions.assertThrows;
Expand Down Expand Up @@ -89,6 +88,75 @@ void testMultiSelectWithParameters() throws Exception {
assertThat("selected laptop", laptops.getFirst().get("brand"), is("Apple"));
}

@Test
void testMultiSelectWithParametersWithColonCloseTo() throws Exception {
RunContext runContext = runContextFactory.of(Collections.emptyMap());

Queries setup = Queries.builder()
.url(getUrl())
.username(getUsername())
.password(getPassword())
.fetchType(NONE)
.timeZoneId("Europe/Paris")
.sql("""
DROP TABLE IF EXISTS myusers CASCADE;
CREATE TABLE myusers (
id SERIAL PRIMARY KEY,
name VARCHAR(255) NOT NULL,
email VARCHAR(255) UNIQUE NOT NULL,
last_login TIMESTAMP
);
DROP TABLE IF EXISTS mylogs;
CREATE TABLE mylogs (
log_id SERIAL PRIMARY KEY,
user_email VARCHAR(255) NOT NULL,
action VARCHAR(255) NOT NULL,
timestamp TIMESTAMP NOT NULL,
FOREIGN KEY (user_email) REFERENCES myusers(email)
);
""")
.build();
setup.run(runContext);

Map<String, Object> parameters = Map.of(
"name1", "John Doe",
"name2", "Jane Smith",
"name3", "Alice Johnson",
"email1", "johndoe@example.com",
"email2", "janesmith@example.com",
"email3", "alicejohnson@example.com",
"action1", "login",
"action2", "login",
"action3", "login"
);

Queries insertAndSelect = Queries.builder()
.url(getUrl())
.username(getUsername())
.password(getPassword())
.fetchType(FETCH)
.timeZoneId("Europe/Paris")
.sql("""
INSERT INTO myusers (name, email, last_login)
VALUES
(:name1, :email1, NOW()),
(:name2, :email2, NOW()),
(:name3, :email3, NOW());
INSERT INTO mylogs (user_email, action, timestamp)
VALUES
(:email1, :action1, NOW() - INTERVAL '10 MINUTES'),
(:email2, :action2, NOW() - INTERVAL '20 MINUTES'),
(:email3, :action3, NOW() - INTERVAL '30 MINUTES');
SELECT * FROM mylogs;
SELECT * FROM myusers;
""")
.parameters(Property.of(parameters))
.build();

Queries.MultiQueryOutput output = insertAndSelect.run(runContext);
assertThat(output.getOutputs().size(), is(2));
}

@Test
void testMultiQueriesOnlySelectOutputs() throws Exception {
RunContext runContext = runContextFactory.of(Collections.emptyMap());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ private PreparedStatement createPreparedStatementAndPopulateParameters(RunContex

//Extract parameters in orders and replace them with '?'
String preparedSql = sql;
Pattern pattern = Pattern.compile(" :\\w+");
Pattern pattern = Pattern.compile(":\\w+");
Matcher matcher = pattern.matcher(preparedSql);

List<String> params = new LinkedList<>();

while (matcher.find()) {
String param = matcher.group();
params.add(param.substring(2));
preparedSql = matcher.replaceFirst( " ?");
params.add(param.substring(1));
preparedSql = matcher.replaceFirst( "?");
matcher = pattern.matcher(preparedSql);
}
stmt = createPreparedStatement(conn, preparedSql);
Expand Down

0 comments on commit de56c68

Please sign in to comment.