Skip to content

Commit

Permalink
[Compiler plugin] Support GroupBy.[minOf | maxOf]
Browse files Browse the repository at this point in the history
  • Loading branch information
koperagen committed Feb 10, 2025
1 parent a85e5ec commit b1f9bf3
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.RowExpression
import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.annotations.Refine
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
import org.jetbrains.kotlinx.dataframe.columns.values
Expand Down Expand Up @@ -164,6 +165,8 @@ public fun <T, C : Comparable<C>> Grouped<T>.max(
public fun <T, C : Comparable<C>> Grouped<T>.max(vararg columns: KProperty<C?>, name: String? = null): DataFrame<T> =
max(name) { columns.toColumnSet() }

@Refine
@Interpretable("GroupByMaxOf")
public fun <T, C : Comparable<C>> Grouped<T>.maxOf(
name: String? = null,
expression: RowExpression<T, C>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.RowExpression
import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.annotations.Refine
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
import org.jetbrains.kotlinx.dataframe.columns.values
Expand Down Expand Up @@ -164,6 +165,8 @@ public fun <T, C : Comparable<C>> Grouped<T>.min(
public fun <T, C : Comparable<C>> Grouped<T>.min(vararg columns: KProperty<C?>, name: String? = null): DataFrame<T> =
min(name) { columns.toColumnSet() }

@Refine
@Interpretable("GroupByMinOf")
public fun <T, C : Comparable<C>> Grouped<T>.minOf(
name: String? = null,
expression: RowExpression<T, C>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximat
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy
import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore
import org.jetbrains.kotlinx.dataframe.plugin.impl.makeNullable
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
import org.jetbrains.kotlinx.dataframe.plugin.interpret
Expand Down Expand Up @@ -181,3 +182,17 @@ class GroupByAdd : AbstractInterpreter<GroupBy>() {
return GroupBy(receiver.keys, receiver.groups.add(name, type.type, context = this))
}
}

abstract class GroupByAggregator(val defaultName: String) : AbstractSchemaModificationInterpreter() {
val Arguments.receiver by groupBy()
val Arguments.name: String? by arg(defaultValue = Present(null))
val Arguments.expression by type()

override fun Arguments.interpret(): PluginDataFrameSchema {
val aggregated = makeNullable(simpleColumnOf(name ?: defaultName, expression.type))
return PluginDataFrameSchema(receiver.keys.columns() + aggregated)
}
}

class GroupByMaxOf : GroupByAggregator(defaultName = "max")
class GroupByMinOf : GroupByAggregator(defaultName = "min")
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByCount0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByInto
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMaxOf
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMinOf
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceExpression
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceInto
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReducePredicate
Expand Down Expand Up @@ -303,6 +305,8 @@ internal inline fun <reified T> String.load(): T {
"GroupByReducePredicate" -> GroupByReducePredicate()
"GroupByReduceExpression" -> GroupByReduceExpression()
"GroupByReduceInto" -> GroupByReduceInto()
"GroupByMaxOf" -> GroupByMaxOf()
"GroupByMinOf" -> GroupByMinOf()
else -> error("$this")
} as T
}
16 changes: 16 additions & 0 deletions plugins/kotlin-dataframe/testData/box/groupBy_maxOfMinOf.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val df = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }.maxOf { 123 }
val df1 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }.minOf { 123 }

val max = df.max[0]
val min = df1.min[0]

df.compareSchemas()
df1.compareSchemas()
return "OK"
}
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ public void testGroupBy_extractSchema() {
runTest("testData/box/groupBy_extractSchema.kt");
}

@Test
@TestMetadata("groupBy_maxOfMinOf.kt")
public void testGroupBy_maxOfMinOf() {
runTest("testData/box/groupBy_maxOfMinOf.kt");
}

@Test
@TestMetadata("groupBy_refine.kt")
public void testGroupBy_refine() {
Expand Down

0 comments on commit b1f9bf3

Please sign in to comment.