Skip to content

Commit

Permalink
bugfix: fix workaround for wrong spans in extension method call
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek committed Dec 11, 2023
1 parent 26a01d9 commit 5f1a2d8
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,9 @@ object MetalsInteractive:
*
* val a = MyIntOut(1).un@@even
*/
case (a @ Apply(sel: Select, _)) :: _
if sel.span.isZeroExtent && sel.symbol.is(Flags.ExtensionMethod) =>
List((sel.symbol, a.typeOpt))
case (a @ ExtensionMethodCall(symbol, app)) :: _
if app.span.withStart(app.span.point).contains(pos.span) =>
List((symbol, a.typeOpt))

case path @ head :: tail =>
if head.symbol.is(Synthetic) then
Expand Down Expand Up @@ -345,6 +345,26 @@ object MetalsInteractive:
}
end ApplySelect

object ExtensionMethodCallSymbol:
def unapply(tree: Tree)(using Context): Option[(Ident | Select)] =
tree match
case tree: (Ident | Select) if tree.symbol.is(Flags.ExtensionMethod) =>
Some(tree)
case TypeApply(tree: (Ident | Select), _)
if tree.symbol.is(Flags.ExtensionMethod) =>
Some(tree)
case _ => None
end ExtensionMethodCallSymbol

object ExtensionMethodCall:
def unapply(tree: Tree)(using Context): Option[(Symbol, Apply)] =
tree match
case app @ Apply(ExtensionMethodCallSymbol(tree), _) =>
Some((tree.symbol, app))
case Apply(tree, _) => unapply(tree)
case _ => None
end ExtensionMethodCall

object TreeApply:
def unapply(tree: Tree): Option[(Tree, List[Tree])] =
tree match
Expand Down
35 changes: 25 additions & 10 deletions mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import scala.meta as m

import scala.meta.internal.metals.CompilerOffsetParams
import scala.meta.internal.mtags.MtagsEnrichments.*
import scala.meta.internal.pc.MetalsInteractive.ExtensionMethodCall
import scala.meta.internal.pc.MetalsInteractive.ExtensionMethodCallSymbol
import scala.meta.pc.OffsetParams
import scala.meta.pc.VirtualFileParams

Expand Down Expand Up @@ -250,10 +252,10 @@ abstract class PcCollector[T](
*
* val a = MyIntOut(1).<<un@@even>>
*/
case (a @ Apply(sel: Select, _)) :: _
if sel.span.isZeroExtent && sel.symbol.is(Flags.ExtensionMethod) =>
val span = a.span.withStart(a.span.point)
Some(symbolAlternatives(sel.symbol), pos.withSpan(span))
case ExtensionMethodCall(sym, app) :: _
if app.span.withStart(app.span.point).contains(pos.span) =>
val span = app.span.withStart(app.span.point)
Some(symbolAlternatives(sym), pos.withSpan(span))
case _ => None

sought match
Expand Down Expand Up @@ -424,7 +426,11 @@ abstract class PcCollector[T](
* All indentifiers such as:
* val a = <<b>>
*/
case ident: Ident if ident.span.isCorrect && filter(ident) =>
case ident: Ident
if ident.span.isCorrect && filter(ident) && !isExtensionMethodCall(
parent,
ident.symbol,
) =>
// symbols will differ for params in different ext methods, but source pos will be the same
if soughtFilter(_.sourcePos == ident.symbol.sourcePos)
then
Expand Down Expand Up @@ -468,15 +474,14 @@ abstract class PcCollector[T](
*
* val a = MyIntOut(1).<<un@@even>>
*/
case sel: Select
if sel.span.isZeroExtent && sel.symbol.is(Flags.ExtensionMethod) =>
case ExtensionMethodCallSymbol(tree) =>
parent match
case Some(a: Apply) =>
val span = a.span.withStart(a.span.point)
val amendedSelect = sel.withSpan(span)
if filter(amendedSelect) then
val amendedTree = tree.withSpan(span)
if filter(amendedTree) then
occurences + collect(
amendedSelect,
amendedTree,
pos.withSpan(span),
)
else occurences
Expand Down Expand Up @@ -639,6 +644,16 @@ abstract class PcCollector[T](
tree match
case sel: Select => sel.sourcePos.withSpan(selectNameSpan(sel))
case _ => tree.sourcePos

/**
* Those have wrong spans and we special case for them.
*/
private def isExtensionMethodCall(parent: Option[Tree], symbol: Symbol) =
symbol.is(Flags.ExtensionMethod) && parent.exists {
case _: TypeApply | _: Apply => true
case _ => false
}

end PcCollector

object PcCollector:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,4 +339,50 @@ class Scala3DocumentHighlightSuite extends BaseDocumentHighlightSuite {
|""".stripMargin
)

check(
"i5921-1",
"""|object Logarithms:
| opaque type Logarithm = Double
| extension [K](vmap: Logarithm)
| def <<multiply>>(k: Logarithm): Logarithm = ???
|
|object Test:
| val in: Logarithms.Logarithm = ???
| in.<<multi@@ply>>(in)
|""".stripMargin
)

check(
"i5921-2",
"""|object Logarithms:
| opaque type Logarithm = Double
| extension [K](vmap: Logarithm)
| def <<mu@@ltiply>>(k: Logarithm): Logarithm = ???
|
|object Test:
| val in: Logarithms.Logarithm = ???
| in.<<multiply>>(in)
|""".stripMargin
)

check(
"i5921-3",
"""|object Logarithms:
| opaque type Logarithm = Double
| extension [K](vmap: Logarithm)
| def <<multiply>>(k: Logarithm): Logarithm = ???
| (2.0).<<mult@@iply>>(1.0)
|""".stripMargin
)

check(
"i5921-4",
"""|object Logarithms:
| opaque type Logarithm = Double
| extension [K](vmap: Logarithm)
| def <<mult@@iply>>(k: Logarithm): Logarithm = ???
| (2.0).<<multiply>>(1.0)
|""".stripMargin
)

}
14 changes: 14 additions & 0 deletions tests/cross/src/test/scala/tests/hover/HoverScala3TypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -376,4 +376,18 @@ class HoverScala3TypeSuite extends BaseHoverSuite {
"""|extension (i: MyIntOut) def uneven: Boolean
|""".stripMargin.hover
)

check(
"i5921",
"""|object Logarithms:
| trait Logarithm
| extension [K](vmap: Logarithm)
| def multiply(k: Logarithm): Logarithm = ???
|
|object Test:
| val in: Logarithms.Logarithm = ???
| in.multi@@ply(in)
|""".stripMargin,
"extension [K](vmap: Logarithm) def multiply(k: Logarithm): Logarithm".hover
)
}
36 changes: 35 additions & 1 deletion tests/cross/src/test/scala/tests/pc/PcDefinitionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ class PcDefinitionSuite extends BasePcDefinitionSuite {
)

check(
"i5630".tag(IgnoreScala2.and(IgnoreForScala3CompilerPC)),
"i5630".tag(IgnoreScala2),
"""|class MyIntOut(val value: Int)
|object MyIntOut:
| extension (i: MyIntOut) def <<uneven>> = i.value % 2 == 1
Expand All @@ -569,4 +569,38 @@ class PcDefinitionSuite extends BasePcDefinitionSuite {
|""".stripMargin
)

check(
"i5921".tag(IgnoreScala2),
"""|object Logarithms:
| opaque type Logarithm = Double
| extension [K](vmap: Logarithm)
| def <<multiply>>(k: Logarithm): Logarithm = ???
|
|object Test:
| val in: Logarithms.Logarithm = ???
| in.multi@@ply(in)
|""".stripMargin
)

check(
"i5921-1".tag(IgnoreScala2),
"""|object Logarithms:
| opaque type Logarithm = Double
| extension [K](vmap: Logarithm)
| def <<multiply>>(k: Logarithm): Logarithm = ???
| (2.0).mult@@iply(1.0)
|""".stripMargin
)

check(
"i5921-2".tag(IgnoreScala2),
"""|object Logarithms:
| opaque type Logarithm = Double
| extension [K](vmap: Logarithm)
| def multiply(k: Logarithm): Logarithm = ???
| val <<vv>> = 1.0
| (2.0).multiply(v@@v)
|""".stripMargin
)

}

0 comments on commit 5f1a2d8

Please sign in to comment.