Skip to content

Commit

Permalink
rpmsg: add release cb and refcnt in endpoint to fix ept used-after-free
Browse files Browse the repository at this point in the history
if rpmsg service free the ept when has got the ept from the ept
list in rpmsg_virtio_rx_callback, there is a used after free about
the ept, so add refcnt to end point and call the rpmsg service
release callback when ept callback fininshed.

Signed-off-by: Bowen Wang <wangbowen6@xiaomi.com>
  • Loading branch information
CV-Bowen authored and Tao Yin committed Oct 26, 2023
1 parent 4e13c09 commit 7ae1d28
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
7 changes: 7 additions & 0 deletions lib/include/openamp/rpmsg.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct rpmsg_device;
/* Returns positive value on success or negative error value on failure */
typedef int (*rpmsg_ept_cb)(struct rpmsg_endpoint *ept, void *data,
size_t len, uint32_t src, void *priv);
typedef void (*rpmsg_ept_release_cb)(struct rpmsg_endpoint *ept);
typedef void (*rpmsg_ns_unbind_cb)(struct rpmsg_endpoint *ept);
typedef void (*rpmsg_ns_bind_cb)(struct rpmsg_device *rdev,
const char *name, uint32_t dest);
Expand All @@ -73,6 +74,12 @@ struct rpmsg_endpoint {
/** Address of the default remote endpoint binded */
uint32_t dest_addr;

/** Reference count of the endpoint */
uint32_t refcnt;

/** Callback to inform the user that the endpoint allocation can be safely removed */
rpmsg_ept_release_cb release_cb;

/**
* User rx callback, return value of this callback is reserved for future
* use, for now, only allow RPMSG_SUCCESS as return value
Expand Down
19 changes: 19 additions & 0 deletions lib/rpmsg/rpmsg.c
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ static int rpmsg_set_address(unsigned long *bitmap, int size, int addr)
}
}

void rpmsg_ept_incref(struct rpmsg_endpoint *ept)
{
if (ept && ept->release_cb) {
ept->refcnt++;
}
}

void rpmsg_ept_decref(struct rpmsg_endpoint *ept)
{
if (ept && ept->release_cb) {
ept->refcnt--;
if (!ept->refcnt) {
ept->release_cb(ept);
}
}
}

int rpmsg_send_offchannel_raw(struct rpmsg_endpoint *ept, uint32_t src,
uint32_t dst, const void *data, int len,
int wait)
Expand Down Expand Up @@ -246,6 +263,7 @@ static void rpmsg_unregister_endpoint(struct rpmsg_endpoint *ept)
ept->addr);
metal_list_del(&ept->node);
ept->rdev = NULL;
rpmsg_ept_decref(ept);
metal_mutex_release(&rdev->lock);
}

Expand All @@ -257,6 +275,7 @@ void rpmsg_register_endpoint(struct rpmsg_device *rdev,
rpmsg_ns_unbind_cb ns_unbind_cb)
{
strncpy(ept->name, name ? name : "", sizeof(ept->name));
ept->refcnt = 1;
ept->addr = src;
ept->dest_addr = dest;
ept->cb = cb;
Expand Down
25 changes: 25 additions & 0 deletions lib/rpmsg/rpmsg_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,31 @@ rpmsg_get_ept_from_addr(struct rpmsg_device *rdev, uint32_t addr)
return rpmsg_get_endpoint(rdev, NULL, addr, RPMSG_ADDR_ANY);
}

/**
* @internal
*
* @brief Increase the endpoint reference count
*
* This function is used to avoid calling ept_cb after release lock causes race condition
* it should be called under lock protection.
*
* @param ept pointer to rpmsg endpoint
*
*/
void rpmsg_ept_incref(struct rpmsg_endpoint *ept);

/**
* @internal
*
* @brief Decrease the end point reference count
*
* This function is used to avoid calling ept_cb after release lock causes race condition
* it should be called under lock protection.
*
* @param ept pointer to rpmsg endpoint
*/
void rpmsg_ept_decref(struct rpmsg_endpoint *ept);

#if defined __cplusplus
}
#endif
Expand Down
12 changes: 12 additions & 0 deletions lib/rpmsg/rpmsg_virtio.c
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ static void rpmsg_virtio_rx_callback(struct virtqueue *vq)
/* Get the channel node from the remote device channels list. */
metal_mutex_acquire(&rdev->lock);
ept = rpmsg_get_ept_from_addr(rdev, rp_hdr->dst);
rpmsg_ept_incref(ept);
metal_mutex_release(&rdev->lock);

if (ept) {
Expand All @@ -576,6 +577,7 @@ static void rpmsg_virtio_rx_callback(struct virtqueue *vq)
}

metal_mutex_acquire(&rdev->lock);
rpmsg_ept_decref(ept);

/* Check whether callback wants to hold buffer */
if (!(rp_hdr->reserved & RPMSG_BUF_HELD)) {
Expand Down Expand Up @@ -616,6 +618,7 @@ static int rpmsg_virtio_ns_callback(struct rpmsg_endpoint *ept, void *data,
struct rpmsg_ns_msg *ns_msg;
uint32_t dest;
char name[RPMSG_NAME_SIZE];
bool ept_decref;

(void)priv;
(void)src;
Expand All @@ -636,11 +639,20 @@ static int rpmsg_virtio_ns_callback(struct rpmsg_endpoint *ept, void *data,
if (ns_msg->flags & RPMSG_NS_DESTROY) {
if (_ept)
_ept->dest_addr = RPMSG_ADDR_ANY;
if (_ept && _ept->release_cb) {
rpmsg_ept_incref(_ept);
ept_decref = true;
}
metal_mutex_release(&rdev->lock);
if (_ept && _ept->ns_unbind_cb)
_ept->ns_unbind_cb(_ept);
if (rdev->ns_unbind_cb)
rdev->ns_unbind_cb(rdev, name, dest);
if (ept_decref) {
metal_mutex_acquire(&rdev->lock);
rpmsg_ept_decref(_ept);
metal_mutex_release(&rdev->lock);
}
} else {
if (!_ept) {
/*
Expand Down

0 comments on commit 7ae1d28

Please sign in to comment.