Skip to content

Commit

Permalink
[Compiler plugin] Add checkers that report compile time schema as inf…
Browse files Browse the repository at this point in the history
…o warnings to observe implicit schema generation
  • Loading branch information
koperagen committed Jan 30, 2025
1 parent bddf7bf commit 68fecef
Show file tree
Hide file tree
Showing 11 changed files with 287 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class FirDataFrameExtensionRegistrar(
private val path: String?,
val schemasDirectory: String?,
val isTest: Boolean,
val dumpSchemas: Boolean,
) : FirExtensionRegistrar() {
@OptIn(FirExtensionApiInternals::class)
override fun ExtensionRegistrarContext.configurePlugin() {
Expand All @@ -76,7 +77,7 @@ class FirDataFrameExtensionRegistrar(
+::TokenGenerator
+::DataRowSchemaSupertype
+{ it: FirSession ->
ExpressionAnalysisAdditionalChecker(it, jsonCache(it), schemasDirectory, isTest)
ExpressionAnalysisAdditionalChecker(it, jsonCache(it), schemasDirectory, isTest, dumpSchemas)
}
}

Expand All @@ -93,7 +94,9 @@ class FirDataFrameComponentRegistrar : CompilerPluginRegistrar() {
override fun ExtensionStorage.registerExtensions(configuration: CompilerConfiguration) {
val schemasDirectory = configuration.get(SCHEMAS)
val path = configuration.get(PATH)
FirExtensionRegistrarAdapter.registerExtension(FirDataFrameExtensionRegistrar(path, schemasDirectory, isTest = false))
FirExtensionRegistrarAdapter.registerExtension(
FirDataFrameExtensionRegistrar(path, schemasDirectory, isTest = false, dumpSchemas = true)
)
IrGenerationExtension.registerExtension(IrBodyFiller(path, schemasDirectory))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,38 @@

package org.jetbrains.kotlinx.dataframe.plugin.extensions

import com.intellij.psi.PsiElement
import org.jetbrains.kotlin.KtSourceElement
import org.jetbrains.kotlin.diagnostics.AbstractSourceElementPositioningStrategy
import org.jetbrains.kotlin.diagnostics.DiagnosticFactory1DelegateProvider
import org.jetbrains.kotlin.diagnostics.DiagnosticReporter
import org.jetbrains.kotlin.diagnostics.KtDiagnosticFactory1
import org.jetbrains.kotlin.diagnostics.Severity
import org.jetbrains.kotlin.diagnostics.SourceElementPositioningStrategies
import org.jetbrains.kotlin.diagnostics.error1
import org.jetbrains.kotlin.diagnostics.reportOn
import org.jetbrains.kotlin.diagnostics.warning1
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.analysis.checkers.MppCheckerKind
import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
import org.jetbrains.kotlin.fir.analysis.checkers.declaration.DeclarationCheckers
import org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirPropertyChecker
import org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirSimpleFunctionChecker
import org.jetbrains.kotlin.fir.analysis.checkers.expression.ExpressionCheckers
import org.jetbrains.kotlin.fir.analysis.checkers.expression.FirFunctionCallChecker
import org.jetbrains.kotlin.fir.analysis.checkers.expression.FirPropertyAccessExpressionChecker
import org.jetbrains.kotlin.fir.analysis.extensions.FirAdditionalCheckersExtension
import org.jetbrains.kotlin.fir.caches.FirCache
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.flatten
import org.jetbrains.kotlinx.dataframe.plugin.pluginDataFrameSchema
import org.jetbrains.kotlin.fir.declarations.FirProperty
import org.jetbrains.kotlin.fir.declarations.FirSimpleFunction
import org.jetbrains.kotlin.fir.declarations.hasAnnotation
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.expressions.FirPropertyAccessExpression
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
import org.jetbrains.kotlin.fir.resolve.fullyExpandedType
import org.jetbrains.kotlin.fir.types.ConeClassLikeType
import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.FirTypeProjectionWithVariance
import org.jetbrains.kotlin.fir.types.coneType
import org.jetbrains.kotlin.fir.types.isSubtypeOf
Expand All @@ -39,17 +51,30 @@ import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.psi.KtElement
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.flatten
import org.jetbrains.kotlinx.dataframe.plugin.pluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names
import org.jetbrains.kotlinx.dataframe.plugin.utils.isDataFrame
import org.jetbrains.kotlinx.dataframe.plugin.utils.isGroupBy

class ExpressionAnalysisAdditionalChecker(
session: FirSession,
cache: FirCache<String, PluginDataFrameSchema, KotlinTypeFacade>,
schemasDirectory: String?,
isTest: Boolean,
dumpSchemas: Boolean
) : FirAdditionalCheckersExtension(session) {
override val expressionCheckers: ExpressionCheckers = object : ExpressionCheckers() {
override val functionCallCheckers: Set<FirFunctionCallChecker> = setOf(Checker(cache, schemasDirectory, isTest))
override val functionCallCheckers: Set<FirFunctionCallChecker> = setOfNotNull(
Checker(cache, schemasDirectory, isTest), FunctionCallSchemaReporter.takeIf { dumpSchemas }
)
override val propertyAccessExpressionCheckers: Set<FirPropertyAccessExpressionChecker> = setOfNotNull(
PropertyAccessSchemaReporter.takeIf { dumpSchemas }
)
}
override val declarationCheckers: DeclarationCheckers = object : DeclarationCheckers() {
override val propertyCheckers: Set<FirPropertyChecker> = setOfNotNull(PropertySchemaReporter.takeIf { dumpSchemas })
override val simpleFunctionCheckers: Set<FirSimpleFunctionChecker> = setOfNotNull(FunctionDeclarationSchemaReporter.takeIf { dumpSchemas })
}
}

Expand Down Expand Up @@ -132,3 +157,84 @@ private class Checker(
}
}
}

private data object PropertySchemaReporter : FirPropertyChecker(mppKind = MppCheckerKind.Common) {
val SCHEMA by info1<KtElement, String>(SourceElementPositioningStrategies.DECLARATION_NAME)

override fun check(declaration: FirProperty, context: CheckerContext, reporter: DiagnosticReporter) {
context.sessionContext {
declaration.returnTypeRef.coneType.let { type ->
reportSchema(reporter, declaration.source, SCHEMA, type, context)
}
}
}
}

private data object FunctionCallSchemaReporter : FirFunctionCallChecker(mppKind = MppCheckerKind.Common) {
val SCHEMA by info1<KtElement, String>(SourceElementPositioningStrategies.REFERENCED_NAME_BY_QUALIFIED)

override fun check(expression: FirFunctionCall, context: CheckerContext, reporter: DiagnosticReporter) {
if (expression.calleeReference.name in setOf(Name.identifier("let"), Name.identifier("run"))) return
val initializer = expression.resolvedType
context.sessionContext {
reportSchema(reporter, expression.source, SCHEMA, initializer, context)
}
}
}

private data object PropertyAccessSchemaReporter : FirPropertyAccessExpressionChecker(mppKind = MppCheckerKind.Common) {
val SCHEMA by info1<KtElement, String>(SourceElementPositioningStrategies.REFERENCED_NAME_BY_QUALIFIED)

override fun check(
expression: FirPropertyAccessExpression,
context: CheckerContext,
reporter: DiagnosticReporter
) {
val initializer = expression.resolvedType
context.sessionContext {
reportSchema(reporter, expression.source, SCHEMA, initializer, context)
}
}
}

private data object FunctionDeclarationSchemaReporter : FirSimpleFunctionChecker(mppKind = MppCheckerKind.Common) {
val SCHEMA by info1<KtElement, String>(SourceElementPositioningStrategies.DECLARATION_SIGNATURE)

override fun check(declaration: FirSimpleFunction, context: CheckerContext, reporter: DiagnosticReporter) {
val type = declaration.returnTypeRef.coneType
context.sessionContext {
reportSchema(reporter, declaration.source, SCHEMA, type, context)
}
}
}

private fun SessionContext.reportSchema(
reporter: DiagnosticReporter,
source: KtSourceElement?,
factory: KtDiagnosticFactory1<String>,
type: ConeKotlinType,
context: CheckerContext,
) {
val schema: PluginDataFrameSchema? = if (type.isDataFrame(session)) {
type.typeArguments.getOrNull(0)?.let {
pluginDataFrameSchema(it)
}
} else if (type.isGroupBy(session)) {
null
} else {
null
}
if (schema != null && source != null) {
reporter.reportOn(source, factory, "\n" + schema.toString(), context)
}
}

fun CheckerContext.sessionContext(f: SessionContext.() -> Unit) {
SessionContext(session).f()
}

inline fun <reified P : PsiElement, A> info1(
positioningStrategy: AbstractSourceElementPositioningStrategy = SourceElementPositioningStrategies.DEFAULT
): DiagnosticFactory1DelegateProvider<A> {
return DiagnosticFactory1DelegateProvider(Severity.INFO, positioningStrategy, P::class)
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ import kotlin.reflect.KType
import kotlin.reflect.KTypeProjection
import kotlin.reflect.KVariance

interface KotlinTypeFacade {
val session: FirSession
interface KotlinTypeFacade : SessionContext {
val resolutionPath: String? get() = null
val cache: FirCache<String, PluginDataFrameSchema, KotlinTypeFacade>
val schemasDirectory: String?
Expand Down Expand Up @@ -99,6 +98,14 @@ interface KotlinTypeFacade {
}
}

interface SessionContext {
val session: FirSession
}

fun SessionContext(session: FirSession) = object : SessionContext {
override val session: FirSession = session
}

private val List = "List".collectionsId()

private fun ConeKotlinType.isBuiltinType(classId: ClassId, isNullable: Boolean?): Boolean {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.jetbrains.kotlin.fir.analysis.checkers.fullyExpandedClassId
import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.ConeNullability
import org.jetbrains.kotlin.fir.types.isNullable
import org.jetbrains.kotlin.fir.types.renderReadable
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
import org.jetbrains.kotlinx.dataframe.plugin.extensions.wrap
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation
Expand All @@ -28,16 +29,17 @@ data class PluginDataFrameSchema(
}

private fun List<SimpleCol>.asString(indent: String = ""): String {
if (isEmpty()) return "$indent<empty compile time schema>"
return joinToString("\n") {
val col = when (it) {
is SimpleFrameColumn -> {
"${it.name}*\n" + it.columns().asString("$indent ")
"${it.name}: *\n" + it.columns().asString("$indent ")
}
is SimpleColumnGroup -> {
"${it.name}\n" + it.columns().asString("$indent ")
"${it.name}:\n" + it.columns().asString("$indent ")
}
is SimpleDataColumn -> {
"${it.name}: ${it.type}"
"${it.name}: ${it.type.type.renderReadable()}"
}
}
"$indent$col"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.utils.addToStdlib.runIf
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.plugin.extensions.SessionContext
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnPathApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.DataFrameCallableId
Expand Down Expand Up @@ -320,7 +321,7 @@ interface InterpretationErrorReporter {
}
}

fun KotlinTypeFacade.pluginDataFrameSchema(schemaTypeArg: ConeTypeProjection): PluginDataFrameSchema {
fun SessionContext.pluginDataFrameSchema(schemaTypeArg: ConeTypeProjection): PluginDataFrameSchema {
val schema = if (schemaTypeArg.isStarProjection) {
PluginDataFrameSchema.EMPTY
} else {
Expand All @@ -330,7 +331,7 @@ fun KotlinTypeFacade.pluginDataFrameSchema(schemaTypeArg: ConeTypeProjection): P
return schema
}

fun KotlinTypeFacade.pluginDataFrameSchema(coneClassLikeType: ConeClassLikeType): PluginDataFrameSchema {
fun SessionContext.pluginDataFrameSchema(coneClassLikeType: ConeClassLikeType): PluginDataFrameSchema {
val symbol = coneClassLikeType.toSymbol(session) as? FirRegularClassSymbol ?: return PluginDataFrameSchema.EMPTY
val declarationSymbols = if (symbol.isLocal && symbol.resolvedSuperTypes.firstOrNull() != session.builtinTypes.anyType.type) {
val rootSchemaSymbol = symbol.resolvedSuperTypes.first().toSymbol(session) as? FirRegularClassSymbol
Expand Down Expand Up @@ -394,7 +395,7 @@ private fun KotlinTypeFacade.columnWithPathApproximations(result: FirPropertyAcc
}
}

private fun KotlinTypeFacade.columnOf(it: FirPropertySymbol, mapping: Map<FirTypeParameterSymbol, ConeTypeProjection>): SimpleCol? {
private fun SessionContext.columnOf(it: FirPropertySymbol, mapping: Map<FirTypeParameterSymbol, ConeTypeProjection>): SimpleCol? {
val annotation = it.getAnnotationByClassId(Names.COLUMN_NAME_ANNOTATION, session)
val columnName = (annotation?.argumentMapping?.mapping?.get(Names.COLUMN_NAME_ARGUMENT) as? FirLiteralExpression)?.value as? String
val name = columnName ?: it.name.identifier
Expand Down Expand Up @@ -443,14 +444,14 @@ private fun KotlinTypeFacade.columnOf(it: FirPropertySymbol, mapping: Map<FirTyp
}
}

private fun KotlinTypeFacade.shouldBeConvertedToColumnGroup(it: FirPropertySymbol) =
private fun SessionContext.shouldBeConvertedToColumnGroup(it: FirPropertySymbol) =
isDataRow(it) ||
it.resolvedReturnType.toRegularClassSymbol(session)?.hasAnnotation(Names.DATA_SCHEMA_CLASS_ID, session) == true

private fun isDataRow(it: FirPropertySymbol) =
it.resolvedReturnType.classId == Names.DATA_ROW_CLASS_ID

private fun KotlinTypeFacade.shouldBeConvertedToFrameColumn(it: FirPropertySymbol) =
private fun SessionContext.shouldBeConvertedToFrameColumn(it: FirPropertySymbol) =
isDataFrame(it) ||
(it.resolvedReturnType.classId == Names.LIST &&
it.resolvedReturnType.typeArguments[0].type?.toRegularClassSymbol(session)?.hasAnnotation(Names.DATA_SCHEMA_CLASS_ID, session) == true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
package org.jetbrains.kotlinx.dataframe.plugin.utils

import org.jetbrains.kotlin.builtins.StandardNames
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.analysis.checkers.fullyExpandedClassId
import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
Expand Down Expand Up @@ -72,3 +75,6 @@ private fun KClass<*>.classId(): ClassId {
val className = fqName.substringAfterLast(".")
return ClassId(FqName(packageFqName), Name.identifier(className))
}

fun ConeKotlinType.isDataFrame(session: FirSession) = fullyExpandedClassId(session) == Names.DF_CLASS_ID
fun ConeKotlinType.isGroupBy(session: FirSession) = fullyExpandedClassId(session) == Names.GROUP_BY_CLASS_ID
Loading

0 comments on commit 68fecef

Please sign in to comment.