Skip to content

Commit

Permalink
[Compiler plugin] Support groupBy.[first | last | maxBy | minBy].into…
Browse files Browse the repository at this point in the history
…(columName)
  • Loading branch information
koperagen committed Feb 10, 2025
1 parent 6088e7b commit 2a6f2d0
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.RowFilter
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.ColumnPath
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
Expand Down Expand Up @@ -55,8 +56,10 @@ public fun <T> DataFrame<T>.firstOrNull(predicate: RowFilter<T>): DataRow<T>? =

// region GroupBy

@Interpretable("GroupByReducePredicate")
public fun <T, G> GroupBy<T, G>.first(): ReducedGroupBy<T, G> = reduce { firstOrNull() }

@Interpretable("GroupByReducePredicate")
public fun <T, G> GroupBy<T, G>.first(predicate: RowFilter<G>): ReducedGroupBy<T, G> = reduce { firstOrNull(predicate) }

// endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ public inline fun <T, G, reified V> ReducedGroupBy<T, G>.into(
noinline expression: RowExpression<G, V>,
): DataFrame<G> = into(column.columnName, expression)

@Refine
@Interpretable("GroupByReduceInto")
public fun <T, G> ReducedGroupBy<T, G>.into(columnName: String): DataFrame<G> = into(columnName) { this }

@AccessApiOverload
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.RowFilter
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.ColumnPath
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
Expand Down Expand Up @@ -56,8 +57,10 @@ public fun <T> DataFrame<T>.last(): DataRow<T> {

// region GroupBy

@Interpretable("GroupByReducePredicate")
public fun <T, G> GroupBy<T, G>.last(): ReducedGroupBy<T, G> = reduce { lastOrNull() }

@Interpretable("GroupByReducePredicate")
public fun <T, G> GroupBy<T, G>.last(predicate: RowFilter<G>): ReducedGroupBy<T, G> = reduce { lastOrNull(predicate) }

// endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.RowExpression
import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
import org.jetbrains.kotlinx.dataframe.columns.values
Expand Down Expand Up @@ -168,6 +169,7 @@ public fun <T, C : Comparable<C>> Grouped<T>.maxOf(
expression: RowExpression<T, C>,
): DataFrame<T> = Aggregators.max.aggregateOfDelegated(this, name) { maxOfOrNull(expression) }

@Interpretable("GroupByReduceExpression")
public fun <T, G, R : Comparable<R>> GroupBy<T, G>.maxBy(rowExpression: RowExpression<G, R?>): ReducedGroupBy<T, G> =
reduce { maxByOrNull(rowExpression) }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.RowExpression
import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
import org.jetbrains.kotlinx.dataframe.columns.values
Expand Down Expand Up @@ -168,6 +169,7 @@ public fun <T, C : Comparable<C>> Grouped<T>.minOf(
expression: RowExpression<T, C>,
): DataFrame<T> = Aggregators.min.aggregateOfDelegated(this, name) { minOfOrNull(expression) }

@Interpretable("GroupByReduceExpression")
public fun <T, G, R : Comparable<R>> GroupBy<T, G>.minBy(rowExpression: RowExpression<G, R?>): ReducedGroupBy<T, G> =
reduce { minByOrNull(rowExpression) }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ fun KotlinTypeFacade.simpleColumnOf(name: String, type: ConeKotlinType): SimpleC
}
}

private fun KotlinTypeFacade.makeNullable(column: SimpleCol): SimpleCol {
internal fun KotlinTypeFacade.makeNullable(column: SimpleCol): SimpleCol {
return when (column) {
is SimpleColumnGroup -> {
SimpleColumnGroup(column.name, column.columns().map { makeNullable(it) })
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.jetbrains.kotlinx.dataframe.plugin.impl.api

import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy
import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore
import org.jetbrains.kotlinx.dataframe.plugin.impl.makeNullable

class GroupByReducePredicate : AbstractInterpreter<GroupBy>() {
val Arguments.receiver by groupBy()
val Arguments.predicate by ignore()
override fun Arguments.interpret(): GroupBy {
return receiver
}
}

class GroupByReduceExpression : AbstractInterpreter<GroupBy>() {
val Arguments.receiver by groupBy()
val Arguments.rowExpression by ignore()
override fun Arguments.interpret(): GroupBy {
return receiver
}
}

class GroupByReduceInto : AbstractSchemaModificationInterpreter() {
val Arguments.receiver by groupBy()
val Arguments.columnName: String by arg()
override fun Arguments.interpret(): PluginDataFrameSchema {
val group = makeNullable(SimpleColumnGroup(columnName, receiver.groups.columns()))
return PluginDataFrameSchema(receiver.keys.columns() + group)
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package org.jetbrains.kotlinx.dataframe.plugin.impl.api

import org.jetbrains.kotlinx.dataframe.plugin.InterpretationErrorReporter
import org.jetbrains.kotlinx.dataframe.plugin.interpret
import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter
import org.jetbrains.kotlin.fir.expressions.FirAnonymousFunctionExpression
import org.jetbrains.kotlin.fir.expressions.FirExpression
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.expressions.FirReturnExpression
import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.resolvedType
import org.jetbrains.kotlinx.dataframe.plugin.InterpretationErrorReporter
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
Expand All @@ -23,8 +21,11 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.add
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy
import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
import org.jetbrains.kotlinx.dataframe.plugin.interpret
import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter

class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema)

Expand Down Expand Up @@ -173,6 +174,7 @@ class GroupByToDataFrame : AbstractSchemaModificationInterpreter() {
class GroupByAdd : AbstractInterpreter<GroupBy>() {
val Arguments.receiver: GroupBy by groupBy()
val Arguments.name: String by arg()
val Arguments.infer by ignore()
val Arguments.type: TypeApproximation by type(name("expression"))

override fun Arguments.interpret(): GroupBy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ fun <T> KotlinTypeFacade.interpret(
assert(expectedReturnType.toString() == GroupBy::class.qualifiedName!!) {
"'$name' should be ${GroupBy::class.qualifiedName!!}, but plugin expect $expectedReturnType"
}

// ok for ReducedGroupBy too
val resolvedType = it.expression.resolvedType.fullyExpandedType(session)
val keys = pluginDataFrameSchema(resolvedType.typeArguments[0])
val groups = pluginDataFrameSchema(resolvedType.typeArguments[1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByCount0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByInto
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceExpression
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceInto
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReducePredicate
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Merge0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MergeId
Expand Down Expand Up @@ -297,6 +300,9 @@ internal inline fun <reified T> String.load(): T {
"MergeBy1" -> MergeBy1()
"ReorderColumnsByName" -> ReorderColumnsByName()
"GroupByCount0" -> GroupByCount0()
"GroupByReducePredicate" -> GroupByReducePredicate()
"GroupByReduceExpression" -> GroupByReduceExpression()
"GroupByReduceInto" -> GroupByReduceInto()
else -> error("$this")
} as T
}
16 changes: 16 additions & 0 deletions plugins/kotlin-dataframe/testData/box/reducedGroupBy.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val groupBy = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }
groupBy.maxBy { id }.into("group").compareSchemas()
groupBy.maxBy { id }.into("group").compareSchemas()
groupBy.first { id == 1 }.into("group").compareSchemas()
groupBy.first().into("group").compareSchemas()
groupBy.last { id == 1 }.into("group").compareSchemas()
groupBy.last().into("group").compareSchemas()
groupBy.minBy { id == 1 }.into("group").compareSchemas()
return "OK"
}
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,12 @@ public void testRead_localFile() {
runTest("testData/box/read_localFile.kt");
}

@Test
@TestMetadata("reducedGroupBy.kt")
public void testReducedGroupBy() {
runTest("testData/box/reducedGroupBy.kt");
}

@Test
@TestMetadata("remove.kt")
public void testRemove() {
Expand Down

0 comments on commit 2a6f2d0

Please sign in to comment.