From d250a5f31d094b87113fae4e8b6e1fb8ff836c96 Mon Sep 17 00:00:00 2001 From: Eric Long Date: Thu, 2 Jan 2025 19:15:13 +0800 Subject: [PATCH] kernel/nft: use nftables handle returned via `--echo` --- src/kernel/linux/mod.rs | 53 ++++++++++++++++------------------------- src/kernel/linux/nft.rs | 18 ++++++++++---- src/kernel/mod.rs | 6 ++--- src/kernel/rtnl.rs | 12 +++++----- 4 files changed, 42 insertions(+), 47 deletions(-) diff --git a/src/kernel/linux/mod.rs b/src/kernel/linux/mod.rs index 6d1ba94..9393812 100644 --- a/src/kernel/linux/mod.rs +++ b/src/kernel/linux/mod.rs @@ -10,19 +10,18 @@ use futures::join; use itertools::Itertools; use nft::Nftables; use nftables::batch::Batch; -use nftables::helper::NftablesError; -use nftables::schema::{NfCmd, NfObject, Nftables as NftablesReq}; +use nftables::schema::{NfCmd, NfListObject, NfObject, Nftables as NftablesReq}; use serde::{Deserialize, Serialize}; use std::borrow::Cow; +use std::collections::BTreeSet; use std::future::pending; #[derive(Debug, Serialize, Deserialize)] pub struct Linux { nft: Nftables, #[serde(skip)] - rtnl: Option, + rtnl: Option>, rtnl_args: RtNetlinkArgs, - counter: u64, } impl Linux { @@ -32,13 +31,12 @@ impl Linux { nft: Nftables::new(table, chain, hooked, hook_priority).await?, rtnl: None, rtnl_args: rtnl, - counter: 0, }) } } impl Kernel for Linux { - type Handle = u64; + type Handle = BTreeSet; // TODO: order async fn apply(&mut self, spec: &Flowspec, info: &RouteInfo<'_>) -> Result { @@ -63,19 +61,27 @@ impl Kernel for Linux { .collect::>() }); - let handle = self.counter; - self.counter += 1; let nftables = NftablesReq { objects: rules .into_iter() - .map(|x| NfObject::CmdObject(NfCmd::Add(self.nft.make_new_rule(x.into(), Some(handle))))) + .map(|x| NfObject::CmdObject(NfCmd::Add(self.nft.make_new_rule(x.into())))) .collect(), }; + let result = self.nft.apply_and_return_ruleset(&nftables).await?; + + let handle: Self::Handle = (result.objects.iter()) + .filter_map(|x| { + if let NfObject::CmdObject(NfCmd::Add(NfListObject::Rule(rule))) = x { + Some(rule.handle.unwrap()) + } else { + None + } + }) + .collect(); - self.nft.apply_ruleset(&nftables).await?; if let Some((next_hop, table_id)) = rt_info { let rtnl = self.rtnl.as_mut().expect("RtNetlink should be initialized"); - let real_table_id = rtnl.add(handle, spec, next_hop).await?; + let real_table_id = rtnl.add(handle.clone(), spec, next_hop).await?; assert_eq!(table_id, real_table_id, "table ID mismatch"); } @@ -83,36 +89,17 @@ impl Kernel for Linux { } async fn remove(&mut self, handle: Self::Handle) -> Result<()> { - #[derive(Debug, Deserialize)] - struct MyNftables { - nftables: Vec, - } - #[derive(Debug, Deserialize)] - struct MyNftObject { - rule: Option, - } - #[derive(Debug, Deserialize)] - struct MyNftRule { - comment: Option, - handle: u32, - } let mut batch = Batch::new(); - let s = self.nft.get_current_ruleset_raw().await?; - let MyNftables { nftables } = serde_json::from_str(&s).map_err(NftablesError::NftInvalidJson)?; - nftables - .into_iter() - .filter_map(|x| x.rule) - .filter(|x| x.comment.as_ref().is_some_and(|y| y == &handle.to_string())) - .for_each(|x| batch.delete(self.nft.make_rule_handle(x.handle))); + for h in handle.iter().copied() { + batch.delete(self.nft.make_rule_handle(h)); + } self.nft.apply_ruleset(&batch.to_nftables()).await?; - if let Some(rtnl) = &mut self.rtnl { rtnl.del(handle).await?; if rtnl.is_empty() { self.rtnl = None; } } - Ok(()) } diff --git a/src/kernel/linux/nft.rs b/src/kernel/linux/nft.rs index 40ec3b8..a792521 100644 --- a/src/kernel/linux/nft.rs +++ b/src/kernel/linux/nft.rs @@ -1,3 +1,4 @@ +use super::Linux; use crate::bgp::flow::{Bitmask, BitmaskFlags, Component, ComponentKind, Flowspec, Numeric, NumericFlags, Op, Ops}; use crate::bgp::route::{ExtCommunity, Ipv6ExtCommunity, RouteInfo, TrafficFilterAction, TrafficFilterActionKind}; use crate::kernel::rtnl::{RtNetlink, RtNetlinkArgs}; @@ -6,7 +7,9 @@ use crate::net::{Afi, IpPrefix}; use crate::util::{Intersect, TruthTable}; use nftables::batch::Batch; use nftables::expr::Expression::{Number as NUM, String as STRING}; -use nftables::helper::{apply_ruleset_async, get_current_ruleset_raw_async, DEFAULT_NFT}; +use nftables::helper::{ + apply_and_return_ruleset_async, apply_ruleset_async, get_current_ruleset_raw_async, DEFAULT_NFT, +}; use nftables::schema::Nftables as NftablesReq; use nftables::{expr, schema, stmt, types}; use num::Integer; @@ -57,14 +60,14 @@ impl Nftables { pub fn make_new_rule( &self, stmts: Cow<'static, [stmt::Statement]>, - comment: Option, + // comment: Option, ) -> schema::NfListObject<'static> { schema::NfListObject::Rule(schema::Rule { family: types::NfFamily::INet, table: self.table.clone(), chain: self.chain.clone(), expr: stmts, - comment: comment.map(|x| x.to_string().into()), + // comment: comment.map(|x| x.to_string().into()), ..Default::default() }) } @@ -79,6 +82,7 @@ impl Nftables { }) } + #[expect(unused)] pub async fn get_current_ruleset_raw(&self) -> Result { let args = ["-n", "-s", "list", "chain", "inet", &self.table, &self.chain]; Ok(get_current_ruleset_raw_async(DEFAULT_NFT, args).await?) @@ -88,6 +92,10 @@ impl Nftables { Ok(apply_ruleset_async(n).await?) } + pub async fn apply_and_return_ruleset(&self, n: &NftablesReq<'_>) -> Result> { + Ok(apply_and_return_ruleset_async(n).await?) + } + pub async fn terminate(self) { let mut batch = Batch::new(); batch.delete(schema::NfListObject::Chain(schema::Chain { @@ -273,7 +281,7 @@ impl RouteInfo<'_> { pub(super) fn to_nft_stmts( &self, afi: Afi, - rtnl: &mut Option, + rtnl: &mut Option>, rtnl_args: &RtNetlinkArgs, ) -> Option<(StatementBranch<'static>, Option<(IpAddr, u32)>)> { let set = (self.ext_comm.iter().copied()) @@ -320,7 +328,7 @@ impl TrafficFilterAction { fn to_nft_stmts( self, afi: Afi, - rtnl: &mut Option, + rtnl: &mut Option>, rtnl_args: &RtNetlinkArgs, ) -> (StatementBlock<'static>, Option<(IpAddr, u32)>, bool) { use TrafficFilterAction::*; diff --git a/src/kernel/mod.rs b/src/kernel/mod.rs index 226a316..d7c4145 100644 --- a/src/kernel/mod.rs +++ b/src/kernel/mod.rs @@ -21,7 +21,7 @@ use thiserror::Error; /// Interface between BGP flowspec and the OS. pub trait Kernel: Sized { /// Type representing a flowspec's counterpart in kernel. - type Handle; + type Handle: Eq + Ord; /// Apply a flowspec to kernel. fn apply(&mut self, spec: &Flowspec, info: &RouteInfo<'_>) -> impl Future>; @@ -96,13 +96,13 @@ impl Kernel for KernelAdapter { } } -#[derive(Debug, Display, Clone, Serialize, Deserialize)] +#[derive(Debug, Display, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum KernelHandle { #[strum(to_string = "()")] Noop, #[cfg(linux)] - #[strum(to_string = "{0}")] + #[strum(to_string = "{0:?}")] Linux(::Handle), } diff --git a/src/kernel/rtnl.rs b/src/kernel/rtnl.rs index 0b66602..99351ed 100644 --- a/src/kernel/rtnl.rs +++ b/src/kernel/rtnl.rs @@ -1,4 +1,4 @@ -use super::Result; +use super::{Kernel, Result}; use crate::bgp::flow::{Component, ComponentKind, Flowspec}; use crate::net::{Afi, IpPrefix}; use clap::Args; @@ -22,16 +22,16 @@ use tokio::select; use tokio::time::{interval, Interval}; #[derive(Debug)] -pub struct RtNetlink { +pub struct RtNetlink { args: RtNetlinkArgs, handle: Handle, msgs: UnboundedReceiver<(NetlinkMessage, rtnetlink::sys::SocketAddr)>, - routes: BTreeMap)>, + routes: BTreeMap)>, rules: BTreeMap>, timer: Interval, } -impl RtNetlink { +impl RtNetlink { pub fn new(args: RtNetlinkArgs) -> io::Result { let (conn, handle, msgs) = rtnetlink::new_connection()?; let scan_time = args.route_scan_time; @@ -46,7 +46,7 @@ impl RtNetlink { }) } - pub async fn add(&mut self, id: u64, spec: &Flowspec, next_hop: IpAddr) -> Result { + pub async fn add(&mut self, id: K::Handle, spec: &Flowspec, next_hop: IpAddr) -> Result { let prefix = spec .component_set() .get(&ComponentKind::DstPrefix) @@ -102,7 +102,7 @@ impl RtNetlink { .unwrap_or(self.args.init_table_id) } - pub async fn del(&mut self, id: u64) -> Result<()> { + pub async fn del(&mut self, id: K::Handle) -> Result<()> { let Some((prefix, _, table_id, _)) = self.routes.remove(&id) else { return Ok(()); };