diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt index 84432017e4..a265a20f9e 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.RowExpression import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload import org.jetbrains.kotlinx.dataframe.annotations.Interpretable +import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.columns.values @@ -164,6 +165,8 @@ public fun > Grouped.max( public fun > Grouped.max(vararg columns: KProperty, name: String? = null): DataFrame = max(name) { columns.toColumnSet() } +@Refine +@Interpretable("GroupByMaxOf") public fun > Grouped.maxOf( name: String? = null, expression: RowExpression, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt index 74a949ee31..d1cae852aa 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.RowExpression import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload import org.jetbrains.kotlinx.dataframe.annotations.Interpretable +import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.columns.values @@ -164,6 +165,8 @@ public fun > Grouped.min( public fun > Grouped.min(vararg columns: KProperty, name: String? = null): DataFrame = min(name) { columns.toColumnSet() } +@Refine +@Interpretable("GroupByMinOf") public fun > Grouped.minOf( name: String? = null, expression: RowExpression, diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index eea04ef476..72ae4bd346 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -22,6 +22,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximat import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore +import org.jetbrains.kotlinx.dataframe.plugin.impl.makeNullable import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf import org.jetbrains.kotlinx.dataframe.plugin.impl.type import org.jetbrains.kotlinx.dataframe.plugin.interpret @@ -181,3 +182,17 @@ class GroupByAdd : AbstractInterpreter() { return GroupBy(receiver.keys, receiver.groups.add(name, type.type, context = this)) } } + +abstract class GroupByAggregator(val defaultName: String) : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.name: String? by arg(defaultValue = Present(null)) + val Arguments.expression by type() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val aggregated = makeNullable(simpleColumnOf(name ?: defaultName, expression.type)) + return PluginDataFrameSchema(receiver.keys.columns() + aggregated) + } +} + +class GroupByMaxOf : GroupByAggregator(defaultName = "max") +class GroupByMinOf : GroupByAggregator(defaultName = "min") diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index 5c76f7f9f9..2f091af117 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -91,6 +91,8 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByCount0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByInto +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMaxOf +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMinOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceExpression import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceInto import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReducePredicate @@ -303,6 +305,8 @@ internal inline fun String.load(): T { "GroupByReducePredicate" -> GroupByReducePredicate() "GroupByReduceExpression" -> GroupByReduceExpression() "GroupByReduceInto" -> GroupByReduceInto() + "GroupByMaxOf" -> GroupByMaxOf() + "GroupByMinOf" -> GroupByMinOf() else -> error("$this") } as T } diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_maxOfMinOf.kt b/plugins/kotlin-dataframe/testData/box/groupBy_maxOfMinOf.kt new file mode 100644 index 0000000000..2d372f69a3 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/groupBy_maxOfMinOf.kt @@ -0,0 +1,16 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + val df = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }.maxOf { 123 } + val df1 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }.minOf { 123 } + + val max = df.max[0] + val min = df1.min[0] + + df.compareSchemas() + df1.compareSchemas() + return "OK" +} diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index edef7a8c3f..77424539f7 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -238,6 +238,12 @@ public void testGroupBy_extractSchema() { runTest("testData/box/groupBy_extractSchema.kt"); } + @Test + @TestMetadata("groupBy_maxOfMinOf.kt") + public void testGroupBy_maxOfMinOf() { + runTest("testData/box/groupBy_maxOfMinOf.kt"); + } + @Test @TestMetadata("groupBy_refine.kt") public void testGroupBy_refine() {