From 5f1a2d8698c94deb6dc16dbc8622c20104a19d0b Mon Sep 17 00:00:00 2001 From: Katarzyna Marek Date: Fri, 8 Dec 2023 13:50:01 +0100 Subject: [PATCH] bugfix: fix workaround for wrong spans in extension method call --- .../meta/internal/pc/MetalsInteractive.scala | 26 +++++++++-- .../scala/meta/internal/pc/PcCollector.scala | 35 ++++++++++---- .../Scala3DocumentHighlightSuite.scala | 46 +++++++++++++++++++ .../tests/hover/HoverScala3TypeSuite.scala | 14 ++++++ .../scala/tests/pc/PcDefinitionSuite.scala | 36 ++++++++++++++- 5 files changed, 143 insertions(+), 14 deletions(-) diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/MetalsInteractive.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/MetalsInteractive.scala index 8fbad48b1a3..f3465658c31 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/MetalsInteractive.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/MetalsInteractive.scala @@ -277,9 +277,9 @@ object MetalsInteractive: * * val a = MyIntOut(1).un@@even */ - case (a @ Apply(sel: Select, _)) :: _ - if sel.span.isZeroExtent && sel.symbol.is(Flags.ExtensionMethod) => - List((sel.symbol, a.typeOpt)) + case (a @ ExtensionMethodCall(symbol, app)) :: _ + if app.span.withStart(app.span.point).contains(pos.span) => + List((symbol, a.typeOpt)) case path @ head :: tail => if head.symbol.is(Synthetic) then @@ -345,6 +345,26 @@ object MetalsInteractive: } end ApplySelect + object ExtensionMethodCallSymbol: + def unapply(tree: Tree)(using Context): Option[(Ident | Select)] = + tree match + case tree: (Ident | Select) if tree.symbol.is(Flags.ExtensionMethod) => + Some(tree) + case TypeApply(tree: (Ident | Select), _) + if tree.symbol.is(Flags.ExtensionMethod) => + Some(tree) + case _ => None + end ExtensionMethodCallSymbol + + object ExtensionMethodCall: + def unapply(tree: Tree)(using Context): Option[(Symbol, Apply)] = + tree match + case app @ Apply(ExtensionMethodCallSymbol(tree), _) => + Some((tree.symbol, app)) + case Apply(tree, _) => unapply(tree) + case _ => None + end ExtensionMethodCall + object TreeApply: def unapply(tree: Tree): Option[(Tree, List[Tree])] = tree match diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala index 6bbe6bbdd56..032bf860b08 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala @@ -6,6 +6,8 @@ import scala.meta as m import scala.meta.internal.metals.CompilerOffsetParams import scala.meta.internal.mtags.MtagsEnrichments.* +import scala.meta.internal.pc.MetalsInteractive.ExtensionMethodCall +import scala.meta.internal.pc.MetalsInteractive.ExtensionMethodCallSymbol import scala.meta.pc.OffsetParams import scala.meta.pc.VirtualFileParams @@ -250,10 +252,10 @@ abstract class PcCollector[T]( * * val a = MyIntOut(1).<> */ - case (a @ Apply(sel: Select, _)) :: _ - if sel.span.isZeroExtent && sel.symbol.is(Flags.ExtensionMethod) => - val span = a.span.withStart(a.span.point) - Some(symbolAlternatives(sel.symbol), pos.withSpan(span)) + case ExtensionMethodCall(sym, app) :: _ + if app.span.withStart(app.span.point).contains(pos.span) => + val span = app.span.withStart(app.span.point) + Some(symbolAlternatives(sym), pos.withSpan(span)) case _ => None sought match @@ -424,7 +426,11 @@ abstract class PcCollector[T]( * All indentifiers such as: * val a = <> */ - case ident: Ident if ident.span.isCorrect && filter(ident) => + case ident: Ident + if ident.span.isCorrect && filter(ident) && !isExtensionMethodCall( + parent, + ident.symbol, + ) => // symbols will differ for params in different ext methods, but source pos will be the same if soughtFilter(_.sourcePos == ident.symbol.sourcePos) then @@ -468,15 +474,14 @@ abstract class PcCollector[T]( * * val a = MyIntOut(1).<> */ - case sel: Select - if sel.span.isZeroExtent && sel.symbol.is(Flags.ExtensionMethod) => + case ExtensionMethodCallSymbol(tree) => parent match case Some(a: Apply) => val span = a.span.withStart(a.span.point) - val amendedSelect = sel.withSpan(span) - if filter(amendedSelect) then + val amendedTree = tree.withSpan(span) + if filter(amendedTree) then occurences + collect( - amendedSelect, + amendedTree, pos.withSpan(span), ) else occurences @@ -639,6 +644,16 @@ abstract class PcCollector[T]( tree match case sel: Select => sel.sourcePos.withSpan(selectNameSpan(sel)) case _ => tree.sourcePos + + /** + * Those have wrong spans and we special case for them. + */ + private def isExtensionMethodCall(parent: Option[Tree], symbol: Symbol) = + symbol.is(Flags.ExtensionMethod) && parent.exists { + case _: TypeApply | _: Apply => true + case _ => false + } + end PcCollector object PcCollector: diff --git a/tests/cross/src/test/scala/tests/highlight/Scala3DocumentHighlightSuite.scala b/tests/cross/src/test/scala/tests/highlight/Scala3DocumentHighlightSuite.scala index 5d2783efdd4..1607169646d 100644 --- a/tests/cross/src/test/scala/tests/highlight/Scala3DocumentHighlightSuite.scala +++ b/tests/cross/src/test/scala/tests/highlight/Scala3DocumentHighlightSuite.scala @@ -339,4 +339,50 @@ class Scala3DocumentHighlightSuite extends BaseDocumentHighlightSuite { |""".stripMargin ) + check( + "i5921-1", + """|object Logarithms: + | opaque type Logarithm = Double + | extension [K](vmap: Logarithm) + | def <>(k: Logarithm): Logarithm = ??? + | + |object Test: + | val in: Logarithms.Logarithm = ??? + | in.<>(in) + |""".stripMargin + ) + + check( + "i5921-2", + """|object Logarithms: + | opaque type Logarithm = Double + | extension [K](vmap: Logarithm) + | def <>(k: Logarithm): Logarithm = ??? + | + |object Test: + | val in: Logarithms.Logarithm = ??? + | in.<>(in) + |""".stripMargin + ) + + check( + "i5921-3", + """|object Logarithms: + | opaque type Logarithm = Double + | extension [K](vmap: Logarithm) + | def <>(k: Logarithm): Logarithm = ??? + | (2.0).<>(1.0) + |""".stripMargin + ) + + check( + "i5921-4", + """|object Logarithms: + | opaque type Logarithm = Double + | extension [K](vmap: Logarithm) + | def <>(k: Logarithm): Logarithm = ??? + | (2.0).<>(1.0) + |""".stripMargin + ) + } diff --git a/tests/cross/src/test/scala/tests/hover/HoverScala3TypeSuite.scala b/tests/cross/src/test/scala/tests/hover/HoverScala3TypeSuite.scala index 9dca8761f8d..d9bbc1467ee 100644 --- a/tests/cross/src/test/scala/tests/hover/HoverScala3TypeSuite.scala +++ b/tests/cross/src/test/scala/tests/hover/HoverScala3TypeSuite.scala @@ -376,4 +376,18 @@ class HoverScala3TypeSuite extends BaseHoverSuite { """|extension (i: MyIntOut) def uneven: Boolean |""".stripMargin.hover ) + + check( + "i5921", + """|object Logarithms: + | trait Logarithm + | extension [K](vmap: Logarithm) + | def multiply(k: Logarithm): Logarithm = ??? + | + |object Test: + | val in: Logarithms.Logarithm = ??? + | in.multi@@ply(in) + |""".stripMargin, + "extension [K](vmap: Logarithm) def multiply(k: Logarithm): Logarithm".hover + ) } diff --git a/tests/cross/src/test/scala/tests/pc/PcDefinitionSuite.scala b/tests/cross/src/test/scala/tests/pc/PcDefinitionSuite.scala index 51ae37152eb..1889459ae09 100644 --- a/tests/cross/src/test/scala/tests/pc/PcDefinitionSuite.scala +++ b/tests/cross/src/test/scala/tests/pc/PcDefinitionSuite.scala @@ -560,7 +560,7 @@ class PcDefinitionSuite extends BasePcDefinitionSuite { ) check( - "i5630".tag(IgnoreScala2.and(IgnoreForScala3CompilerPC)), + "i5630".tag(IgnoreScala2), """|class MyIntOut(val value: Int) |object MyIntOut: | extension (i: MyIntOut) def <> = i.value % 2 == 1 @@ -569,4 +569,38 @@ class PcDefinitionSuite extends BasePcDefinitionSuite { |""".stripMargin ) + check( + "i5921".tag(IgnoreScala2), + """|object Logarithms: + | opaque type Logarithm = Double + | extension [K](vmap: Logarithm) + | def <>(k: Logarithm): Logarithm = ??? + | + |object Test: + | val in: Logarithms.Logarithm = ??? + | in.multi@@ply(in) + |""".stripMargin + ) + + check( + "i5921-1".tag(IgnoreScala2), + """|object Logarithms: + | opaque type Logarithm = Double + | extension [K](vmap: Logarithm) + | def <>(k: Logarithm): Logarithm = ??? + | (2.0).mult@@iply(1.0) + |""".stripMargin + ) + + check( + "i5921-2".tag(IgnoreScala2), + """|object Logarithms: + | opaque type Logarithm = Double + | extension [K](vmap: Logarithm) + | def multiply(k: Logarithm): Logarithm = ??? + | val <> = 1.0 + | (2.0).multiply(v@@v) + |""".stripMargin + ) + }