Skip to content

Commit

Permalink
Implement Flag select and filter conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsrnhld committed Jan 5, 2024
1 parent 2bed2b7 commit ebfcb31
Show file tree
Hide file tree
Showing 19 changed files with 497 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@
import com.bakdata.conquery.models.exceptions.ConceptConfigurationException;
import com.bakdata.conquery.models.query.filter.event.FlagColumnsFilterNode;
import com.bakdata.conquery.models.query.queryplan.filter.FilterNode;
import com.bakdata.conquery.sql.conversion.cqelement.concept.ConceptCteStep;
import com.bakdata.conquery.sql.conversion.cqelement.concept.FilterContext;
import com.bakdata.conquery.sql.conversion.model.filter.SqlFilters;
import com.bakdata.conquery.sql.conversion.model.select.FlagSqlAggregator;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import io.dropwizard.validation.ValidationMethod;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;

Expand All @@ -31,6 +36,7 @@
*
* The selected flags are logically or-ed.
*/
@Getter
@CPSType(base = Filter.class, id = "FLAGS")
@RequiredArgsConstructor(onConstructor_ = {@JsonCreator})
@ToString
Expand Down Expand Up @@ -87,4 +93,14 @@ public boolean isAllColumnsOfSameTable() {
public boolean isAllColumnsBoolean() {
return flags.values().stream().map(Column::getType).allMatch(MajorTypeId.BOOLEAN::equals);
}

@Override
public SqlFilters convertToSqlFilter(FilterContext<String[]> filterContext) {
return FlagSqlAggregator.create(this, filterContext).getSqlFilters();
}

@Override
public Set<ConceptCteStep> getRequiredSqlSteps() {
return ConceptCteStep.withOptionalSteps(ConceptCteStep.EVENT_FILTER);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
import com.bakdata.conquery.models.events.MajorTypeId;
import com.bakdata.conquery.models.query.queryplan.aggregators.Aggregator;
import com.bakdata.conquery.models.query.queryplan.aggregators.specific.FlagsAggregator;
import com.bakdata.conquery.sql.conversion.cqelement.concept.SelectContext;
import com.bakdata.conquery.sql.conversion.model.select.FlagSqlAggregator;
import com.bakdata.conquery.sql.conversion.model.select.SqlSelects;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import io.dropwizard.validation.ValidationMethod;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;

Expand All @@ -24,6 +28,7 @@
*
* The selected flags are logically or-ed.
*/
@Getter
@CPSType(base = Select.class, id = "FLAGS")
@RequiredArgsConstructor(onConstructor_ = {@JsonCreator})
@ToString
Expand Down Expand Up @@ -55,4 +60,9 @@ public boolean isAllColumnsOfSameTable() {
public boolean isAllColumnsBoolean() {
return flags.values().stream().map(Column::getType).allMatch(MajorTypeId.BOOLEAN::equals);
}

@Override
public SqlSelects convertToSqlSelects(SelectContext selectContext) {
return FlagSqlAggregator.create(this, selectContext).getSqlSelects();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public abstract class ResultInfo {
public abstract String defaultColumnName(PrintSettings printSettings);

@ToString.Include
public abstract ResultType getType();
public abstract ResultType<?> getType();

@ToString.Include
public abstract Set<SemanticType> getSemantics();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ protected String print(PrintSettings cfg, @NonNull Object f) {
public String getFromResultSet(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException {
return resultSetProcessor.getString(resultSet, columnIndex);
}

@Override
protected List<String> getFromResultSetAsList(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException {
return resultSetProcessor.getStringList(resultSet, columnIndex);
}
}

@CPSType(id = "MONEY", base = ResultType.class)
Expand Down Expand Up @@ -277,6 +282,7 @@ public BigDecimal readIntermediateValue(PrintSettings cfg, Number f) {
@Getter
@EqualsAndHashCode(callSuper = false)
public static class ListT<T> extends ResultType<List<T>> {

@NonNull
private final ResultType<T> elementType;

Expand Down Expand Up @@ -307,10 +313,10 @@ public String typeInfo() {

@Override
public List<T> getFromResultSet(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException {
if (elementType instanceof DateRangeT) {
if (elementType.getClass() == DateRangeT.class || elementType.getClass() == StringT.class) {
return elementType.getFromResultSetAsList(resultSet, columnIndex, resultSetProcessor);
}
// TODO handle all other list types properly by
// TODO handle all other list types properly
throw new UnsupportedOperationException("Other result type lists not supported for now.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.bakdata.conquery.models.datasets.concepts.ValidityDate;
import com.bakdata.conquery.sql.conversion.model.ColumnDateRange;
import org.jooq.Condition;
import org.jooq.DataType;
import org.jooq.Field;
import org.jooq.Name;
import org.jooq.Param;
Expand All @@ -30,6 +31,20 @@ public String getMaxDateExpression() {
return MAX_DATE_VALUE;
}

@Override
public <T> Field<T> cast(Field<?> field, DataType<T> type) {
return DSL.function(
"CAST",
type.getType(),
DSL.field("%s AS %s".formatted(field, type.getName()))
);
}

@Override
public Field<String> toChar(int character) {
return DSL.function("char", String.class, DSL.val(character));
}

@Override
public Condition dateRestriction(ColumnDateRange dateRestriction, ColumnDateRange validityDate) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.bakdata.conquery.models.datasets.concepts.ValidityDate;
import com.bakdata.conquery.sql.conversion.model.ColumnDateRange;
import org.jooq.Condition;
import org.jooq.DataType;
import org.jooq.DatePart;
import org.jooq.Field;
import org.jooq.Name;
Expand All @@ -31,6 +32,16 @@ public String getMaxDateExpression() {
return INFINITY_DATE_VALUE;
}

@Override
public <T> Field<T> cast(Field<?> field, DataType<T> type) {
return DSL.cast(field, type);
}

@Override
public Field<String> toChar(int character) {
return DSL.chr(character);
}

@Override
public String getMinDateExpression() {
return MINUS_INFINITY_DATE_VALUE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.bakdata.conquery.sql.conversion.model.ColumnDateRange;
import com.bakdata.conquery.sql.conversion.model.QueryStep;
import org.jooq.Condition;
import org.jooq.DataType;
import org.jooq.Field;
import org.jooq.Name;
import org.jooq.Record;
Expand All @@ -29,6 +30,10 @@ public interface SqlFunctionProvider {

String getMaxDateExpression();

<T> Field<T> cast(Field<?> field, DataType<T> type);

Field<String> toChar(int character);

/**
* A date restriction condition is true if holds: dateRestrictionStart <= validityDateEnd and dateRestrictionEnd >= validityDateStart
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.bakdata.conquery.sql.conversion.model.filter;

import java.util.List;

import lombok.RequiredArgsConstructor;
import org.jooq.Condition;
import org.jooq.Field;
import org.jooq.impl.DSL;

@RequiredArgsConstructor
public class FlagCondition implements WhereCondition {

private final List<Field<Boolean>> flagFields;

@Override
public Condition condition() {
return flagFields.stream()
.map(DSL::condition)
.map(Field::isTrue)
.reduce(Condition::or)
.orElseThrow(() -> new IllegalArgumentException("Can't construct a FlagCondition with an empty flag field list."));
}

@Override
public ConditionType type() {
return ConditionType.EVENT;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package com.bakdata.conquery.sql.conversion.model.select;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.bakdata.conquery.models.datasets.Column;
import com.bakdata.conquery.models.datasets.concepts.filters.specific.FlagFilter;
import com.bakdata.conquery.models.datasets.concepts.select.connector.specific.FlagSelect;
import com.bakdata.conquery.models.identifiable.NamedImpl;
import com.bakdata.conquery.sql.conversion.cqelement.concept.ConceptCteStep;
import com.bakdata.conquery.sql.conversion.cqelement.concept.FilterContext;
import com.bakdata.conquery.sql.conversion.cqelement.concept.SelectContext;
import com.bakdata.conquery.sql.conversion.dialect.SqlFunctionProvider;
import com.bakdata.conquery.sql.conversion.model.SqlTables;
import com.bakdata.conquery.sql.conversion.model.filter.FlagCondition;
import com.bakdata.conquery.sql.conversion.model.filter.WhereClauses;
import com.bakdata.conquery.sql.execution.ResultSetProcessor;
import lombok.Value;
import org.jooq.Condition;
import org.jooq.Field;
import org.jooq.Param;
import org.jooq.impl.DSL;
import org.jooq.impl.SQLDataType;

/**
* {@link FlagSelect} conversion aggregates the keys of the flags of a {@link FlagSelect} into a string aggregation.
* <p>
* If any value of the respective flag column is true, the flag key will be part of the string aggregation. <br>
* If no value is true, an empty string will be added as value, because a {@code null} value would cause the whole string aggregation to be {@code null} too. <br>
* Each value will be followed by the {@link ResultSetProcessor#UNIT_SEPARATOR}.
*
* <pre>
* {@code
* "group_select" as (
* select
* "pid",
* (
* case when max(CAST("preprocessing"."a" AS integer)) = 1 then 'A' else '' end || char(31)
* || case when max(CAST("preprocessing"."b" AS integer)) = 1 then 'B' else '' end || char(31)
* || case when max(CAST("preprocessing"."c" AS integer)) = 1 then 'C' else '' end || char(31)
* ) "flags_selects-1"
* from "preprocessing"
* group by "pid"
* )
* }
* </pre>
*
* <hr>
* <p>
* {@link FlagFilter} conversion filters events if not at least 1 of the flag columns has a true value for the corresponding entry.
*
* <pre>
* {@code
* "event_filter" as (
* select "pid"
* from "preprocessing"
* where (
* "preprocessing"."b" = true
* or "preprocessing"."c" = true
* )
* )
* }
* </pre>
*/
@Value
public class FlagSqlAggregator implements SqlAggregator {

private static final Param<Integer> NUMERIC_TRUE_VAL = DSL.val(1);
private static final Param<Integer> NUMERIC_FALSE_VAL = DSL.val(0);
private static final Param<String> EMPTY_STRING = DSL.val("");

SqlSelects sqlSelects;
WhereClauses whereClauses;

public static FlagSqlAggregator create(FlagSelect flagSelect, SelectContext selectContext) {

SqlFunctionProvider functionProvider = selectContext.getParentContext().getSqlDialect().getFunctionProvider();
SqlTables<ConceptCteStep> conceptTables = selectContext.getConceptTables();

String rootTable = conceptTables.getPredecessor(ConceptCteStep.PREPROCESSING);
Map<String, SqlSelect> rootSelects = createFlagRootSelectMap(flagSelect, rootTable);

String alias = selectContext.getNameGenerator().selectName(flagSelect);
FieldWrapper<String> flagAggregation = createFlagSelect(alias, conceptTables, functionProvider, rootSelects);

ExtractingSqlSelect<String> finalSelect = flagAggregation.createAliasedReference(conceptTables.getPredecessor(ConceptCteStep.FINAL));

SqlSelects sqlSelects = SqlSelects.builder().preprocessingSelects(rootSelects.values())
.aggregationSelect(flagAggregation)
.finalSelect(finalSelect)
.build();

return new FlagSqlAggregator(sqlSelects, WhereClauses.builder().build());
}

public static FlagSqlAggregator create(FlagFilter flagFilter, FilterContext<String[]> filterContext) {
SqlTables<ConceptCteStep> conceptTables = filterContext.getConceptTables();
String rootTable = conceptTables.getPredecessor(ConceptCteStep.PREPROCESSING);

List<SqlSelect> rootSelects =
getRequiredColumnNames(flagFilter.getFlags(), filterContext.getValue())
.stream()
.map(columnName -> new ExtractingSqlSelect<>(rootTable, columnName, Boolean.class))
.collect(Collectors.toList());
SqlSelects selects = SqlSelects.builder()
.preprocessingSelects(rootSelects)
.build();

List<Field<Boolean>> flagFields = rootSelects.stream()
.map(sqlSelect -> conceptTables.<Boolean>qualifyOnPredecessor(ConceptCteStep.EVENT_FILTER, sqlSelect.aliased()))
.toList();
FlagCondition flagCondition = new FlagCondition(flagFields);
WhereClauses whereClauses = WhereClauses.builder()
.eventFilter(flagCondition)
.build();

return new FlagSqlAggregator(selects, whereClauses);
}

/**
* @return A mapping between a flags key and the corresponding {@link ExtractingSqlSelect} that will be created to reference the flag's column.
*/
private static Map<String, SqlSelect> createFlagRootSelectMap(FlagSelect flagSelect, String rootTable) {
return flagSelect.getFlags()
.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> new ExtractingSqlSelect<>(rootTable, entry.getValue().getName(), Boolean.class)
));
}

private static FieldWrapper<String> createFlagSelect(
String alias,
SqlTables<ConceptCteStep> conceptTables,
SqlFunctionProvider functionProvider,
Map<String, SqlSelect> flagRootSelectMap
) {
Map<String, Field<Boolean>> flagFieldsMap = createRootSelectReferences(conceptTables, flagRootSelectMap);

List<Field<String>> flagAggregations = new ArrayList<>();
for (Map.Entry<String, Field<Boolean>> entry : flagFieldsMap.entrySet()) {
Field<Boolean> boolColumn = entry.getValue();
Condition anyTrue = DSL.max(functionProvider.cast(boolColumn, SQLDataType.INTEGER))
.eq(NUMERIC_TRUE_VAL);
// we have to prevent null values because then the whole String aggregation is null
String flagName = entry.getKey();
Field<String> flag = DSL.when(anyTrue, DSL.val(flagName))
.otherwise(EMPTY_STRING);
// append separator
Field<String> separator = functionProvider.toChar(ResultSetProcessor.UNIT_SEPARATOR);
Field<String> withSeparator = DSL.field("%s || %s".formatted(flag, separator), String.class);
flagAggregations.add(withSeparator);
}

return new FieldWrapper<>(
DSL.concat(flagAggregations.toArray(Field[]::new)).as(alias)
);
}

private static Map<String, Field<Boolean>> createRootSelectReferences(SqlTables<ConceptCteStep> conceptTables, Map<String, SqlSelect> flagRootSelectMap) {
return flagRootSelectMap.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> conceptTables.qualifyOnPredecessor(ConceptCteStep.AGGREGATION_SELECT, entry.getValue().aliased())
));
}

/**
* @return Columns names of a given flags map that match the selected flags of the filter value.
*/
private static List<String> getRequiredColumnNames(Map<String, Column> flags, String[] selectedFlags) {
return Arrays.stream(selectedFlags)
.map(flags::get)
.map(NamedImpl::getName)
.toList();
}

}
Loading

0 comments on commit ebfcb31

Please sign in to comment.