From 7ae1d289968bd8696ef20e3d5018f8709b1471c2 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Wed, 13 Sep 2023 11:07:54 +0800 Subject: [PATCH] rpmsg: add release cb and refcnt in endpoint to fix ept used-after-free if rpmsg service free the ept when has got the ept from the ept list in rpmsg_virtio_rx_callback, there is a used after free about the ept, so add refcnt to end point and call the rpmsg service release callback when ept callback fininshed. Signed-off-by: Bowen Wang --- lib/include/openamp/rpmsg.h | 7 +++++++ lib/rpmsg/rpmsg.c | 19 +++++++++++++++++++ lib/rpmsg/rpmsg_internal.h | 25 +++++++++++++++++++++++++ lib/rpmsg/rpmsg_virtio.c | 12 ++++++++++++ 4 files changed, 63 insertions(+) diff --git a/lib/include/openamp/rpmsg.h b/lib/include/openamp/rpmsg.h index 9cf1e7444..c13c600c4 100644 --- a/lib/include/openamp/rpmsg.h +++ b/lib/include/openamp/rpmsg.h @@ -50,6 +50,7 @@ struct rpmsg_device; /* Returns positive value on success or negative error value on failure */ typedef int (*rpmsg_ept_cb)(struct rpmsg_endpoint *ept, void *data, size_t len, uint32_t src, void *priv); +typedef void (*rpmsg_ept_release_cb)(struct rpmsg_endpoint *ept); typedef void (*rpmsg_ns_unbind_cb)(struct rpmsg_endpoint *ept); typedef void (*rpmsg_ns_bind_cb)(struct rpmsg_device *rdev, const char *name, uint32_t dest); @@ -73,6 +74,12 @@ struct rpmsg_endpoint { /** Address of the default remote endpoint binded */ uint32_t dest_addr; + /** Reference count of the endpoint */ + uint32_t refcnt; + + /** Callback to inform the user that the endpoint allocation can be safely removed */ + rpmsg_ept_release_cb release_cb; + /** * User rx callback, return value of this callback is reserved for future * use, for now, only allow RPMSG_SUCCESS as return value diff --git a/lib/rpmsg/rpmsg.c b/lib/rpmsg/rpmsg.c index 5a9237f47..e0132d15b 100644 --- a/lib/rpmsg/rpmsg.c +++ b/lib/rpmsg/rpmsg.c @@ -97,6 +97,23 @@ static int rpmsg_set_address(unsigned long *bitmap, int size, int addr) } } +void rpmsg_ept_incref(struct rpmsg_endpoint *ept) +{ + if (ept && ept->release_cb) { + ept->refcnt++; + } +} + +void rpmsg_ept_decref(struct rpmsg_endpoint *ept) +{ + if (ept && ept->release_cb) { + ept->refcnt--; + if (!ept->refcnt) { + ept->release_cb(ept); + } + } +} + int rpmsg_send_offchannel_raw(struct rpmsg_endpoint *ept, uint32_t src, uint32_t dst, const void *data, int len, int wait) @@ -246,6 +263,7 @@ static void rpmsg_unregister_endpoint(struct rpmsg_endpoint *ept) ept->addr); metal_list_del(&ept->node); ept->rdev = NULL; + rpmsg_ept_decref(ept); metal_mutex_release(&rdev->lock); } @@ -257,6 +275,7 @@ void rpmsg_register_endpoint(struct rpmsg_device *rdev, rpmsg_ns_unbind_cb ns_unbind_cb) { strncpy(ept->name, name ? name : "", sizeof(ept->name)); + ept->refcnt = 1; ept->addr = src; ept->dest_addr = dest; ept->cb = cb; diff --git a/lib/rpmsg/rpmsg_internal.h b/lib/rpmsg/rpmsg_internal.h index 6721ecf88..2c6687ca3 100644 --- a/lib/rpmsg/rpmsg_internal.h +++ b/lib/rpmsg/rpmsg_internal.h @@ -109,6 +109,31 @@ rpmsg_get_ept_from_addr(struct rpmsg_device *rdev, uint32_t addr) return rpmsg_get_endpoint(rdev, NULL, addr, RPMSG_ADDR_ANY); } +/** + * @internal + * + * @brief Increase the endpoint reference count + * + * This function is used to avoid calling ept_cb after release lock causes race condition + * it should be called under lock protection. + * + * @param ept pointer to rpmsg endpoint + * + */ +void rpmsg_ept_incref(struct rpmsg_endpoint *ept); + +/** + * @internal + * + * @brief Decrease the end point reference count + * + * This function is used to avoid calling ept_cb after release lock causes race condition + * it should be called under lock protection. + * + * @param ept pointer to rpmsg endpoint + */ +void rpmsg_ept_decref(struct rpmsg_endpoint *ept); + #if defined __cplusplus } #endif diff --git a/lib/rpmsg/rpmsg_virtio.c b/lib/rpmsg/rpmsg_virtio.c index ea4cc0d9e..0b48ca6a9 100644 --- a/lib/rpmsg/rpmsg_virtio.c +++ b/lib/rpmsg/rpmsg_virtio.c @@ -558,6 +558,7 @@ static void rpmsg_virtio_rx_callback(struct virtqueue *vq) /* Get the channel node from the remote device channels list. */ metal_mutex_acquire(&rdev->lock); ept = rpmsg_get_ept_from_addr(rdev, rp_hdr->dst); + rpmsg_ept_incref(ept); metal_mutex_release(&rdev->lock); if (ept) { @@ -576,6 +577,7 @@ static void rpmsg_virtio_rx_callback(struct virtqueue *vq) } metal_mutex_acquire(&rdev->lock); + rpmsg_ept_decref(ept); /* Check whether callback wants to hold buffer */ if (!(rp_hdr->reserved & RPMSG_BUF_HELD)) { @@ -616,6 +618,7 @@ static int rpmsg_virtio_ns_callback(struct rpmsg_endpoint *ept, void *data, struct rpmsg_ns_msg *ns_msg; uint32_t dest; char name[RPMSG_NAME_SIZE]; + bool ept_decref; (void)priv; (void)src; @@ -636,11 +639,20 @@ static int rpmsg_virtio_ns_callback(struct rpmsg_endpoint *ept, void *data, if (ns_msg->flags & RPMSG_NS_DESTROY) { if (_ept) _ept->dest_addr = RPMSG_ADDR_ANY; + if (_ept && _ept->release_cb) { + rpmsg_ept_incref(_ept); + ept_decref = true; + } metal_mutex_release(&rdev->lock); if (_ept && _ept->ns_unbind_cb) _ept->ns_unbind_cb(_ept); if (rdev->ns_unbind_cb) rdev->ns_unbind_cb(rdev, name, dest); + if (ept_decref) { + metal_mutex_acquire(&rdev->lock); + rpmsg_ept_decref(_ept); + metal_mutex_release(&rdev->lock); + } } else { if (!_ept) { /*