Skip to content

Commit

Permalink
kernel/nft: use nftables handle returned via --echo
Browse files Browse the repository at this point in the history
  • Loading branch information
hack3ric committed Jan 2, 2025
1 parent 2602e84 commit d250a5f
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 47 deletions.
53 changes: 20 additions & 33 deletions src/kernel/linux/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@ use futures::join;
use itertools::Itertools;
use nft::Nftables;
use nftables::batch::Batch;
use nftables::helper::NftablesError;
use nftables::schema::{NfCmd, NfObject, Nftables as NftablesReq};
use nftables::schema::{NfCmd, NfListObject, NfObject, Nftables as NftablesReq};
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::collections::BTreeSet;
use std::future::pending;

#[derive(Debug, Serialize, Deserialize)]
pub struct Linux {
nft: Nftables,
#[serde(skip)]
rtnl: Option<RtNetlink>,
rtnl: Option<RtNetlink<Self>>,
rtnl_args: RtNetlinkArgs,
counter: u64,
}

impl Linux {
Expand All @@ -32,13 +31,12 @@ impl Linux {
nft: Nftables::new(table, chain, hooked, hook_priority).await?,
rtnl: None,
rtnl_args: rtnl,
counter: 0,
})
}
}

impl Kernel for Linux {
type Handle = u64;
type Handle = BTreeSet<u32>;

// TODO: order
async fn apply(&mut self, spec: &Flowspec, info: &RouteInfo<'_>) -> Result<Self::Handle> {
Expand All @@ -63,56 +61,45 @@ impl Kernel for Linux {
.collect::<Vec<_>>()
});

let handle = self.counter;
self.counter += 1;
let nftables = NftablesReq {
objects: rules
.into_iter()
.map(|x| NfObject::CmdObject(NfCmd::Add(self.nft.make_new_rule(x.into(), Some(handle)))))
.map(|x| NfObject::CmdObject(NfCmd::Add(self.nft.make_new_rule(x.into()))))
.collect(),
};
let result = self.nft.apply_and_return_ruleset(&nftables).await?;

let handle: Self::Handle = (result.objects.iter())
.filter_map(|x| {
if let NfObject::CmdObject(NfCmd::Add(NfListObject::Rule(rule))) = x {
Some(rule.handle.unwrap())
} else {
None
}
})
.collect();

self.nft.apply_ruleset(&nftables).await?;
if let Some((next_hop, table_id)) = rt_info {
let rtnl = self.rtnl.as_mut().expect("RtNetlink should be initialized");
let real_table_id = rtnl.add(handle, spec, next_hop).await?;
let real_table_id = rtnl.add(handle.clone(), spec, next_hop).await?;
assert_eq!(table_id, real_table_id, "table ID mismatch");
}

Ok(handle)
}

async fn remove(&mut self, handle: Self::Handle) -> Result<()> {
#[derive(Debug, Deserialize)]
struct MyNftables {
nftables: Vec<MyNftObject>,
}
#[derive(Debug, Deserialize)]
struct MyNftObject {
rule: Option<MyNftRule>,
}
#[derive(Debug, Deserialize)]
struct MyNftRule {
comment: Option<String>,
handle: u32,
}
let mut batch = Batch::new();
let s = self.nft.get_current_ruleset_raw().await?;
let MyNftables { nftables } = serde_json::from_str(&s).map_err(NftablesError::NftInvalidJson)?;
nftables
.into_iter()
.filter_map(|x| x.rule)
.filter(|x| x.comment.as_ref().is_some_and(|y| y == &handle.to_string()))
.for_each(|x| batch.delete(self.nft.make_rule_handle(x.handle)));
for h in handle.iter().copied() {
batch.delete(self.nft.make_rule_handle(h));
}
self.nft.apply_ruleset(&batch.to_nftables()).await?;

if let Some(rtnl) = &mut self.rtnl {
rtnl.del(handle).await?;
if rtnl.is_empty() {
self.rtnl = None;
}
}

Ok(())
}

Expand Down
18 changes: 13 additions & 5 deletions src/kernel/linux/nft.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::Linux;
use crate::bgp::flow::{Bitmask, BitmaskFlags, Component, ComponentKind, Flowspec, Numeric, NumericFlags, Op, Ops};
use crate::bgp::route::{ExtCommunity, Ipv6ExtCommunity, RouteInfo, TrafficFilterAction, TrafficFilterActionKind};
use crate::kernel::rtnl::{RtNetlink, RtNetlinkArgs};
Expand All @@ -6,7 +7,9 @@ use crate::net::{Afi, IpPrefix};
use crate::util::{Intersect, TruthTable};
use nftables::batch::Batch;
use nftables::expr::Expression::{Number as NUM, String as STRING};
use nftables::helper::{apply_ruleset_async, get_current_ruleset_raw_async, DEFAULT_NFT};
use nftables::helper::{
apply_and_return_ruleset_async, apply_ruleset_async, get_current_ruleset_raw_async, DEFAULT_NFT,
};
use nftables::schema::Nftables as NftablesReq;
use nftables::{expr, schema, stmt, types};
use num::Integer;
Expand Down Expand Up @@ -57,14 +60,14 @@ impl Nftables {
pub fn make_new_rule(
&self,
stmts: Cow<'static, [stmt::Statement]>,
comment: Option<impl ToString>,
// comment: Option<impl ToString>,
) -> schema::NfListObject<'static> {
schema::NfListObject::Rule(schema::Rule {
family: types::NfFamily::INet,
table: self.table.clone(),
chain: self.chain.clone(),
expr: stmts,
comment: comment.map(|x| x.to_string().into()),
// comment: comment.map(|x| x.to_string().into()),
..Default::default()
})
}
Expand All @@ -79,6 +82,7 @@ impl Nftables {
})
}

#[expect(unused)]
pub async fn get_current_ruleset_raw(&self) -> Result<String> {
let args = ["-n", "-s", "list", "chain", "inet", &self.table, &self.chain];
Ok(get_current_ruleset_raw_async(DEFAULT_NFT, args).await?)
Expand All @@ -88,6 +92,10 @@ impl Nftables {
Ok(apply_ruleset_async(n).await?)
}

pub async fn apply_and_return_ruleset(&self, n: &NftablesReq<'_>) -> Result<NftablesReq<'static>> {
Ok(apply_and_return_ruleset_async(n).await?)
}

pub async fn terminate(self) {
let mut batch = Batch::new();
batch.delete(schema::NfListObject::Chain(schema::Chain {
Expand Down Expand Up @@ -273,7 +281,7 @@ impl RouteInfo<'_> {
pub(super) fn to_nft_stmts(
&self,
afi: Afi,
rtnl: &mut Option<RtNetlink>,
rtnl: &mut Option<RtNetlink<Linux>>,
rtnl_args: &RtNetlinkArgs,
) -> Option<(StatementBranch<'static>, Option<(IpAddr, u32)>)> {
let set = (self.ext_comm.iter().copied())
Expand Down Expand Up @@ -320,7 +328,7 @@ impl TrafficFilterAction {
fn to_nft_stmts(
self,
afi: Afi,
rtnl: &mut Option<RtNetlink>,
rtnl: &mut Option<RtNetlink<Linux>>,
rtnl_args: &RtNetlinkArgs,
) -> (StatementBlock<'static>, Option<(IpAddr, u32)>, bool) {
use TrafficFilterAction::*;
Expand Down
6 changes: 3 additions & 3 deletions src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use thiserror::Error;
/// Interface between BGP flowspec and the OS.
pub trait Kernel: Sized {
/// Type representing a flowspec's counterpart in kernel.
type Handle;
type Handle: Eq + Ord;

/// Apply a flowspec to kernel.
fn apply(&mut self, spec: &Flowspec, info: &RouteInfo<'_>) -> impl Future<Output = Result<Self::Handle>>;
Expand Down Expand Up @@ -96,13 +96,13 @@ impl Kernel for KernelAdapter {
}
}

#[derive(Debug, Display, Clone, Serialize, Deserialize)]
#[derive(Debug, Display, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum KernelHandle {
#[strum(to_string = "()")]
Noop,

#[cfg(linux)]
#[strum(to_string = "{0}")]
#[strum(to_string = "{0:?}")]
Linux(<Linux as Kernel>::Handle),
}

Expand Down
12 changes: 6 additions & 6 deletions src/kernel/rtnl.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::Result;
use super::{Kernel, Result};
use crate::bgp::flow::{Component, ComponentKind, Flowspec};
use crate::net::{Afi, IpPrefix};
use clap::Args;
Expand All @@ -22,16 +22,16 @@ use tokio::select;
use tokio::time::{interval, Interval};

#[derive(Debug)]
pub struct RtNetlink {
pub struct RtNetlink<K: Kernel> {
args: RtNetlinkArgs,
handle: Handle,
msgs: UnboundedReceiver<(NetlinkMessage<RouteNetlinkMessage>, rtnetlink::sys::SocketAddr)>,
routes: BTreeMap<u64, (IpPrefix, IpAddr, u32, Vec<RouteAttribute>)>,
routes: BTreeMap<K::Handle, (IpPrefix, IpAddr, u32, Vec<RouteAttribute>)>,
rules: BTreeMap<u32, BTreeSet<IpPrefix>>,
timer: Interval,
}

impl RtNetlink {
impl<K: Kernel> RtNetlink<K> {
pub fn new(args: RtNetlinkArgs) -> io::Result<Self> {
let (conn, handle, msgs) = rtnetlink::new_connection()?;
let scan_time = args.route_scan_time;
Expand All @@ -46,7 +46,7 @@ impl RtNetlink {
})
}

pub async fn add(&mut self, id: u64, spec: &Flowspec, next_hop: IpAddr) -> Result<u32> {
pub async fn add(&mut self, id: K::Handle, spec: &Flowspec, next_hop: IpAddr) -> Result<u32> {
let prefix = spec
.component_set()
.get(&ComponentKind::DstPrefix)
Expand Down Expand Up @@ -102,7 +102,7 @@ impl RtNetlink {
.unwrap_or(self.args.init_table_id)
}

pub async fn del(&mut self, id: u64) -> Result<()> {
pub async fn del(&mut self, id: K::Handle) -> Result<()> {
let Some((prefix, _, table_id, _)) = self.routes.remove(&id) else {
return Ok(());
};
Expand Down

0 comments on commit d250a5f

Please sign in to comment.