From dafdc5beec70ad10a5334eaf523bed1a7f14ab17 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Wed, 13 Sep 2023 11:07:54 +0800 Subject: [PATCH] rpmsg: add release cb and refcnt in endpoint to fix ept used-after-free if rpmsg service free the ept when has got the ept from the ept list in rpmsg_virtio_rx_callback, there is a used after free about the ept, so add refcnt to end point and call the rpmsg service release callback when ept callback fininshed. Signed-off-by: Bowen Wang --- lib/include/openamp/rpmsg.h | 7 +++++++ lib/rpmsg/rpmsg.c | 22 +++++++++++++++++++++- lib/rpmsg/rpmsg_internal.h | 25 +++++++++++++++++++++++++ lib/rpmsg/rpmsg_virtio.c | 16 ++++++++++++++++ 4 files changed, 69 insertions(+), 1 deletion(-) diff --git a/lib/include/openamp/rpmsg.h b/lib/include/openamp/rpmsg.h index 35573ca78..17eaddec2 100644 --- a/lib/include/openamp/rpmsg.h +++ b/lib/include/openamp/rpmsg.h @@ -50,6 +50,7 @@ struct rpmsg_device; /* Returns positive value on success or negative error value on failure */ typedef int (*rpmsg_ept_cb)(struct rpmsg_endpoint *ept, void *data, size_t len, uint32_t src, void *priv); +typedef void (*rpmsg_ept_release_cb)(struct rpmsg_endpoint *ept); typedef void (*rpmsg_ns_unbind_cb)(struct rpmsg_endpoint *ept); typedef void (*rpmsg_ns_bind_cb)(struct rpmsg_device *rdev, const char *name, uint32_t dest); @@ -73,6 +74,12 @@ struct rpmsg_endpoint { /** Address of the default remote endpoint binded */ uint32_t dest_addr; + /** Reference count for determining whether the endpoint can be deallocated */ + uint32_t refcnt; + + /** Callback to inform the user that the endpoint allocation can be safely removed */ + rpmsg_ept_release_cb release_cb; + /** * User rx callback, return value of this callback is reserved for future * use, for now, only allow RPMSG_SUCCESS as return value diff --git a/lib/rpmsg/rpmsg.c b/lib/rpmsg/rpmsg.c index a9f1c1e84..3be947d3c 100644 --- a/lib/rpmsg/rpmsg.c +++ b/lib/rpmsg/rpmsg.c @@ -97,6 +97,25 @@ static int rpmsg_set_address(unsigned long *bitmap, int size, int addr) } } +void rpmsg_ept_incref(struct rpmsg_endpoint *ept) +{ + if (ept) + ept->refcnt++; +} + +void rpmsg_ept_decref(struct rpmsg_endpoint *ept) +{ + if (ept) { + ept->refcnt--; + if (!ept->refcnt) { + if (ept->release_cb) + ept->release_cb(ept); + else + ept->rdev = NULL; + } + } +} + int rpmsg_send_offchannel_raw(struct rpmsg_endpoint *ept, uint32_t src, uint32_t dst, const void *data, int len, int wait) @@ -245,7 +264,7 @@ static void rpmsg_unregister_endpoint(struct rpmsg_endpoint *ept) rpmsg_release_address(rdev->bitmap, RPMSG_ADDR_BMP_SIZE, ept->addr); metal_list_del(&ept->node); - ept->rdev = NULL; + rpmsg_ept_decref(ept); metal_mutex_release(&rdev->lock); } @@ -257,6 +276,7 @@ void rpmsg_register_endpoint(struct rpmsg_device *rdev, rpmsg_ns_unbind_cb ns_unbind_cb) { strncpy(ept->name, name ? name : "", sizeof(ept->name)); + ept->refcnt = 1; ept->addr = src; ept->dest_addr = dest; ept->cb = cb; diff --git a/lib/rpmsg/rpmsg_internal.h b/lib/rpmsg/rpmsg_internal.h index 6721ecf88..2c6687ca3 100644 --- a/lib/rpmsg/rpmsg_internal.h +++ b/lib/rpmsg/rpmsg_internal.h @@ -109,6 +109,31 @@ rpmsg_get_ept_from_addr(struct rpmsg_device *rdev, uint32_t addr) return rpmsg_get_endpoint(rdev, NULL, addr, RPMSG_ADDR_ANY); } +/** + * @internal + * + * @brief Increase the endpoint reference count + * + * This function is used to avoid calling ept_cb after release lock causes race condition + * it should be called under lock protection. + * + * @param ept pointer to rpmsg endpoint + * + */ +void rpmsg_ept_incref(struct rpmsg_endpoint *ept); + +/** + * @internal + * + * @brief Decrease the end point reference count + * + * This function is used to avoid calling ept_cb after release lock causes race condition + * it should be called under lock protection. + * + * @param ept pointer to rpmsg endpoint + */ +void rpmsg_ept_decref(struct rpmsg_endpoint *ept); + #if defined __cplusplus } #endif diff --git a/lib/rpmsg/rpmsg_virtio.c b/lib/rpmsg/rpmsg_virtio.c index d9ee8c4c0..e80bf66dd 100644 --- a/lib/rpmsg/rpmsg_virtio.c +++ b/lib/rpmsg/rpmsg_virtio.c @@ -514,6 +514,7 @@ static void rpmsg_virtio_rx_callback(struct virtqueue *vq) /* Get the channel node from the remote device channels list. */ metal_mutex_acquire(&rdev->lock); ept = rpmsg_get_ept_from_addr(rdev, rp_hdr->dst); + rpmsg_ept_incref(ept); metal_mutex_release(&rdev->lock); if (ept) { @@ -532,6 +533,7 @@ static void rpmsg_virtio_rx_callback(struct virtqueue *vq) } metal_mutex_acquire(&rdev->lock); + rpmsg_ept_decref(ept); /* Check whether callback wants to hold buffer */ if (!(rp_hdr->reserved & RPMSG_BUF_HELD)) { @@ -571,6 +573,7 @@ static int rpmsg_virtio_ns_callback(struct rpmsg_endpoint *ept, void *data, struct rpmsg_endpoint *_ept; struct rpmsg_ns_msg *ns_msg; uint32_t dest; + bool ept_to_release; char name[RPMSG_NAME_SIZE]; (void)priv; @@ -589,14 +592,27 @@ static int rpmsg_virtio_ns_callback(struct rpmsg_endpoint *ept, void *data, metal_mutex_acquire(&rdev->lock); _ept = rpmsg_get_endpoint(rdev, name, RPMSG_ADDR_ANY, dest); + /* + * If ept-release callback is not implemented, ns_unbind_cb() can free the ept. + * Test _ept->release_cb before calling ns_unbind_cb() callbacks. + */ + ept_to_release = _ept && _ept->release_cb; + if (ns_msg->flags & RPMSG_NS_DESTROY) { if (_ept) _ept->dest_addr = RPMSG_ADDR_ANY; + if (ept_to_release) + rpmsg_ept_incref(_ept); metal_mutex_release(&rdev->lock); if (_ept && _ept->ns_unbind_cb) _ept->ns_unbind_cb(_ept); if (rdev->ns_unbind_cb) rdev->ns_unbind_cb(rdev, name, dest); + if (ept_to_release) { + metal_mutex_acquire(&rdev->lock); + rpmsg_ept_decref(_ept); + metal_mutex_release(&rdev->lock); + } } else { if (!_ept) { /*