diff --git a/src/lib.rs b/src/lib.rs index 3dcc264..675219f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,25 +28,29 @@ use nix::{ sockopt::{ReceiveTimeout, SendTimeout, SocketError}, AddressFamily, GetSockOpt, MsgFlags, SetSockOpt, SockFlag, SockType, }, - unistd::close, }; -use std::fs::File; -use std::io::{Error, ErrorKind, Read, Result, Write}; -use std::mem::{self, size_of}; +use std::mem::size_of; use std::net::Shutdown; use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use std::time::Duration; +use std::{fs::File, os::fd::OwnedFd}; +use std::{ + io::{Error, ErrorKind, Read, Result, Write}, + os::fd::{AsFd, BorrowedFd}, +}; pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL}; pub use nix::sys::socket::{SockaddrLike, VsockAddr}; -fn new_socket() -> Result<RawFd> { - Ok(socket( +fn new_socket() -> Result<OwnedFd> { + let fd = socket( AddressFamily::Vsock, SockType::Stream, SockFlag::SOCK_CLOEXEC, None, - )?) + )?; + // SAFETY: We just created a new file descriptor, so we can take ownership of it. + unsafe { Ok(OwnedFd::from_raw_fd(fd)) } } /// An iterator that infinitely accepts connections on a VsockListener. @@ -64,9 +68,9 @@ impl<'a> Iterator for Incoming<'a> { } /// A virtio socket server, listening for connections. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct VsockListener { - socket: RawFd, + socket: OwnedFd, } impl VsockListener { @@ -81,10 +85,10 @@ impl VsockListener { let socket = new_socket()?; - bind(socket, addr)?; + bind(socket.as_raw_fd(), addr)?; // rust stdlib uses a 128 connection backlog - listen(socket, 128)?; + listen(socket.as_raw_fd(), 128)?; Ok(Self { socket }) } @@ -96,12 +100,14 @@ impl VsockListener { /// The local socket address of the listener. pub fn local_addr(&self) -> Result<VsockAddr> { - Ok(getsockname(self.socket)?) + Ok(getsockname(self.socket.as_raw_fd())?) } /// Create a new independently owned handle to the underlying socket. pub fn try_clone(&self) -> Result<Self> { - Ok(self.clone()) + Ok(Self { + socket: self.socket.try_clone()?, + }) } /// Accept a new incoming connection from this listener. @@ -116,7 +122,7 @@ impl VsockListener { let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t; let socket = unsafe { accept4( - self.socket, + self.socket.as_raw_fd(), &mut vsock_addr as *mut _ as *mut sockaddr, &mut vsock_addr_len, SOCK_CLOEXEC, @@ -139,7 +145,7 @@ impl VsockListener { /// Retrieve the latest error associated with the underlying socket. pub fn take_error(&self) -> Result<Option<Error>> { - let error = SocketError.get(self.socket)?; + let error = SocketError.get(self.socket.as_raw_fd())?; Ok(if error == 0 { None } else { @@ -150,7 +156,7 @@ impl VsockListener { /// Move this stream in and out of nonblocking mode. pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> { let mut nonblocking: i32 = if nonblocking { 1 } else { 0 }; - if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 { + if unsafe { ioctl(self.socket.as_raw_fd(), FIONBIO, &mut nonblocking) } < 0 { Err(Error::last_os_error()) } else { Ok(()) @@ -160,34 +166,34 @@ impl VsockListener { impl AsRawFd for VsockListener { fn as_raw_fd(&self) -> RawFd { - self.socket + self.socket.as_raw_fd() + } +} + +impl AsFd for VsockListener { + fn as_fd(&self) -> BorrowedFd { + self.socket.as_fd() } } impl FromRawFd for VsockListener { unsafe fn from_raw_fd(socket: RawFd) -> Self { - Self { socket } + Self { + socket: OwnedFd::from_raw_fd(socket), + } } } impl IntoRawFd for VsockListener { fn into_raw_fd(self) -> RawFd { - let fd = self.socket; - mem::forget(self); - fd - } -} - -impl Drop for VsockListener { - fn drop(&mut self) { - let _ = close(self.socket); + self.socket.into_raw_fd() } } /// A virtio stream between a local and a remote socket. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct VsockStream { - socket: RawFd, + socket: OwnedFd, } impl VsockStream { @@ -200,9 +206,9 @@ impl VsockStream { )); } - let sock = new_socket()?; - connect(sock, addr)?; - Ok(unsafe { Self::from_raw_fd(sock) }) + let socket = new_socket()?; + connect(socket.as_raw_fd(), addr)?; + Ok(Self { socket }) } /// Open a connection to a remote host with specified cid and port. @@ -212,12 +218,12 @@ impl VsockStream { /// Virtio socket address of the remote peer associated with this connection. pub fn peer_addr(&self) -> Result<VsockAddr> { - Ok(getpeername(self.socket)?) + Ok(getpeername(self.socket.as_raw_fd())?) } /// Virtio socket address of the local address associated with this connection. pub fn local_addr(&self) -> Result<VsockAddr> { - Ok(getsockname(self.socket)?) + Ok(getsockname(self.socket.as_raw_fd())?) } /// Shutdown the read, write, or both halves of this connection. @@ -227,29 +233,31 @@ impl VsockStream { Shutdown::Read => socket::Shutdown::Read, Shutdown::Both => socket::Shutdown::Both, }; - Ok(shutdown(self.socket, how)?) + Ok(shutdown(self.socket.as_raw_fd(), how)?) } /// Create a new independently owned handle to the underlying socket. pub fn try_clone(&self) -> Result<Self> { - Ok(self.clone()) + Ok(Self { + socket: self.socket.try_clone()?, + }) } /// Set the timeout on read operations. pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> { let timeout = Self::timeval_from_duration(dur)?.into(); - Ok(ReceiveTimeout.set(self.socket, &timeout)?) + Ok(ReceiveTimeout.set(self.socket.as_raw_fd(), &timeout)?) } /// Set the timeout on write operations. pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> { let timeout = Self::timeval_from_duration(dur)?.into(); - Ok(SendTimeout.set(self.socket, &timeout)?) + Ok(SendTimeout.set(self.socket.as_raw_fd(), &timeout)?) } /// Retrieve the latest error associated with the underlying socket. pub fn take_error(&self) -> Result<Option<Error>> { - let error = SocketError.get(self.socket)?; + let error = SocketError.get(self.socket.as_raw_fd())?; Ok(if error == 0 { None } else { @@ -260,7 +268,7 @@ impl VsockStream { /// Move this stream in and out of nonblocking mode. pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> { let mut nonblocking: i32 = if nonblocking { 1 } else { 0 }; - if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 { + if unsafe { ioctl(self.socket.as_raw_fd(), FIONBIO, &mut nonblocking) } < 0 { Err(Error::last_os_error()) } else { Ok(()) @@ -319,13 +327,13 @@ impl Write for VsockStream { impl Read for &VsockStream { fn read(&mut self, buf: &mut [u8]) -> Result<usize> { - Ok(recv(self.socket, buf, MsgFlags::empty())?) + Ok(recv(self.socket.as_raw_fd(), buf, MsgFlags::empty())?) } } impl Write for &VsockStream { fn write(&mut self, buf: &[u8]) -> Result<usize> { - Ok(send(self.socket, buf, MsgFlags::MSG_NOSIGNAL)?) + Ok(send(self.socket.as_raw_fd(), buf, MsgFlags::MSG_NOSIGNAL)?) } fn flush(&mut self) -> Result<()> { @@ -335,27 +343,27 @@ impl Write for &VsockStream { impl AsRawFd for VsockStream { fn as_raw_fd(&self) -> RawFd { - self.socket + self.socket.as_raw_fd() + } +} + +impl AsFd for VsockStream { + fn as_fd(&self) -> BorrowedFd { + self.socket.as_fd() } } impl FromRawFd for VsockStream { unsafe fn from_raw_fd(socket: RawFd) -> Self { - Self { socket } + Self { + socket: OwnedFd::from_raw_fd(socket), + } } } impl IntoRawFd for VsockStream { fn into_raw_fd(self) -> RawFd { - let fd = self.socket; - mem::forget(self); - fd - } -} - -impl Drop for VsockStream { - fn drop(&mut self) { - let _ = close(self.socket); + self.socket.into_raw_fd() } }