Skip to content

Commit 1b3e4b0

Browse files
Allow relaying messages to self (#2834)
Allow sending messages to self Fixes corner cases caused by compact encoding of node ids. Every message to be relayed now follows the same path and `MessageRelay` can relay to self.
1 parent c866be3 commit 1b3e4b0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+343
-358
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala

+5-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package fr.acinq.eclair.crypto
1818

1919
import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
2020
import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto}
21+
import fr.acinq.eclair.EncodedNodeId
2122
import fr.acinq.eclair.wire.protocol._
2223
import grizzled.slf4j.Logging
2324
import scodec.Attempt
@@ -341,14 +342,14 @@ object Sphinx extends Logging {
341342
object RouteBlinding {
342343

343344
/**
344-
* @param publicKey introduction node's public key (which cannot be blinded since the sender need to find a route to it).
345+
* @param nodeId introduction node's id (which cannot be blinded since the sender need to find a route to it).
345346
* @param blindedPublicKey blinded public key, which hides the real public key.
346347
* @param blindingEphemeralKey blinding tweak that can be used by the receiving node to derive the private key that
347348
* matches the blinded public key.
348349
* @param encryptedPayload encrypted payload that can be decrypted with the introduction node's private key and the
349350
* blinding ephemeral key.
350351
*/
351-
case class IntroductionNode(publicKey: PublicKey, blindedPublicKey: PublicKey, blindingEphemeralKey: PublicKey, encryptedPayload: ByteVector)
352+
case class IntroductionNode(nodeId: EncodedNodeId, blindedPublicKey: PublicKey, blindingEphemeralKey: PublicKey, encryptedPayload: ByteVector)
352353

353354
/**
354355
* @param blindedPublicKey blinded public key, which hides the real public key.
@@ -363,7 +364,7 @@ object Sphinx extends Logging {
363364
* matches the blinded public key.
364365
* @param blindedNodes blinded nodes (including the introduction node).
365366
*/
366-
case class BlindedRoute(introductionNodeId: PublicKey, blindingKey: PublicKey, blindedNodes: Seq[BlindedNode]) {
367+
case class BlindedRoute(introductionNodeId: EncodedNodeId, blindingKey: PublicKey, blindedNodes: Seq[BlindedNode]) {
367368
require(blindedNodes.nonEmpty, "blinded route must not be empty")
368369
val introductionNode: IntroductionNode = IntroductionNode(introductionNodeId, blindedNodes.head.blindedPublicKey, blindingKey, blindedNodes.head.encryptedPayload)
369370
val subsequentNodes: Seq[BlindedNode] = blindedNodes.tail
@@ -398,7 +399,7 @@ object Sphinx extends Logging {
398399
e = e.multiply(PrivateKey(Crypto.sha256(blindingKey.value ++ sharedSecret.bytes)))
399400
(BlindedNode(blindedPublicKey, encryptedPayload ++ mac), blindingKey)
400401
}.unzip
401-
BlindedRouteDetails(BlindedRoute(publicKeys.head, blindingKeys.head, blindedHops), blindingKeys.last)
402+
BlindedRouteDetails(BlindedRoute(EncodedNodeId(publicKeys.head), blindingKeys.head, blindedHops), blindingKeys.last)
402403
}
403404

404405
/**

eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala

+7-4
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ object MessageRelay {
5959
case class Disconnected(messageId: ByteVector32) extends Failure {
6060
override def toString: String = "Peer is not connected"
6161
}
62-
case class UnknownOutgoingChannel(messageId: ByteVector32, outgoingChannelId: ShortChannelId) extends Failure {
63-
override def toString: String = s"Unknown outgoing channel: $outgoingChannelId"
62+
case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure {
63+
override def toString: String = s"Unknown channel: $channelId"
6464
}
6565
case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure {
6666
override def toString: String = s"Message dropped: $reason"
@@ -99,6 +99,8 @@ private class MessageRelay(nodeParams: NodeParams,
9999

100100
def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = {
101101
nextNode match {
102+
case Left(outgoingChannelId) if outgoingChannelId == ShortChannelId.toSelf =>
103+
withNextNodeId(msg, nodeParams.nodeId)
102104
case Left(outgoingChannelId) =>
103105
register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId)
104106
waitForNextNodeId(msg, outgoingChannelId)
@@ -110,14 +112,15 @@ private class MessageRelay(nodeParams: NodeParams,
110112
}
111113
}
112114

113-
private def waitForNextNodeId(msg: OnionMessage, outgoingChannelId: ShortChannelId): Behavior[Command] =
115+
private def waitForNextNodeId(msg: OnionMessage, channelId: ShortChannelId): Behavior[Command] = {
114116
Behaviors.receiveMessagePartial {
115117
case WrappedOptionalNodeId(None) =>
116-
replyTo_opt.foreach(_ ! UnknownOutgoingChannel(messageId, outgoingChannelId))
118+
replyTo_opt.foreach(_ ! UnknownChannel(messageId, channelId))
117119
Behaviors.stopped
118120
case WrappedOptionalNodeId(Some(nextNodeId)) =>
119121
withNextNodeId(msg, nextNodeId)
120122
}
123+
}
121124

122125
private def withNextNodeId(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = {
123126
if (nextNodeId == nodeParams.nodeId) {

eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala

+3-9
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ object RouteNodeIdsSerializer extends ConvertClassSerializer[Route](route => {
357357
case Some(hop: NodeHop) if channelNodeIds.nonEmpty => Seq(hop.nextNodeId)
358358
case Some(hop: NodeHop) => Seq(hop.nodeId, hop.nextNodeId)
359359
case Some(hop: BlindedHop) if channelNodeIds.nonEmpty => hop.route.blindedNodeIds.tail
360-
case Some(hop: BlindedHop) => hop.route.introductionNodeId +: hop.route.blindedNodeIds.tail
360+
case Some(hop: BlindedHop) => hop.nodeId +: hop.route.blindedNodeIds.tail
361361
case None => Nil
362362
}
363363
RouteNodeIdsJson(route.amount, channelNodeIds ++ finalNodeIds)
@@ -468,14 +468,8 @@ object InvoiceSerializer extends MinimalSerializer({
468468
UnknownFeatureSerializer
469469
)),
470470
JField("blindedPaths", JArray(p.blindedPaths.map(path => {
471-
val introductionNode = path.route match {
472-
case OfferTypes.BlindedPath(route) => route.introductionNodeId.toString
473-
case OfferTypes.CompactBlindedPath(shortIdDir, _, _) => s"${if (shortIdDir.isNode1) '0' else '1'}x${shortIdDir.scid.toString}"
474-
}
475-
val blindedNodes = path.route match {
476-
case OfferTypes.BlindedPath(route) => route.blindedNodes
477-
case OfferTypes.CompactBlindedPath(_, _, nodes) => nodes
478-
}
471+
val introductionNode = path.route.introductionNodeId.toString
472+
val blindedNodes = path.route.blindedNodes
479473
JObject(List(
480474
JField("introductionNodeId", JString(introductionNode)),
481475
JField("blindedNodeIds", JArray(blindedNodes.map(n => JString(n.blindedPublicKey.toString)).toList))

eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala

+40-55
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import fr.acinq.eclair.wire.protocol._
2727
import scodec.bits.ByteVector
2828
import scodec.{Attempt, DecodeResult}
2929

30-
import scala.annotation.tailrec
3130
import scala.concurrent.duration.FiniteDuration
3231

3332
object OnionMessages {
@@ -44,23 +43,29 @@ object OnionMessages {
4443
timeout: FiniteDuration,
4544
maxAttempts: Int)
4645

47-
case class IntermediateNode(nodeId: PublicKey, outgoingChannel_opt: Option[ShortChannelId] = None, padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) {
48-
def toTlvStream(nextNodeId: PublicKey, nextBlinding_opt: Option[PublicKey] = None): TlvStream[RouteBlindingEncryptedDataTlv] =
46+
case class IntermediateNode(publicKey: PublicKey, encodedNodeId: EncodedNodeId, outgoingChannel_opt: Option[ShortChannelId] = None, padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) {
47+
def toTlvStream(nextNodeId: EncodedNodeId, nextBlinding_opt: Option[PublicKey] = None): TlvStream[RouteBlindingEncryptedDataTlv] =
4948
TlvStream(Set[Option[RouteBlindingEncryptedDataTlv]](
5049
padding.map(Padding),
5150
outgoingChannel_opt.map(OutgoingChannelId).orElse(Some(OutgoingNodeId(nextNodeId))),
5251
nextBlinding_opt.map(NextBlinding)
5352
).flatten, customTlvs)
5453
}
5554

55+
object IntermediateNode {
56+
def apply(publicKey: PublicKey): IntermediateNode = IntermediateNode(publicKey, EncodedNodeId(publicKey))
57+
}
58+
5659
// @formatter:off
5760
sealed trait Destination {
58-
def nodeId: PublicKey
61+
def introductionNodeId: EncodedNodeId
5962
}
6063
case class BlindedPath(route: Sphinx.RouteBlinding.BlindedRoute) extends Destination {
61-
override def nodeId: PublicKey = route.introductionNodeId
64+
override def introductionNodeId: EncodedNodeId = route.introductionNodeId
65+
}
66+
case class Recipient(nodeId: PublicKey, pathId: Option[ByteVector], padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) extends Destination {
67+
override def introductionNodeId: EncodedNodeId = EncodedNodeId(nodeId)
6268
}
63-
case class Recipient(nodeId: PublicKey, pathId: Option[ByteVector], padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) extends Destination
6469
// @formatter:on
6570

6671
// @formatter:off
@@ -75,11 +80,11 @@ object OnionMessages {
7580
}
7681
// @formatter:on
7782

78-
private def buildIntermediatePayloads(intermediateNodes: Seq[IntermediateNode], lastNodeId: PublicKey, lastBlinding_opt: Option[PublicKey] = None): Seq[ByteVector] = {
83+
private def buildIntermediatePayloads(intermediateNodes: Seq[IntermediateNode], lastNodeId: EncodedNodeId, lastBlinding_opt: Option[PublicKey] = None): Seq[ByteVector] = {
7984
if (intermediateNodes.isEmpty) {
8085
Nil
8186
} else {
82-
val intermediatePayloads = intermediateNodes.dropRight(1).zip(intermediateNodes.tail).map { case (hop, nextNode) => hop.toTlvStream(nextNode.nodeId) }
87+
val intermediatePayloads = intermediateNodes.dropRight(1).zip(intermediateNodes.tail).map { case (hop, nextNode) => hop.toTlvStream(nextNode.encodedNodeId) }
8388
val lastPayload = intermediateNodes.last.toTlvStream(lastNodeId, lastBlinding_opt)
8489
(intermediatePayloads :+ lastPayload).map(tlvs => RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(tlvs).require.bytes)
8590
}
@@ -88,33 +93,22 @@ object OnionMessages {
8893
def buildRoute(blindingSecret: PrivateKey,
8994
intermediateNodes: Seq[IntermediateNode],
9095
recipient: Recipient): Sphinx.RouteBlinding.BlindedRoute = {
91-
val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, recipient.nodeId)
96+
val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, EncodedNodeId(recipient.nodeId))
9297
val tlvs: Set[RouteBlindingEncryptedDataTlv] = Set(recipient.padding.map(Padding), recipient.pathId.map(PathId)).flatten
9398
val lastPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(tlvs, recipient.customTlvs)).require.bytes
94-
Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId) :+ recipient.nodeId, intermediatePayloads :+ lastPayload).route
99+
Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.publicKey) :+ recipient.nodeId, intermediatePayloads :+ lastPayload).route
95100
}
96101

97-
private[message] def buildRouteFrom(originKey: PrivateKey,
98-
blindingSecret: PrivateKey,
102+
private[message] def buildRouteFrom(blindingSecret: PrivateKey,
99103
intermediateNodes: Seq[IntermediateNode],
100-
destination: Destination): Option[Sphinx.RouteBlinding.BlindedRoute] = {
104+
destination: Destination): Sphinx.RouteBlinding.BlindedRoute = {
101105
destination match {
102-
case recipient: Recipient => Some(buildRoute(blindingSecret, intermediateNodes, recipient))
103-
case BlindedPath(route) if route.introductionNodeId == originKey.publicKey =>
104-
RouteBlindingEncryptedDataCodecs.decode(originKey, route.blindingKey, route.blindedNodes.head.encryptedPayload) match {
105-
case Left(_) => None
106-
case Right(decoded) =>
107-
decoded.tlvs.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId] match {
108-
case Some(RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.Plain(nextNodeId))) =>
109-
Some(Sphinx.RouteBlinding.BlindedRoute(nextNodeId, decoded.nextBlinding, route.blindedNodes.tail))
110-
case _ => None // TODO: allow compact node id and OutgoingChannelId
111-
}
112-
}
113-
case BlindedPath(route) if intermediateNodes.isEmpty => Some(route)
106+
case recipient: Recipient => buildRoute(blindingSecret, intermediateNodes, recipient)
107+
case BlindedPath(route) if intermediateNodes.isEmpty => route
114108
case BlindedPath(route) =>
115109
val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, route.introductionNodeId, Some(route.blindingKey))
116-
val routePrefix = Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId), intermediatePayloads).route
117-
Some(Sphinx.RouteBlinding.BlindedRoute(routePrefix.introductionNodeId, routePrefix.blindingKey, routePrefix.blindedNodes ++ route.blindedNodes))
110+
val routePrefix = Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.publicKey), intermediatePayloads).route
111+
Sphinx.RouteBlinding.BlindedRoute(routePrefix.introductionNodeId, routePrefix.blindingKey, routePrefix.blindedNodes ++ route.blindedNodes)
118112
}
119113
}
120114

@@ -134,32 +128,28 @@ object OnionMessages {
134128
* @param content List of TLVs to send to the recipient of the message
135129
* @return The node id to send the onion to and the onion containing the message
136130
*/
137-
def buildMessage(nodeKey: PrivateKey,
138-
sessionKey: PrivateKey,
131+
def buildMessage(sessionKey: PrivateKey,
139132
blindingSecret: PrivateKey,
140133
intermediateNodes: Seq[IntermediateNode],
141134
destination: Destination,
142-
content: TlvStream[OnionMessagePayloadTlv]): Either[BuildMessageError, (PublicKey, OnionMessage)] = {
143-
buildRouteFrom(nodeKey, blindingSecret, intermediateNodes, destination) match {
144-
case None => Left(InvalidDestination(destination))
145-
case Some(route) =>
146-
val lastPayload = MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(content.records + EncryptedData(route.encryptedPayloads.last), content.unknown)).require.bytes
147-
val payloads = route.encryptedPayloads.dropRight(1).map(encTlv => MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(EncryptedData(encTlv))).require.bytes) :+ lastPayload
148-
val payloadSize = payloads.map(_.length + Sphinx.MacLength).sum
149-
val packetSize = if (payloadSize <= 1300) {
150-
1300
151-
} else if (payloadSize <= 32768) {
152-
32768
153-
} else if (payloadSize > 65432) {
154-
// A payload of size 65432 corresponds to a total lightning message size of 65535.
155-
return Left(MessageTooLarge(payloadSize))
156-
} else {
157-
payloadSize.toInt
158-
}
159-
// Since we are setting the packet size based on the payload, the onion creation should never fail (hence the `.get`).
160-
val Sphinx.PacketAndSecrets(packet, _) = Sphinx.create(sessionKey, packetSize, route.blindedNodes.map(_.blindedPublicKey), payloads, None).get
161-
Right((route.introductionNodeId, OnionMessage(route.blindingKey, packet)))
135+
content: TlvStream[OnionMessagePayloadTlv]): Either[BuildMessageError, OnionMessage] = {
136+
val route = buildRouteFrom(blindingSecret, intermediateNodes, destination)
137+
val lastPayload = MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(content.records + EncryptedData(route.encryptedPayloads.last), content.unknown)).require.bytes
138+
val payloads = route.encryptedPayloads.dropRight(1).map(encTlv => MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(EncryptedData(encTlv))).require.bytes) :+ lastPayload
139+
val payloadSize = payloads.map(_.length + Sphinx.MacLength).sum
140+
val packetSize = if (payloadSize <= 1300) {
141+
1300
142+
} else if (payloadSize <= 32768) {
143+
32768
144+
} else if (payloadSize > 65432) {
145+
// A payload of size 65432 corresponds to a total lightning message size of 65535.
146+
return Left(MessageTooLarge(payloadSize))
147+
} else {
148+
payloadSize.toInt
162149
}
150+
// Since we are setting the packet size based on the payload, the onion creation should never fail (hence the `.get`).
151+
val Sphinx.PacketAndSecrets(packet, _) = Sphinx.create(sessionKey, packetSize, route.blindedNodes.map(_.blindedPublicKey), payloads, None).get
152+
Right(OnionMessage(route.blindingKey, packet))
163153
}
164154

165155
// @formatter:off
@@ -199,7 +189,6 @@ object OnionMessages {
199189
}
200190
}
201191

202-
@tailrec
203192
def process(privateKey: PrivateKey, msg: OnionMessage): Action = {
204193
val blindedPrivateKey = Sphinx.RouteBlinding.derivePrivateKey(privateKey, msg.blindingKey)
205194
decryptOnion(blindedPrivateKey, msg.onionRoutingPacket) match {
@@ -210,11 +199,7 @@ object OnionMessages {
210199
decryptEncryptedData(privateKey, msg.blindingKey, encryptedData) match {
211200
case Left(f) => DropMessage(f)
212201
case Right(DecodedEncryptedData(blindedPayload, nextBlinding)) => nextPacket_opt match {
213-
case Some(nextPacket) => validateRelayPayload(payload, blindedPayload, nextBlinding, nextPacket) match {
214-
case SendMessage(Right(EncodedNodeId.Plain(publicKey)), nextMsg) if publicKey == privateKey.publicKey => process(privateKey, nextMsg) // TODO: remove and rely on MessageRelay
215-
case SendMessage(Left(outgoingChannelId), nextMsg) if outgoingChannelId == ShortChannelId.toSelf => process(privateKey, nextMsg) // TODO: remove and rely on MessageRelay
216-
case action => action
217-
}
202+
case Some(nextPacket) => validateRelayPayload(payload, blindedPayload, nextBlinding, nextPacket)
218203
case None => validateFinalPayload(payload, blindedPayload)
219204
}
220205
}

0 commit comments

Comments
 (0)