diff --git a/.gitignore b/.gitignore index bca6f9b..ac452f2 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ # will have compiled files and executables debug/ target/ +temp/ # These are backup files generated by rustfmt **/*.rs.bk diff --git a/Cargo.lock b/Cargo.lock index 10c0a6e..763535d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -184,6 +184,15 @@ dependencies = [ "inout", ] +[[package]] +name = "circuit_macro" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "clap" version = "4.5.19" @@ -216,6 +225,7 @@ dependencies = [ "anyhow", "bincode", "blake3", + "circuit_macro", "curve25519-dalek", "garble_lang", "hex", diff --git a/Cargo.toml b/Cargo.toml index c8a525f..c711d20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,8 @@ resolver = "2" members = [ "benchmark", "compute", - "vm", + "vm", + "circuit_macro", ] [workspace.package] diff --git a/benchmark/benches/benchmarks.rs b/benchmark/benches/benchmarks.rs index 0dfc14e..7070e6f 100644 --- a/benchmark/benches/benchmarks.rs +++ b/benchmark/benches/benchmarks.rs @@ -678,6 +678,90 @@ fn gateway_encrypted_le() -> Result<(), Box> { Ok(()) } +fn tfhe_encrypted_division() -> Result<(), Box> { + use tfhe::prelude::*; + use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint128}; + // Basic configuration to use homomorphic integers + let config = ConfigBuilder::default().build(); + + // Key generation + let (client_key, server_keys) = generate_keys(config); + + let clear_a = 12345678910u128; + let clear_b = 1234; + + // Encrypting the input data using the (private) client_key + let encrypted_a = FheUint128::try_encrypt(clear_a, &client_key).unwrap(); + let encrypted_b = FheUint128::try_encrypt(clear_b, &client_key).unwrap(); + + // On the server side: + set_server_key(server_keys); + + // Clear equivalent computations: 12345678910 * 1234 + let encrypted_res_mul = &encrypted_a / &encrypted_b; + + let clear_res: u128 = encrypted_res_mul.decrypt(&client_key); + assert_eq!(clear_res, clear_a / clear_b); + + Ok(()) +} + +fn gateway_encrypted_division() -> Result<(), Box> { + use compute::uint::GarbledUint128; + + let clear_a = 12345678910u128; + let clear_b = 1234; + + let a: GarbledUint128 = clear_a.into(); + let b: GarbledUint128 = clear_b.into(); + + let result: u128 = (&a / &b).into(); + assert_eq!(result, clear_a / clear_b); + Ok(()) +} + +fn tfhe_encrypted_modulus() -> Result<(), Box> { + use tfhe::prelude::*; + use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint128}; + // Basic configuration to use homomorphic integers + let config = ConfigBuilder::default().build(); + + // Key generation + let (client_key, server_keys) = generate_keys(config); + + let clear_a = 12345678910u128; + let clear_b = 1234; + + // Encrypting the input data using the (private) client_key + let encrypted_a = FheUint128::try_encrypt(clear_a, &client_key).unwrap(); + let encrypted_b = FheUint128::try_encrypt(clear_b, &client_key).unwrap(); + + // On the server side: + set_server_key(server_keys); + + // Clear equivalent computations: 12345678910 * 1234 + let encrypted_res_mul = &encrypted_a % &encrypted_b; + + let clear_res: u128 = encrypted_res_mul.decrypt(&client_key); + assert_eq!(clear_res, clear_a % clear_b); + + Ok(()) +} + +fn gateway_encrypted_modulus() -> Result<(), Box> { + use compute::uint::GarbledUint128; + + let clear_a = 12345678910u128; + let clear_b = 1234; + + let a: GarbledUint128 = clear_a.into(); + let b: GarbledUint128 = clear_b.into(); + + let result: u128 = (&a % &b).into(); + assert_eq!(result, clear_a % clear_b); + Ok(()) +} + fn tfhe_encrypted_mux() { use tfhe::boolean::prelude::*; // We generate a set of client/server keys, using the default parameters: @@ -711,7 +795,7 @@ fn gateway_encrypted_mux() { let b: GarbledBoolean = bool2.into(); let c: GarbledBoolean = bool3.into(); - let result = a.mux(&b, &c); + let result = GarbledBoolean::mux(&a, &b, &c); let result: bool = result.into(); assert_eq!(result, if bool1 { bool2 } else { bool3 }); } @@ -914,6 +998,32 @@ fn benchmark_tfhe_encrypted_mux(c: &mut Criterion) { c.bench_function("tfhe_encrypted_mux", |b| b.iter(tfhe_encrypted_mux)); } +// Benchmark 35: Benchmarking benchmark_gateway_encrypted_division +fn benchmark_gateway_encrypted_division(c: &mut Criterion) { + c.bench_function("gateway_encrypted_division", |b| { + b.iter(gateway_encrypted_division) + }); +} + +// Benchmark 36: Benchmarking benchmark_tfhe_encrypted_division +fn benchmark_tfhe_encrypted_division(c: &mut Criterion) { + c.bench_function("tfhe_encrypted_division", |b| { + b.iter(tfhe_encrypted_division) + }); +} + +// Benchmark 37: Benchmarking benchmark_gateway_encrypted_modulus +fn benchmark_gateway_encrypted_modulus(c: &mut Criterion) { + c.bench_function("gateway_encrypted_modulus", |b| { + b.iter(gateway_encrypted_modulus) + }); +} + +// Benchmark 38: Benchmarking benchmark_tfhe_encrypted_modulus +fn benchmark_tfhe_encrypted_modulus(c: &mut Criterion) { + c.bench_function("tfhe_encrypted_modulus", |b| b.iter(tfhe_encrypted_modulus)); +} + // Configure Criterion with a sample size of 10 fn custom_criterion() -> Criterion { Criterion::default().sample_size(10) @@ -924,9 +1034,13 @@ criterion_group!( name = benches; config = custom_criterion(); targets = + benchmark_gateway_encrypted_division, + benchmark_tfhe_encrypted_division, + benchmark_gateway_encrypted_modulus, + benchmark_tfhe_encrypted_modulus, + benchmark_gateway_encrypted_mux, benchmark_tfhe_encrypted_mux, - benchmark_gateway_encrypted_addition, benchmark_tfhe_encrypted_addition, benchmark_gateway_encrypted_subtraction, diff --git a/circuit_macro/.gitignore b/circuit_macro/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/circuit_macro/.gitignore @@ -0,0 +1 @@ +/target diff --git a/circuit_macro/Cargo.toml b/circuit_macro/Cargo.toml new file mode 100644 index 0000000..190dfae --- /dev/null +++ b/circuit_macro/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "circuit_macro" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "2.0", features = ["full"] } +quote = "1.0" +proc-macro2 = "1.0" diff --git a/circuit_macro/src/lib.rs b/circuit_macro/src/lib.rs new file mode 100644 index 0000000..f78f01b --- /dev/null +++ b/circuit_macro/src/lib.rs @@ -0,0 +1,457 @@ +extern crate proc_macro; +use proc_macro::TokenStream; +use quote::quote; +use syn::{ + parse_macro_input, BinOp, Expr, ExprBinary, ExprIf, ExprUnary, FnArg, ItemFn, Pat, PatType, +}; + +#[proc_macro_attribute] +pub fn circuit(attr: TokenStream, item: TokenStream) -> TokenStream { + let mode = parse_macro_input!(attr as syn::Ident).to_string(); // Retrieve the mode (e.g., "compile" or "execute") + generate_macro(item, &mode) +} + +/// Generates the macro code based on the mode (either "compile" or "execute") +fn generate_macro(item: TokenStream, mode: &str) -> TokenStream { + let input_fn = parse_macro_input!(item as ItemFn); + let fn_name = &input_fn.sig.ident; // Function name + let inputs = &input_fn.sig.inputs; // Function input parameters + + // get the type of the first input parameter + let type_name = if let FnArg::Typed(PatType { ty, .. }) = &inputs[0] { + quote! {#ty} + } else { + panic!("Expected typed argument"); + }; + + // get the type of the first output parameter + let output_type = if let syn::ReturnType::Type(_, ty) = &input_fn.sig.output { + quote! {#ty} + } else { + panic!("Expected typed return type"); + }; + + // We need to extract each input's identifier + let mapped_inputs = inputs.iter().map(|input| { + if let FnArg::Typed(PatType { pat, .. }) = input { + if let Pat::Ident(pat_ident) = &**pat { + let var_name = &pat_ident.ident; + quote! { + let #var_name = &context.input(&#var_name.clone().into()); + } + } else { + quote! {} + } + } else { + quote! {} + } + }); + + // Replace "+" with context.add and handle if/else in the function body + let transformed_block = modify_body(*input_fn.block); + + // Collect parameter names dynamically + let param_names: Vec<_> = inputs + .iter() + .map(|input| { + if let FnArg::Typed(PatType { pat, .. }) = input { + if let Pat::Ident(pat_ident) = &**pat { + pat_ident.ident.clone() + } else { + panic!("Expected identifier pattern"); + } + } else { + panic!("Expected typed argument"); + } + }) + .collect(); + + // Dynamically generate the `generate` function calls using the parameter names + let match_arms = quote! { + match std::any::type_name::<#type_name>() { + "bool" => generate::<1, #type_name>(#(#param_names),*), + "u8" => generate::<8, #type_name>(#(#param_names),*), + "u16" => generate::<16, #type_name>(#(#param_names),*), + "u32" => generate::<32, #type_name>(#(#param_names),*), + "u64" => generate::<64, #type_name>(#(#param_names),*), + "u128" => generate::<128, #type_name>(#(#param_names),*), + _ => panic!("Unsupported type"), + } + }; + + // Set the output type and operation logic based on mode + let output_type = if mode == "compile" { + quote! {(Circuit, Vec)} + } else { + quote! {#output_type} + }; + + let operation = if mode == "compile" { + quote! { + (context.compile(&output), context.inputs().to_vec()) + } + } else { + quote! { + let compiled_circuit = context.compile(&output.into()); + let result = context.execute::(&compiled_circuit).expect("Execution failed"); + result.into() + } + }; + + // Build the function body with circuit context, compile, and execute + let expanded = quote! { + #[allow(non_camel_case_types, non_snake_case, clippy::builtin_type_shadow, clippy::too_many_arguments)] + fn #fn_name<#type_name>(#inputs) -> #output_type + where + #type_name: Into> + From> + + Into> + From> + + Into> + From> + + Into> + From> + + Into> + From> + + Into> + From> + + Clone, + { + fn generate(#inputs) -> #output_type + where + #type_name: Into> + From> + Clone, + { + let mut context = CircuitBuilder::default(); + #(#mapped_inputs)* + + // Use the transformed function block (with context.add and if/else replacements) + let output = { #transformed_block }; + + #operation + } + + #match_arms + } + }; + + // Print the expanded code to stderr + // println!("Generated code:\n{}", expanded); + + TokenStream::from(expanded) +} + +/// Traverse and transform the function body, replacing binary operators and if/else expressions. +fn modify_body(block: syn::Block) -> syn::Block { + let stmts = block + .stmts + .into_iter() + .map(|stmt| { + match stmt { + syn::Stmt::Expr(expr, semi_opt) => { + syn::Stmt::Expr(replace_expressions(expr), semi_opt) + } + syn::Stmt::Local(mut local) => { + if let Some(local_init) = &mut local.init { + // Replace the initializer expression + local_init.expr = Box::new(replace_expressions(*local_init.expr.clone())); + } + syn::Stmt::Local(local) + } + other => other, + } + }) + .collect(); + + syn::Block { + stmts, + brace_token: syn::token::Brace::default(), + } +} + +/// Replaces binary operators and if/else expressions with appropriate context calls. +fn replace_expressions(expr: Expr) -> Expr { + match expr { + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::Eq(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.eq(&left.into(), &right.into()) + }} + } + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::Ne(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.ne(&left.into(), &right.into()) + }} + } + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::Gt(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.gt(&left.into(), &right.into()) + }} + } + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::Ge(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.ge(&left.into(), &right.into()) + }} + } + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::Lt(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.lt(&left.into(), &right.into()) + }} + } + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::Le(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.le(&left.into(), &right.into()) + }} + } + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::Add(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.add(&left.into(), &right.into()) + }} + } + + /* + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::AddAssign(_), + .. + }) => { + syn::parse_quote! { + &context.add(&#left, &#right) + } + } + */ + // subtraction + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::Sub(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.sub(&left.into(), &right.into()) + }} + } + // multiplication + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::Mul(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.mul(&left.into(), &right.into()) + }} + } + // division - TODO: Implement division + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::Div(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.div(&left.into(), &right.into()) + }} + } + // modulo - TODO: Implement modulo + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::Rem(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.rem(&left.into(), &right.into()) + }} + } + // logical AND + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::And(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.land(&left, &right) + }} + } + + // logical OR + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::Or(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.lor(&left, &right) + }} + } + + // bitwise AND + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::BitAnd(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.and(&left.into(), &right.into()) + }} + } + // bitwise OR + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::BitOr(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.or(&left.into(), &right.into()) + }} + } + // bitwise XOR + Expr::Binary(ExprBinary { + left, + right, + op: BinOp::BitXor(_), + .. + }) => { + let left_expr = replace_expressions(*left); + let right_expr = replace_expressions(*right); + syn::parse_quote! {{ + let left = #left_expr; + let right = #right_expr; + &context.xor(&left.into(), &right.into()) + }} + } + // bitwise NOT + Expr::Unary(ExprUnary { + op: syn::UnOp::Not(_), + expr, + .. + }) => { + let single_expr = replace_expressions(*expr); + syn::parse_quote! {{ + let single = #single_expr; + &context.not(&single.into()) + }} + } + Expr::If(ExprIf { + cond, + then_branch, + else_branch, + .. + }) => { + if let Some((_, else_branch)) = else_branch { + let then_expr = modify_body(then_branch.clone()); + + let else_expr = match *else_branch { + syn::Expr::Block(syn::ExprBlock { block, .. }) => modify_body(block.clone()), + _ => panic!("Expected a block in else branch"), + }; + + let cond = replace_expressions(*cond.clone()); + + syn::parse_quote! {{ + let if_true = #then_expr; + let if_false = #else_expr; + let cond = #cond; + &context.mux(cond, if_true, if_false) + }} + } else { + panic!("Expected else branch for if expression"); + } + } + + other => other, + } +} diff --git a/compute/Cargo.toml b/compute/Cargo.toml index a8b1581..b6c0549 100644 --- a/compute/Cargo.toml +++ b/compute/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +circuit_macro = { path = "../circuit_macro" } tracing = { workspace = true, features = ["log"] } anyhow = { workspace = true } tandem = { git = "https://github.com/sine-fdn/tandem.git" } diff --git a/compute/examples/access_control.rs b/compute/examples/access_control.rs new file mode 100644 index 0000000..ac00d5c --- /dev/null +++ b/compute/examples/access_control.rs @@ -0,0 +1,26 @@ +use compute::prelude::*; + +/// Determines if a user has the required access level to enter a restricted area. +/// +/// # Parameters +/// - `user_level`: The access level of the current user. +/// - `required_level`: The minimum access level required for the restricted area. +/// +/// # Returns +/// - `bool`: Returns `true` if the user's level is greater than or equal to the required level, +/// indicating they have the necessary access, otherwise `false`. +/// +/// # Example +/// This example demonstrates verifying if a user with level 5 can access an area that requires level 4. +#[circuit(execute)] +fn has_access(user_level: u8, required_level: u8) -> bool { + user_level >= required_level +} + +fn main() { + let user_level = 5_u8; + let required_level = 4_u8; + + let result = has_access(user_level, required_level); + println!("Does the user have access? {}", result); // Expected: true +} diff --git a/compute/examples/add_two_numbers.rs b/compute/examples/add_two_numbers.rs new file mode 100644 index 0000000..7ef6358 --- /dev/null +++ b/compute/examples/add_two_numbers.rs @@ -0,0 +1,14 @@ +use compute::prelude::*; + +fn main() -> Result<(), Box> { + let clear_a = 12297829382473034410u128; + let clear_b = 424242424242u128; + + let a: GarbledUint128 = clear_a.into(); + let b: GarbledUint128 = clear_b.into(); + + let result = &a + &b; + let result: u128 = result.into(); + assert_eq!(result, clear_a + clear_b); + Ok(()) +} diff --git a/compute/examples/discount_eligibility.rs b/compute/examples/discount_eligibility.rs new file mode 100644 index 0000000..88f472c --- /dev/null +++ b/compute/examples/discount_eligibility.rs @@ -0,0 +1,26 @@ +use compute::prelude::*; + +/// Checks if a customer's purchase amount qualifies for a discount. +/// +/// # Parameters +/// - `purchase_amount`: The total amount of the customer's purchase. +/// - `discount_threshold`: The minimum amount required to be eligible for a discount. +/// +/// # Returns +/// - `bool`: Returns `true` if the purchase amount is greater than or equal to the discount threshold, +/// otherwise `false`. +/// +/// # Example +/// This example demonstrates checking if a purchase of 100 qualifies for a discount with a threshold of 80. +#[circuit(execute)] +fn qualifies_for_discount(purchase_amount: u16, discount_threshold: u16) -> bool { + purchase_amount >= discount_threshold +} + +fn main() { + let purchase_amount = 100_u16; + let discount_threshold = 80_u16; + + let result = qualifies_for_discount(purchase_amount, discount_threshold); + println!("Does the purchase qualify for a discount? {}", result); // Expected: true +} diff --git a/compute/examples/loan_eligibility.rs b/compute/examples/loan_eligibility.rs new file mode 100644 index 0000000..86e51de --- /dev/null +++ b/compute/examples/loan_eligibility.rs @@ -0,0 +1,101 @@ +use compute::prelude::*; + +/// Evaluates a loan application based on income, credit score, debt-to-income ratio, and other requirements. +/// +/// The logic follows a tiered approach: +/// - "Full Approval": If income, credit score, and debt-to-income ratio meet the highest criteria. +/// - "Conditional Approval": If income or credit score partially meet the requirements. +/// - "Denied": If none of the criteria are met. +/// +/// # Parameters +/// - `income`: The applicant's income level. +/// - `credit_score`: The applicant's credit score. +/// - `debt_ratio`: The applicant's debt-to-income ratio (in percentage). +/// - `HIGH_INCOME_REQ`: The high income requirement for full approval. +/// - `MIN_INCOME_REQ`: The minimum income requirement for conditional approval. +/// - `MIN_CREDIT_SCORE`: The minimum credit score requirement for conditional approval. +/// - `MAX_DEBT_RATIO`: The maximum debt-to-income ratio allowed for full approval. +/// - `MAX_CONDITIONAL_DEBT_RATIO`: The maximum debt-to-income ratio for conditional approval. +/// - `FULLY_APPROVED`: The status code for full approval. +/// - `CONDITIONAL_APPROVED`: The status code for conditional approval. +/// - `DENIED`: The status code for denial. +/// +/// # Returns +/// - `u8`: Returns 2 for "Full Approval," 1 for "Conditional Approval," and 0 for "Denied." +/// +/// # Example +/// This example demonstrates evaluating an applicant with an income of 75,000, a credit score of 680, +/// and a debt-to-income ratio of 30%. The requirements are: +/// - Full approval requires income >= 70,000, credit score >= 720, and debt ratio <= 35%. +/// - Conditional approval requires credit score >= 650 and income >= 50,000 and debt ratio <= 40. + +#[circuit(execute)] +fn evaluate_loan_application( + income: u32, + credit_score: u32, + debt_ratio: u32, + HIGH_INCOME_REQ: u32, + MIN_INCOME_REQ: u32, + MIN_CREDIT_SCORE: u32, + MAX_DEBT_RATIO: u32, + MAX_CONDITIONAL_DEBT_RATIO: u32, + FULLY_APPROVED: u32, + CONDITIONAL_APPROVED: u32, + DENIED: u32, +) -> u32 { + // Check for Full Approval + if income >= HIGH_INCOME_REQ && credit_score >= MIN_CREDIT_SCORE && debt_ratio <= MAX_DEBT_RATIO + { + FULLY_APPROVED + } else { + let income_and_credit_score = income >= MIN_INCOME_REQ && credit_score >= MIN_CREDIT_SCORE; + // Check for Conditional Approval + if income_and_credit_score && debt_ratio <= MAX_CONDITIONAL_DEBT_RATIO { + CONDITIONAL_APPROVED + } else { + // Denied if neither criteria met + DENIED + } + } +} + +fn main() { + enum LoanStatus { + Denied, + ConditionalApproval, + FullApproval, + } + + // Approval requirements passed as parameters + const HIGH_INCOME_REQ: u32 = 70000_u32; + const MIN_INCOME_REQ: u32 = 50000_u32; + const MIN_CREDIT_SCORE: u32 = 650_u32; + const MAX_DEBT_RATIO: u32 = 35_u32; + const MAX_CONDITIONAL_DEBT_RATIO: u32 = 40_u32; + + // Example applicant data + let income = 75000_u32; + let credit_score = 680_u32; + let debt_ratio = 30_u32; + + let result = evaluate_loan_application( + income, + credit_score, + debt_ratio, + HIGH_INCOME_REQ, + MIN_INCOME_REQ, + MIN_CREDIT_SCORE, + MAX_DEBT_RATIO, + MAX_CONDITIONAL_DEBT_RATIO, + LoanStatus::FullApproval as u32, + LoanStatus::ConditionalApproval as u32, + LoanStatus::Denied as u32, + ); + + // Output the decision based on result + match result { + 2 => println!("Loan Status: Full Approval"), + 1 => println!("Loan Status: Conditional Approval"), + _ => println!("Loan Status: Denied"), + } +} diff --git a/compute/examples/password_requirements.rs b/compute/examples/password_requirements.rs new file mode 100644 index 0000000..62aaf30 --- /dev/null +++ b/compute/examples/password_requirements.rs @@ -0,0 +1,29 @@ +use compute::prelude::*; + +/// Validates if the provided password length meets the minimum required length. +/// +/// # Parameters +/// - `password_length`: The length of the password to be checked. +/// - `min_length`: The minimum acceptable length for the password. +/// +/// # Returns +/// - `bool`: Returns `true` if the password length is greater than or equal to the minimum length, +/// indicating the password meets the strength requirement, otherwise `false`. +/// +/// # Example +/// This example demonstrates verifying if a password with 12 characters meets a minimum length of 8 characters. +#[circuit(execute)] +fn password_strength(password_length: u8, min_length: u8) -> bool { + password_length >= min_length +} + +fn main() { + let password_length = 12_u8; + let min_length = 8_u8; + + let result = password_strength(password_length, min_length); + println!( + "Does the password meet the strength requirement? {}", + result + ); // Expected: true +} diff --git a/compute/examples/spending.rs b/compute/examples/spending.rs new file mode 100644 index 0000000..d64dd15 --- /dev/null +++ b/compute/examples/spending.rs @@ -0,0 +1,30 @@ +use compute::prelude::*; + +/// Determines if the remaining budget after expenses is within the allowable limit. +/// +/// # Parameters +/// - `budget`: The initial budget available. +/// - `spent`: The amount already spent from the budget. +/// - `max_allowed`: The maximum allowable remaining budget. +/// +/// # Returns +/// - `bool`: Returns `true` if the remaining budget is less than or equal to the maximum allowable, +/// indicating spending is within the limits, otherwise `false`. +/// +/// # Example +/// This example demonstrates checking if a remaining budget of 3000 (after spending 2000 from a budget of 5000) +/// stays within the maximum allowable limit of 3000. +#[circuit(execute)] +fn can_spend(budget: u16, spent: u16, max_allowed: u16) -> bool { + let remaining_budget = budget - spent; + remaining_budget <= max_allowed +} + +fn main() { + let budget = 5000_u16; + let spent = 2000_u16; + let max_allowed = 3000_u16; + + let result = can_spend(budget, spent, max_allowed); + println!("Is spending within the allowable limit? {}", result); // Expected: true +} diff --git a/compute/examples/temperature.rs b/compute/examples/temperature.rs new file mode 100644 index 0000000..a07baee --- /dev/null +++ b/compute/examples/temperature.rs @@ -0,0 +1,31 @@ +use compute::prelude::*; + +/// Checks if the current temperature is within the specified minimum and maximum range. +/// +/// # Parameters +/// - `current_temp`: The current temperature of the room. +/// - `min_temp`: The minimum acceptable temperature. +/// - `max_temp`: The maximum acceptable temperature. +/// +/// # Returns +/// - `bool`: Returns `true` if the current temperature is between the minimum and maximum values, +/// indicating it is within an acceptable range, otherwise `false`. +/// +/// # Example +/// This example demonstrates verifying if a room with a temperature of 70°F is within the range of 65°F to 75°F. +#[circuit(execute)] +fn within_temperature_range(current_temp: u8, min_temp: u8, max_temp: u8) -> bool { + let above_min = current_temp >= min_temp; + let below_max = current_temp <= max_temp; + + above_min && below_max +} + +fn main() { + let current_temp = 70_u8; + let min_temp = 65_u8; + let max_temp = 75_u8; + + let result = within_temperature_range(current_temp, min_temp, max_temp); + println!("Is the temperature within range? {}", result); // Expected: true +} diff --git a/compute/examples/threshold_voting.rs b/compute/examples/threshold_voting.rs new file mode 100644 index 0000000..e81d220 --- /dev/null +++ b/compute/examples/threshold_voting.rs @@ -0,0 +1,26 @@ +use compute::prelude::*; + +/// Determines if a candidate has received enough votes to pass the specified threshold. +/// +/// # Parameters +/// - `votes`: The number of votes the candidate has received. +/// - `threshold`: The minimum number of votes required for the candidate to pass. +/// +/// # Returns +/// - `bool`: Returns `true` if the votes are greater than or equal to the threshold, +/// indicating the candidate has met the required vote count, otherwise `false`. +/// +/// # Example +/// This example demonstrates checking if a candidate with 150 votes meets a threshold of 100 votes. +#[circuit(execute)] +fn has_enough_votes(votes: u8, threshold: u8) -> bool { + votes >= threshold +} + +fn main() { + let votes = 150_u8; + let threshold = 100_u8; + + let result = has_enough_votes(votes, threshold); + println!("Does the candidate have enough votes? {}", result); // Expected: true +} diff --git a/compute/src/executor.rs b/compute/src/executor.rs index b492559..6f84c11 100644 --- a/compute/src/executor.rs +++ b/compute/src/executor.rs @@ -6,6 +6,15 @@ use tandem::Circuit; use crate::evaluator::{Evaluator, GatewayEvaluator}; use crate::garbler::{Garbler, GatewayGarbler}; +/// A static Lazy instance for holding the singleton LocalSimulator. +static SINGLETON_EXECUTOR: Lazy> = + Lazy::new(|| Arc::new(LocalSimulator) as Arc); + +/// Provides access to the singleton Executor instance. +pub fn get_executor() -> Arc { + SINGLETON_EXECUTOR.clone() +} + pub trait Executor { /// Executes the 2 Party MPC protocol. /// @@ -22,13 +31,20 @@ pub trait Executor { input_contributor: &[bool], input_evaluator: &[bool], ) -> Result>; + + fn instance() -> &'static Arc + where + Self: Sized, + { + &SINGLETON_EXECUTOR + } } pub struct LocalSimulator; impl Executor for LocalSimulator { /// The Multi-Party Computation is performed using the full cryptographic protocol exposed by the - /// [`Contributor`] and [`Evaluator`]. The messages between contributor and evaluator are exchanged + /// `Contributor` and `Evaluator`. The messages between contributor and evaluator are exchanged /// using local message queues. This function thus simulates an MPC execution on a local machine /// under ideal network conditions, without any latency or bandwidth restrictions. fn execute( @@ -58,12 +74,3 @@ impl Executor for LocalSimulator { Ok(output) } } - -/// A static Lazy instance for holding the singleton LocalSimulator. -static SINGLETON_EXECUTOR: Lazy> = - Lazy::new(|| Arc::new(LocalSimulator) as Arc); - -/// Provides access to the singleton Executor instance. -pub(crate) fn get_executor() -> Arc { - SINGLETON_EXECUTOR.clone() -} diff --git a/compute/src/int.rs b/compute/src/int.rs index 2dbee3c..e6e8773 100644 --- a/compute/src/int.rs +++ b/compute/src/int.rs @@ -11,6 +11,10 @@ pub type GarbledInt16 = GarbledInt<16>; pub type GarbledInt32 = GarbledInt<32>; pub type GarbledInt64 = GarbledInt<64>; pub type GarbledInt128 = GarbledInt<128>; +pub type GarbledInt160 = GarbledInt<160>; +pub type GarbledInt256 = GarbledInt<256>; +pub type GarbledInt512 = GarbledInt<512>; +pub type GarbledInt1024 = GarbledInt<1024>; // Define a new type GarbledInt #[derive(Debug, Clone)] @@ -28,6 +32,7 @@ impl Display for GarbledInt { 32 => write!(f, "{}", i32::from(self.clone())), 64 => write!(f, "{}", i64::from(self.clone())), 128 => write!(f, "{}", i128::from(self.clone())), + 160..=1024 => write!(f, "GarbledInt<{}>", N), _ => panic!("Unsupported bit size for GarbledInt"), } } diff --git a/compute/src/lib.rs b/compute/src/lib.rs index 37161d6..a7d59ed 100644 --- a/compute/src/lib.rs +++ b/compute/src/lib.rs @@ -1,8 +1,21 @@ pub mod evaluator; -mod executor; +pub mod executor; pub mod garbler; pub mod int; pub mod operations; pub mod uint; -pub use tandem::Circuit; +pub mod prelude { + pub use crate::operations::circuits::builder::CircuitBuilder; + + pub use crate::int::{ + GarbledInt, GarbledInt128, GarbledInt16, GarbledInt256, GarbledInt32, GarbledInt512, + GarbledInt64, GarbledInt8, + }; + pub use crate::operations::circuits::types::GateIndexVec; + pub use crate::uint::{ + GarbledBoolean, GarbledUint, GarbledUint128, GarbledUint16, GarbledUint2, GarbledUint256, + GarbledUint32, GarbledUint4, GarbledUint512, GarbledUint64, GarbledUint8, + }; + pub use circuit_macro::circuit; +} diff --git a/compute/src/operations/arithmetic.rs b/compute/src/operations/arithmetic.rs index 40cbf32..cd71799 100644 --- a/compute/src/operations/arithmetic.rs +++ b/compute/src/operations/arithmetic.rs @@ -1,9 +1,12 @@ use crate::int::GarbledInt; use crate::operations::circuits::builder::{ - build_and_execute_addition, build_and_execute_multiplication, build_and_execute_subtraction, + build_and_execute_addition, build_and_execute_division, build_and_execute_multiplication, + build_and_execute_subtraction, }; use crate::uint::GarbledUint; -use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Sub, SubAssign}; + +use super::circuits::builder::build_and_execute_remainder; // Implement the Add operation for Uint and &GarbledUint impl Add for GarbledUint { @@ -94,6 +97,65 @@ impl MulAssign<&GarbledUint> for GarbledUint { } } +// Implement the Div operation for GarbledUint and &GarbledUint +impl Div for GarbledUint { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + build_and_execute_division(&self, &rhs) + } +} + +impl Div for &GarbledUint { + type Output = GarbledUint; + + fn div(self, rhs: Self) -> Self::Output { + build_and_execute_division(self, rhs) + } +} + +// Implement the DivAssign operation for GarbledUint and &GarbledUint +impl DivAssign for GarbledUint { + fn div_assign(&mut self, rhs: Self) { + *self = build_and_execute_division(self, &rhs); + } +} + +impl DivAssign<&GarbledUint> for GarbledUint { + fn div_assign(&mut self, rhs: &Self) { + *self = build_and_execute_division(self, rhs); + } +} + +// rem +impl Rem for GarbledUint { + type Output = Self; + + fn rem(self, rhs: Self) -> Self::Output { + build_and_execute_remainder(&self, &rhs) + } +} + +impl Rem for &GarbledUint { + type Output = GarbledUint; + + fn rem(self, rhs: Self) -> Self::Output { + build_and_execute_remainder(self, rhs) + } +} + +impl RemAssign for GarbledUint { + fn rem_assign(&mut self, rhs: Self) { + *self = build_and_execute_remainder(self, &rhs); + } +} + +impl RemAssign<&GarbledUint> for GarbledUint { + fn rem_assign(&mut self, rhs: &Self) { + *self = build_and_execute_remainder(self, rhs); + } +} + // Implement the Add operation for GarbledInt and &GarbledInt impl Add for GarbledInt { type Output = Self; @@ -183,3 +245,63 @@ impl MulAssign<&GarbledInt> for GarbledInt { *self = build_and_execute_multiplication(&self.clone().into(), &rhs.into()).into(); } } + +// implement Div operation for GarbledInt and &GarbledInt +impl Div for GarbledInt { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + build_and_execute_division(&self.into(), &rhs.into()).into() + } +} + +impl Div for &GarbledInt { + type Output = GarbledInt; + + fn div(self, rhs: Self) -> Self::Output { + build_and_execute_division(&self.into(), &rhs.into()).into() + } +} + +// Implement the DivAssign operation for GarbledInt and &GarbledInt +impl DivAssign for GarbledInt { + fn div_assign(&mut self, rhs: Self) { + *self = build_and_execute_division(&self.clone().into(), &rhs.into()).into(); + } +} + +impl DivAssign<&GarbledInt> for GarbledInt { + fn div_assign(&mut self, rhs: &Self) { + *self = build_and_execute_division(&self.clone().into(), &rhs.into()).into(); + } +} + +// Implement the Rem operation for GarbledInt and &GarbledInt +impl Rem for GarbledInt { + type Output = Self; + + fn rem(self, rhs: Self) -> Self::Output { + build_and_execute_remainder(&self.into(), &rhs.into()).into() + } +} + +impl Rem for &GarbledInt { + type Output = GarbledInt; + + fn rem(self, rhs: Self) -> Self::Output { + build_and_execute_remainder(&self.into(), &rhs.into()).into() + } +} + +// Implement the RemAssign operation for GarbledInt and &GarbledInt +impl RemAssign for GarbledInt { + fn rem_assign(&mut self, rhs: Self) { + *self = build_and_execute_remainder(&self.clone().into(), &rhs.into()).into(); + } +} + +impl RemAssign<&GarbledInt> for GarbledInt { + fn rem_assign(&mut self, rhs: &Self) { + *self = build_and_execute_remainder(&self.clone().into(), &rhs.into()).into(); + } +} diff --git a/compute/src/operations/circuits/builder.rs b/compute/src/operations/circuits/builder.rs index 7c439f5..b759f33 100644 --- a/compute/src/operations/circuits/builder.rs +++ b/compute/src/operations/circuits/builder.rs @@ -1,25 +1,49 @@ -use crate::executor::get_executor; +use crate::operations::circuits::types::GateIndexVec; use crate::uint::GarbledUint; +use crate::{executor::get_executor, uint::GarbledBoolean}; +use std::cell::RefCell; use std::cmp::Ordering; -use tandem::GateIndex; use tandem::{Circuit, Gate}; -pub struct CircuitBuilder { +pub type GateIndex = u32; + +// Global instance of CircuitBuilder +thread_local! { + static CIRCUIT_BUILDER: RefCell = RefCell::new(CircuitBuilder::default()); +} + +#[derive(Default)] +pub struct CircuitBuilder { + inputs: Vec, gates: Vec, } -impl Default for CircuitBuilder { - fn default() -> Self { - let gates = Vec::new(); - Self { gates } +impl CircuitBuilder { + // Static `global` function to access `CIRCUIT_BUILDER` + #[allow(dead_code)] + pub(super) fn global(f: F) -> R + where + F: FnOnce(&mut CircuitBuilder) -> R, + { + CIRCUIT_BUILDER.with(|builder| { + // Borrow mutably and pass to the closure `f` + f(&mut builder.borrow_mut()) + }) } -} -impl CircuitBuilder { - pub fn push_input(&mut self, input: &GarbledUint) { - for _ in &input.bits { - self.gates.push(Gate::InContrib); + pub fn input(&mut self, input: &GarbledUint) -> GateIndexVec { + // get the cumulative size of all inputs in input_labels + //let input_offset = self.input_labels.iter().map(|x| x.len()).sum::(); + + let input_offset = self.inputs.len(); + let mut input_label = GateIndexVec::default(); + for (i, bool_value) in input.bits.iter().enumerate() { + self.gates.insert(0, Gate::InContrib); + + self.inputs.push(*bool_value); + input_label.push((input_offset + i) as GateIndex); } + input_label } pub fn len(&self) -> GateIndex { @@ -30,278 +54,432 @@ impl CircuitBuilder { self.gates.is_empty() } + pub fn inputs(&self) -> &Vec { + &self.inputs + } + // Add a XOR gate between two inputs and return the index - pub fn push_xor(&mut self, a: GateIndex, b: GateIndex) -> GateIndex { + pub fn push_xor(&mut self, a: &GateIndex, b: &GateIndex) -> GateIndex { let xor_index = self.gates.len() as u32; - self.gates.push(Gate::Xor(a, b)); + self.gates.push(Gate::Xor(*a, *b)); xor_index } - // Add an AND gate between two inputs and return the index - pub fn push_and(&mut self, a: GateIndex, b: GateIndex) -> GateIndex { + pub fn xor(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndexVec { + // repeat with output_indices + let mut output = GateIndexVec::default(); + for i in 0..a.len() { + let xor = self.push_xor(&a[i], &b[i]); + output.push(xor); + } + output + } + + // Add an Aa.len()D gate between two inputs and return the index + pub fn push_and(&mut self, a: &GateIndex, b: &GateIndex) -> GateIndex { let and_index = self.gates.len() as u32; - self.gates.push(Gate::And(a, b)); + self.gates.push(Gate::And(*a, *b)); and_index } - // Add a NOT gate for a single input and return the index - pub fn push_not(&mut self, a: GateIndex) -> GateIndex { + pub fn and(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndexVec { + // repeat with output_indices + let mut output = GateIndexVec::default(); + for i in 0..a.len() { + let and = self.push_and(&a[i], &b[i]); + output.push(and); + } + output + } + + pub fn land(&mut self, a: &GateIndex, b: &GateIndex) -> GateIndex { + // repeat with output_indices + let mut output = GateIndexVec::default(); + let and = self.push_and(a, b); + output.push(and); + output.into() + } + + // Add a a.len()OT gate for a single input and return the index + pub fn push_not(&mut self, a: &GateIndex) -> GateIndex { let not_index = self.gates.len() as u32; - self.gates.push(Gate::Not(a)); + self.gates.push(Gate::Not(*a)); not_index } + pub fn not(&mut self, a: &GateIndexVec) -> GateIndexVec { + let mut output = GateIndexVec::default(); + for i in 0..a.len() { + let not_gate = self.push_not(&a[i]); + output.push(not_gate); + } + output + } + // Add a gate for OR operation: OR(a, b) = (a ⊕ b) ⊕ (a & b) - pub fn push_or(&mut self, a: GateIndex, b: GateIndex) -> GateIndex { + pub fn push_or(&mut self, a: &GateIndex, b: &GateIndex) -> GateIndex { let xor_gate = self.push_xor(a, b); let and_gate = self.push_and(a, b); - self.push_xor(xor_gate, and_gate) + self.push_xor(&xor_gate, &and_gate) + } + + pub fn or(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndexVec { + // repeat with output_indices + let mut output = GateIndexVec::default(); + for i in 0..a.len() { + let or_gate = self.push_or(&a[i], &b[i]); + output.push(or_gate); + } + output + } + + pub fn lor(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndex { + let output = self.or(a, b); + output.into() } - // Add a NAND gate: NAND(a, b) = NOT(a & b) - pub fn push_nand(&mut self, a: GateIndex, b: GateIndex) -> GateIndex { + // Add a a.len()Aa.len()D gate: a.len()Aa.len()D(a, b) = a.len()OT(a & b) + pub fn push_nand(&mut self, a: &GateIndex, b: &GateIndex) -> GateIndex { let and_gate = self.push_and(a, b); - self.push_not(and_gate) + self.push_not(&and_gate) } - // Add a NOR gate: NOR(a, b) = NOT(OR(a, b)) - pub fn push_nor(&mut self, a: GateIndex, b: GateIndex) -> GateIndex { + pub fn nand(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndexVec { + // repeat with output_indices + let mut output = GateIndexVec::default(); + for i in 0..a.len() { + let nand = self.push_nand(&a[i], &b[i]); + output.push(nand); + } + output + } + + pub fn push_nor(&mut self, a: &GateIndex, b: &GateIndex) -> GateIndex { let or_gate = self.push_or(a, b); - self.push_not(or_gate) + self.push_not(&or_gate) } - // Add an XNOR gate: XNOR(a, b) = NOT(a ⊕ b) - pub fn push_xnor(&mut self, a: GateIndex, b: GateIndex) -> GateIndex { + pub fn nor(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndexVec { + // repeat with output_indices + let mut output = GateIndexVec::default(); + for i in 0..a.len() { + let nor = self.push_nor(&a[i], &b[i]); + output.push(nor); + } + output + } + + // Add an Xa.len()OR gate: Xa.len()OR(a, b) = a.len()OT(a ⊕ b) + pub fn push_xnor(&mut self, a: &GateIndex, b: &GateIndex) -> GateIndex { let xor_gate = self.push_xor(a, b); - self.push_not(xor_gate) + self.push_not(&xor_gate) + } + + pub fn xnor(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndexVec { + // repeat with output_indices + let mut output = GateIndexVec::default(); + for i in 0..a.len() { + let xnor = self.push_xnor(&a[i], &b[i]); + output.push(xnor); + } + output + } + + pub fn mux(&mut self, s: &GateIndex, a: &GateIndexVec, b: &GateIndexVec) -> GateIndexVec { + // repeat with output_indices + let mut output = GateIndexVec::default(); + for i in 0..a.len() { + let mux = self.push_mux(s, &b[i], &a[i]); + output.push(mux); + } + output } #[allow(dead_code)] // Add a MUX gate: MUX(a, b, s) = (a & !s) | (b & s) - pub fn push_mux(&mut self, a: GateIndex, b: GateIndex, s: GateIndex) -> GateIndex { + pub fn push_mux(&mut self, s: &GateIndex, a: &GateIndex, b: &GateIndex) -> GateIndex { let not_s = self.push_not(s); - let and_a_not_s = self.push_and(a, not_s); + let and_a_not_s = self.push_and(a, ¬_s); let and_b_s = self.push_and(b, s); - self.push_or(and_a_not_s, and_b_s) + self.push_or(&and_a_not_s, &and_b_s) } - // Build and return a Circuit from the current gates with given output indices - pub fn build(self, output_indices: Vec) -> Circuit { - Circuit::new(self.gates, output_indices) + pub fn add(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndexVec { + let mut carry = None; + let mut output_indices = GateIndexVec::default(); + for i in 0..a.len() { + let (sum, new_carry) = full_adder(self, a[i], b[i], carry); + output_indices.push(sum); + carry = new_carry; + } + output_indices } - fn push_garbled_uints( - &mut self, - a: &[GateIndex], - b: &[GateIndex], - ) -> (Vec, Option) { - let mut result = Vec::with_capacity(a.len()); - let mut carry = None; + pub fn sub(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndexVec { + let mut borrow = None; + let mut output_indices = GateIndexVec::default(); + for i in 0..a.len() { + let (diff, new_borrow) = full_subtractor(self, &a[i], &b[i], &borrow); + output_indices.push(diff); + borrow = new_borrow; + } + output_indices + } + + pub fn mul(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndexVec { + let mut partial_products: Vec = Vec::with_capacity(a.len()); + // Generate partial products for i in 0..a.len() { - let sum = self.full_adder(a[i], b[i], carry); - result.push(sum.0); - carry = sum.1; + let shifted_product = partial_product_shift(self, a, b, i); + partial_products.push(shifted_product); } - (result, carry) - } + // Sum up all partial products + let mut result = partial_products[0].clone(); + for partial_product in partial_products.iter().take(a.len()).skip(1) { + result = self.add(&result, partial_product); + } - fn full_adder( - &mut self, - a: GateIndex, - b: GateIndex, - carry: Option, - ) -> (GateIndex, Option) { - let xor_ab = self.len(); - self.gates.push(Gate::Xor(a, b)); + result + } - let sum = if let Some(c) = carry { - let sum_with_carry = self.len(); - self.gates.push(Gate::Xor(xor_ab, c)); - sum_with_carry - } else { - xor_ab - }; + fn div_inner(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> (GateIndexVec, GateIndexVec) { + let n = a.len(); + let mut quotient = GateIndexVec::default(); + let mut remainder = GateIndexVec::default(); - let and_ab = self.len(); - self.gates.push(Gate::And(a, b)); + // Initialize remainder with 0 + for _ in 0..n { + remainder.push(GateIndex::default()); // Zero initialize + } - let new_carry = if let Some(c) = carry { - let and_axorb_c = self.len(); - self.gates.push(Gate::And(xor_ab, c)); + // Iterate through each bit, starting from the most significant + for i in (0..n).rev() { + // Shift remainder left by 1 (equivalent to adding a bit) + remainder.insert(0, a[i]); + if remainder.len() > n { + remainder.truncate(n); // Ensure remainder does not exceed bit width + } + + // Check if remainder is greater than or equal to divisor + let greater_or_equal = self.ge(&remainder, b); + + // If remainder is greater than or equal to divisor, set quotient bit to 1 and subtract divisor from remainder + if greater_or_equal != GateIndex::default() { + // Subtract divisor from remainder if it’s greater than or equal + let new_remainder = self.sub(&remainder, b); + remainder = self.mux(&greater_or_equal, &new_remainder, &remainder); + + // Set quotient bit to 1 + quotient.insert(0, greater_or_equal); + } else { + // Set quotient bit to 0 + quotient.insert(0, GateIndex::default()); + } + + if quotient.len() > n { + quotient.truncate(n); // Ensure quotient does not exceed bit width + } + } - let or_gate = self.len(); - self.gates.push(Gate::Xor(and_ab, and_axorb_c)); - Some(or_gate) - } else { - Some(and_ab) - }; + (quotient, remainder) + } - (sum, new_carry) + pub fn div(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndexVec { + self.div_inner(a, b).0 } - // Simulate the circuit using the provided input values - pub fn execute( - &self, - lhs: &GarbledUint, - rhs: &GarbledUint, - output_indices: Vec, - ) -> anyhow::Result> { - let input = [lhs.bits.clone(), rhs.bits.clone()].concat(); - self.execute_with_input(&input, output_indices) + pub fn rem(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndexVec { + self.div_inner(a, b).1 } - pub fn execute_with_input( - &self, - input: &[bool], - output_indices: Vec, - ) -> anyhow::Result> { - let program = Circuit::new(self.gates.clone(), output_indices); - let result = get_executor().execute(&program, input, &[])?; - Ok(GarbledUint::new(result)) + pub fn eq(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndex { + let mut eq_list = vec![0; a.len()]; + + let i = a.len() - 1; + let eq_i = self.push_xnor(&a[i], &b[i]); + eq_list[i] = eq_i; + + for idx in (0..i).rev() { + let xn = self.push_xnor(&a[idx], &b[idx]); + let eq_i = self.push_and(&eq_list[idx + 1], &xn); + eq_list[idx] = eq_i; + } + + eq_list[0] } -} -pub(crate) fn build_and_execute_xor( - lhs: &GarbledUint, - rhs: &GarbledUint, -) -> GarbledUint { - let mut builder = CircuitBuilder::default(); - builder.push_input(lhs); - builder.push_input(rhs); + pub fn ne(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndex { + let eq = self.eq(a, b); + self.push_not(&eq) + } - // Add XOR gates for each bit - let mut output_indices = Vec::with_capacity(N); - for i in 0..N { - let xor_gate = builder.push_xor(i as u32, (N + i) as u32); - output_indices.push(xor_gate); + pub fn gt(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndex { + let (lt, eq) = self.compare(a, b); + let or_gate = self.push_or(<, &eq); + self.push_not(&or_gate) } - // Simulate the circuit - builder - .execute(lhs, rhs, output_indices) - .expect("Failed to execute XOR circuit") -} + pub fn ge(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndex { + let lt = self.lt(a, b); + self.push_not(<) + } -pub(crate) fn build_and_execute_and( - lhs: &GarbledUint, - rhs: &GarbledUint, -) -> GarbledUint { - let mut builder = CircuitBuilder::default(); - builder.push_input(lhs); - builder.push_input(rhs); + pub fn lt(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndex { + let (lt, _eq) = self.compare(a, b); + lt + } - // Add AND gates for each bit - let mut output_indices = Vec::with_capacity(N); - for i in 0..N { - let and_gate = builder.push_and(i as u32, (N + i) as u32); - output_indices.push(and_gate); + pub fn le(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> GateIndex { + let gt = self.gt(a, b); + self.push_not(>) } - // Simulate the circuit - builder - .execute(lhs, rhs, output_indices) - .expect("Failed to execute AND circuit") -} + pub fn compare(&mut self, a: &GateIndexVec, b: &GateIndexVec) -> (GateIndex, GateIndex) { + let mut eq_list = vec![0; a.len()]; + let mut lt_list = vec![0; a.len()]; -pub(crate) fn build_and_execute_or( - lhs: &GarbledUint, - rhs: &GarbledUint, -) -> GarbledUint { - let mut builder = CircuitBuilder::default(); - builder.push_input(lhs); - builder.push_input(rhs); + let i = a.len() - 1; + let eq_i = self.push_xnor(&a[i], &b[i]); + eq_list[i] = eq_i; - // Add OR gates for each bit - let mut output_indices = Vec::with_capacity(N); - for i in 0..N { - let or_gate = builder.push_or(i as u32, (N + i) as u32); - output_indices.push(or_gate); - } + let nt = self.push_not(&a[i]); + let lt_i = self.push_and(&nt, &b[i]); + lt_list[i] = lt_i; - // Simulate the circuit - builder - .execute(lhs, rhs, output_indices) - .expect("Failed to execute OR circuit") -} + for idx in (0..i).rev() { + let xn = self.push_xnor(&a[idx], &b[idx]); + let eq_i = self.push_and(&eq_list[idx + 1], &xn); + eq_list[idx] = eq_i; -pub(crate) fn build_and_execute_addition( - lhs: &GarbledUint, - rhs: &GarbledUint, -) -> GarbledUint { - let mut builder = CircuitBuilder::default(); - builder.push_input(lhs); - builder.push_input(rhs); + let nt = self.push_not(&a[idx]); + let aa = self.push_and(&nt, &b[idx]); + let temp_lt = self.push_and(&eq_list[idx + 1], &aa); + lt_list[idx] = self.push_or(<_list[idx + 1], &temp_lt); + } - let mut carry = None; + (lt_list[0], eq_list[0]) + } - // Create a full adder for each bit - let mut output_indices = Vec::with_capacity(N); - for i in 0..N { - let (sum, new_carry) = builder.full_adder(i as GateIndex, (N + i) as GateIndex, carry); - output_indices.push(sum); - carry = new_carry; + pub fn compile(&self, output_indices: &GateIndexVec) -> Circuit { + Circuit::new(self.gates.clone(), output_indices.clone().into()) } - // Simulate the circuit - builder - .execute(lhs, rhs, output_indices) - .expect("Failed to execute addition circuit") + pub fn execute(&self, circuit: &Circuit) -> anyhow::Result> { + let result = get_executor().execute(circuit, &self.inputs, &[])?; + Ok(GarbledUint::new(result)) + } + + // Simulate the circuit using the provided input values + pub fn compile_and_execute( + &self, + output_indices: &GateIndexVec, + ) -> anyhow::Result> { + let circuit = self.compile(output_indices); + let result = get_executor().execute(&circuit, &self.inputs, &[])?; + Ok(GarbledUint::new(result)) + } } -pub(crate) fn build_and_execute_subtraction( - lhs: &GarbledUint, - rhs: &GarbledUint, -) -> GarbledUint { - let mut builder = CircuitBuilder::default(); - builder.push_input(lhs); - builder.push_input(rhs); +macro_rules! build_and_execute { + ($fn_name:ident, $op:ident) => { + pub(crate) fn $fn_name( + lhs: &GarbledUint, + rhs: &GarbledUint, + ) -> GarbledUint { + let mut builder = CircuitBuilder::default(); + // Access the global CircuitBuilder instance + //let mut builder = CircuitBuilder::instance().lock().unwrap(); + + let a = builder.input(lhs); + let b = builder.input(rhs); + + let output = builder.$op(&a, &b); + let circuit = builder.compile(&output); + + // Execute the circuit + builder + .execute(&circuit) + .expect("Failed to execute circuit") + } + }; +} - let mut borrow = None; +build_and_execute!(build_and_execute_xor, xor); +build_and_execute!(build_and_execute_and, and); +build_and_execute!(build_and_execute_or, or); +build_and_execute!(build_and_execute_nand, nand); +build_and_execute!(build_and_execute_nor, nor); +build_and_execute!(build_and_execute_xnor, xnor); +build_and_execute!(build_and_execute_addition, add); +build_and_execute!(build_and_execute_subtraction, sub); +build_and_execute!(build_and_execute_multiplication, mul); +build_and_execute!(build_and_execute_division, div); +build_and_execute!(build_and_execute_remainder, rem); + +fn full_adder( + builder: &mut CircuitBuilder, + a: GateIndex, + b: GateIndex, + carry: Option, +) -> (GateIndex, Option) { + let xor_ab = builder.len(); + builder.gates.push(Gate::Xor(a, b)); + + let sum = if let Some(c) = carry { + let sum_with_carry = builder.len(); + builder.gates.push(Gate::Xor(xor_ab, c)); + sum_with_carry + } else { + xor_ab + }; - // Create a full subtractor for each bit - let mut output_indices = Vec::with_capacity(N); - for i in 0..N { - let (diff, new_borrow) = full_subtractor(&mut builder, i as u32, (N + i) as u32, borrow); - output_indices.push(diff); - borrow = new_borrow; - } + let and_ab = builder.len(); + builder.gates.push(Gate::And(a, b)); - // Simulate the circuit - builder - .execute(lhs, rhs, output_indices) - .expect("Failed to execute subtraction circuit") + let new_carry = if let Some(c) = carry { + let and_axorb_c = builder.len(); + builder.gates.push(Gate::And(xor_ab, c)); + + let or_gate = builder.len(); + builder.gates.push(Gate::Xor(and_ab, and_axorb_c)); + Some(or_gate) + } else { + Some(and_ab) + }; + + (sum, new_carry) } -fn full_subtractor( - builder: &mut CircuitBuilder, - a: u32, - b: u32, - borrow: Option, +fn full_subtractor( + builder: &mut CircuitBuilder, + a: &u32, + b: &u32, + borrow: &Option, ) -> (u32, Option) { // XOR gate for difference bit (a ⊕ b) let xor_ab = builder.push_xor(a, b); // If borrow exists, XOR the result of the previous XOR with the borrow let diff = if let Some(borrow) = borrow { - builder.push_xor(xor_ab, borrow) + builder.push_xor(&xor_ab, borrow) } else { xor_ab }; // Compute the new borrow: (!a & b) | (a & borrow) | (!b & borrow) let not_a = builder.push_not(a); - let and_not_a_b = builder.push_and(not_a, b); + let and_not_a_b = builder.push_and(¬_a, b); let new_borrow = if let Some(borrow) = borrow { let and_a_borrow = builder.push_and(a, borrow); let not_b = builder.push_not(b); - let and_not_b_borrow = builder.push_and(not_b, borrow); + let and_not_b_borrow = builder.push_and(¬_b, borrow); - // Combine borrow parts using XOR and AND to simulate OR - let xor_borrow_parts = builder.push_xor(and_not_a_b, and_a_borrow); - builder.push_xor(xor_borrow_parts, and_not_b_borrow) + // Combine borrow parts using XOR and Aa.len()D to simulate OR + let xor_borrow_parts = builder.push_xor(&and_not_a_b, &and_a_borrow); + builder.push_xor(&xor_borrow_parts, &and_not_b_borrow) } else { and_not_a_b }; @@ -309,154 +487,63 @@ fn full_subtractor( (diff, Some(new_borrow)) } -pub(crate) fn build_and_execute_multiplication( - lhs: &GarbledUint, - rhs: &GarbledUint, -) -> GarbledUint { - let mut builder = CircuitBuilder::default(); - builder.push_input(lhs); - builder.push_input(rhs); - - let mut partial_products = Vec::with_capacity(N); - - // Generate partial products - for i in 0..N { - let shifted_product = generate_partial_product(&mut builder, 0, N as GateIndex, i); - partial_products.push(shifted_product); - } - - // Sum up all partial products - let mut result = partial_products[0].clone(); - for partial_product in partial_products.iter().take(N).skip(1) { - (result, _) = builder.push_garbled_uints(&result, partial_product); - } - - // Simulate the circuit - builder - .execute(lhs, rhs, result.to_vec()) - .expect("Failed to execute multiplication circuit") -} - -fn generate_partial_product( - builder: &mut CircuitBuilder, - lhs_start: GateIndex, - rhs_start: GateIndex, +fn partial_product_shift( + builder: &mut CircuitBuilder, + lhs: &GateIndexVec, + rhs: &GateIndexVec, shift: usize, -) -> Vec { - let mut partial_product = Vec::with_capacity(N); +) -> GateIndexVec { + let mut shifted = GateIndexVec::default(); - for i in 0..N { + for i in 0..lhs.len() { if i < shift { - // For lower bits, we use a constant 0 + // For the lower bits, we push a constant 0. let zero_bit = builder.len(); - builder.push_not(rhs_start); - builder.push_and(rhs_start, zero_bit); // Constant 0 - partial_product.push(builder.len() - 1); + builder.push_not(&rhs[0]); + let _zero = builder.push_and(&rhs[0], &zero_bit); // Constant 0 + shifted.push(builder.len() - 1); } else { - let lhs_bit = lhs_start + (i - shift) as u32; - let and_gate = builder.len(); - builder.push_and(lhs_bit, rhs_start + shift as u32); - partial_product.push(and_gate); + let lhs_bit = lhs[i - shift]; + let and_gate = builder.push_and(&lhs_bit, &(rhs[shift])); + // Shift the bit from the input array + shifted.push(and_gate); } } - partial_product -} - -pub(crate) fn build_and_execute_nand( - lhs: &GarbledUint, - rhs: &GarbledUint, -) -> GarbledUint { - let mut builder = CircuitBuilder::default(); - builder.push_input(lhs); - builder.push_input(rhs); - - let mut output_indices = Vec::with_capacity(N); - - for i in 0..N { - let nand_gate = builder.push_nand(i as u32, (N + i) as u32); - output_indices.push(nand_gate); - } - - builder - .execute(lhs, rhs, output_indices) - .expect("Failed to execute NAND circuit") -} - -pub(crate) fn build_and_execute_nor( - lhs: &GarbledUint, - rhs: &GarbledUint, -) -> GarbledUint { - let mut builder = CircuitBuilder::default(); - builder.push_input(lhs); - builder.push_input(rhs); - - let mut output_indices = Vec::with_capacity(N); - - for i in 0..N { - let nor_gate = builder.push_nor(i as u32, (N + i) as u32); - output_indices.push(nor_gate); - } - - builder - .execute(lhs, rhs, output_indices) - .expect("Failed to execute NOR circuit") -} - -pub(crate) fn build_and_execute_xnor( - lhs: &GarbledUint, - rhs: &GarbledUint, -) -> GarbledUint { - let mut builder = CircuitBuilder::default(); - builder.push_input(lhs); - builder.push_input(rhs); - - let mut output_indices = Vec::with_capacity(N); - - for i in 0..N { - let xnor_gate = builder.push_xnor(i as u32, (N + i) as u32); - output_indices.push(xnor_gate); - } - - builder - .execute(lhs, rhs, output_indices) - .expect("Failed to execute XNOR circuit") + shifted } pub(crate) fn build_and_execute_equality( lhs: &GarbledUint, rhs: &GarbledUint, ) -> bool { - let mut builder: CircuitBuilder = CircuitBuilder::default(); - builder.push_input(lhs); - builder.push_input(rhs); - - let mut result = builder.push_xnor(0, N as u32); - - for i in 1..N { - let current_comparison = builder.push_xnor(i as u32, (N + i) as u32); - result = builder.push_and(result, current_comparison); - } - let result = builder.execute(lhs, rhs, vec![result]).unwrap(); - result.bits[0] + let mut builder = CircuitBuilder::default(); + let a = builder.input(lhs); + let b = builder.input(rhs); + + let result = builder.eq(&a, &b); + let result = builder + .compile_and_execute::<1>(&vec![result].into()) + .expect("Failed to execute equality circuit"); + result.into() } pub(crate) fn build_and_execute_comparator( lhs: &GarbledUint, rhs: &GarbledUint, ) -> Ordering { - let mut builder: CircuitBuilder = CircuitBuilder::default(); - builder.push_input(lhs); - builder.push_input(rhs); + let mut builder = CircuitBuilder::default(); + let a = builder.input(lhs); + let b = builder.input(rhs); - let (lt_output, eq_output) = comparator_circuit::(&mut builder); + let (lt_output, eq_output) = builder.compare(&a, &b); - let program = builder.build(vec![lt_output, eq_output]); - let input = [lhs.bits.clone(), rhs.bits.clone()].concat(); - let result = get_executor().execute(&program, &input, &[]).unwrap(); + let result = builder + .compile_and_execute::<2>(&vec![lt_output, eq_output].into()) + .expect("Failed to execute equality circuit"); - let lt = result[0]; - let eq = result[1]; + let lt = result.bits[0]; + let eq = result.bits[1]; if lt { Ordering::Less @@ -467,77 +554,49 @@ pub(crate) fn build_and_execute_comparator( } } -fn comparator_circuit(builder: &mut CircuitBuilder) -> (u32, u32) { - let mut eq_list = vec![0; N]; - let mut lt_list = vec![0; N]; - - let i = N - 1; - let eq_i = builder.push_xnor(i as u32, (N + i) as u32); - eq_list[i] = eq_i; - - let nt = builder.push_not(i as u32); - let lt_i = builder.push_and(nt, (N + i) as u32); - lt_list[i] = lt_i; - - for idx in (0..i).rev() { - let xn = builder.push_xnor(idx as u32, (N + idx) as u32); - let eq_i = builder.push_and(eq_list[idx + 1], xn); - eq_list[idx] = eq_i; - - let nt = builder.push_not(idx as u32); - let aa = builder.push_and(nt, (N + idx) as u32); - let temp_lt = builder.push_and(eq_list[idx + 1], aa); - lt_list[idx] = builder.push_or(lt_list[idx + 1], temp_lt); - } - - (lt_list[0], eq_list[0]) -} - pub(crate) fn build_and_execute_not(input: &GarbledUint) -> GarbledUint { let mut builder = CircuitBuilder::default(); - builder.push_input(input); + builder.input(input); - let mut output_indices = Vec::with_capacity(N); + let mut output_indices = GateIndexVec::default(); - for i in 0..N { - let not_gate = builder.push_not(i as u32); + let n = N as u32; + for i in 0..n { + let not_gate = builder.push_not(&i); output_indices.push(not_gate); } builder - .execute_with_input(&input.bits, output_indices) - .expect("Failed to execute NOT circuit") + .compile_and_execute(&output_indices) + .expect("Failed to execute a.len()OT circuit") } #[allow(dead_code)] -pub(crate) fn build_and_execute_mux( - condition: &GarbledUint, +pub(crate) fn build_and_execute_mux( + condition: &GarbledBoolean, if_true: &GarbledUint, if_false: &GarbledUint, ) -> GarbledUint { let mut builder = CircuitBuilder::default(); - builder.push_input(if_false); - builder.push_input(if_true); - builder.push_input(condition); + let a = builder.input(if_true); + let b = builder.input(if_false); + let s = builder.input(condition); // Add MUX gates for each bit - let mut output_indices = Vec::with_capacity(N); - for i in 0..N { - let mux_gate = builder.push_mux(i as u32, (N + i) as u32, (2 * N) as u32); + /* + let mut output_indices = Vec::with_capacity(a.len()); + let n = a.len() as u32; + for i in 0..n { + let mux_gate = builder.push_mux(&i, &(n + i), &(2 * n)); output_indices.push(mux_gate); } + */ - // combine the three inputs into a single value - let input = [ - if_false.bits.clone(), - if_true.bits.clone(), - condition.bits.clone(), - ] - .concat(); + let output = builder.mux(&s[0], &a, &b); // Simulate the circuit builder - .execute_with_input(&input, output_indices) + .compile_and_execute(&output) .expect("Failed to execute MUX circuit") } @@ -551,47 +610,23 @@ mod tests { use crate::uint::GarbledUint8; #[test] - fn test_mux() { - const N: usize = 32; - - let mut builder: CircuitBuilder = CircuitBuilder::default(); - let a: GarbledUint32 = 1900142_u32.into(); // if s is false, output should be a - let b: GarbledUint32 = 771843900_u32.into(); // if s is true, output should be b - let s: GarbledBit = true.into(); - - builder.push_input(&a); - builder.push_input(&b); - builder.push_input(&s); - - // Add MUX gates for each bit - let mut output_indices = Vec::with_capacity(N); - for i in 0..N { - let mux_gate = builder.push_mux(i as u32, (N + i) as u32, (2 * N) as u32); - output_indices.push(mux_gate); - } - - // combine the three inputs into a single value - let input = [a.bits.clone(), b.bits.clone(), s.bits].concat(); + fn test_div() { + let a: GarbledUint8 = 10_u8.into(); + let b: GarbledUint8 = 2_u8.into(); - // Simulate the circuit - let result = builder - .execute_with_input(&input, output_indices.clone()) - .expect("Failed to execute MUX circuit"); - - println!("MUX result: {}", result); - assert_eq!(result, b); - - let s: GarbledBit = false.into(); - // combine the three inputs into a single value - let input = [a.bits.clone(), b.bits.clone(), s.bits].concat(); + let result = build_and_execute_division(&a, &b); + let result_value: u8 = result.into(); + assert_eq!(result_value, 10 / 2); + } - // Simulate the circuit - let result = builder - .execute_with_input(&input, output_indices) - .expect("Failed to execute MUX circuit"); + #[test] + fn test_rem() { + let a: GarbledUint8 = 10_u8.into(); + let b: GarbledUint8 = 3_u8.into(); - println!("MUX result: {}", result); - assert_eq!(result, a); + let result = build_and_execute_remainder(&a, &b); + let result_value: u8 = result.into(); + assert_eq!(result_value, 10 % 3); } #[test] @@ -608,6 +643,7 @@ mod tests { assert_eq!(result, b); } + #[ignore = "mixed bits not supported yet"] #[test] fn test_build_and_execute_mux() { let s: GarbledBit = true.into(); @@ -624,30 +660,380 @@ mod tests { #[test] fn test_build_and_execute_mux32() { - let s: GarbledUint32 = 0b11111111_11111111_11111111_11111111_u32.into(); + let s: GarbledBoolean = true.into(); let a: GarbledUint32 = 28347823_u32.into(); let b: GarbledUint32 = 8932849_u32.into(); let result = build_and_execute_mux(&s, &a, &b); assert_eq!(result, a); - let s: GarbledUint32 = 0_u32.into(); - let result = build_and_execute_mux(&s, &a, &b); + let result = build_and_execute_mux(&false.into(), &a, &b); assert_eq!(result, b); } #[test] fn test_build_and_execute_mux64() { - let s: GarbledUint64 = - 0b11111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111_u64.into(); + let s: GarbledBoolean = true.into(); let a: GarbledUint64 = 23948323290804923_u64.into(); let b: GarbledUint64 = 834289823983634323_u64.into(); let result = build_and_execute_mux(&s, &a, &b); assert_eq!(result, a); - let s: GarbledUint64 = 0_u64.into(); - let result = build_and_execute_mux(&s, &a, &b); + let result = build_and_execute_mux(&false.into(), &a, &b); assert_eq!(result, b); } + + #[test] + fn test_build_and_execute_multiplication() { + let a: GarbledUint8 = 9_u8.into(); + let b: GarbledUint8 = 3_u8.into(); + + let result = build_and_execute_multiplication(&a, &b); + let result_value: u8 = result.into(); + assert_eq!(result_value, 9 * 3); + } + + #[test] + fn test_eq_true() { + let a: GarbledUint8 = 42_u8.into(); + let b: GarbledUint8 = 42_u8.into(); + + let mut builder = CircuitBuilder::default(); + let a = builder.input(&a); + let b = builder.input(&b); + + let output = builder.eq(&a, &b); + + let circuit = builder.compile(&vec![output].into()); + let result = builder + .execute::<1>(&circuit) + .expect("Failed to execute equality circuit"); + let result_value: bool = result.into(); + assert!(result_value); + } + + #[test] + fn test_eq_false() { + let a: GarbledUint8 = 123_u8.into(); + let b: GarbledUint8 = 124_u8.into(); + + let mut builder = CircuitBuilder::default(); + let a = builder.input(&a); + let b = builder.input(&b); + + let output = builder.eq(&a, &b); + + let circuit = builder.compile(&vec![output].into()); + let result = builder + .execute::<1>(&circuit) + .expect("Failed to execute equality circuit"); + let result_value: bool = result.into(); + assert!(!result_value); + } + + #[test] + fn test_ne_true() { + let a: GarbledUint8 = 123_u8.into(); + let b: GarbledUint8 = 124_u8.into(); + + let mut builder = CircuitBuilder::default(); + let a = builder.input(&a); + let b = builder.input(&b); + + let output = builder.ne(&a, &b); + + let circuit = builder.compile(&vec![output].into()); + let result = builder + .execute::<1>(&circuit) + .expect("Failed to execute inequality circuit"); + let result_value: bool = result.into(); + assert!(result_value); + } + + #[test] + fn test_ne_false() { + let a: GarbledUint8 = 42_u8.into(); + let b: GarbledUint8 = 42_u8.into(); + + let mut builder = CircuitBuilder::default(); + let a = builder.input(&a); + let b = builder.input(&b); + + let output = builder.ne(&a, &b); + + let circuit = builder.compile(&vec![output].into()); + let result = builder + .execute::<1>(&circuit) + .expect("Failed to execute inequality circuit"); + let result_value: bool = result.into(); + assert!(!result_value); + } + + #[test] + fn test_lt_true() { + let a: GarbledUint8 = 42_u8.into(); + let b: GarbledUint8 = 43_u8.into(); + + let mut builder = CircuitBuilder::default(); + let a = builder.input(&a); + let b = builder.input(&b); + + let output = builder.lt(&a, &b); + + let circuit = builder.compile(&vec![output].into()); + let result = builder + .execute::<1>(&circuit) + .expect("Failed to execute less than circuit"); + let result_value: bool = result.into(); + assert!(result_value); + } + + #[test] + fn test_lt_false() { + let a: GarbledUint8 = 43_u8.into(); + let b: GarbledUint8 = 42_u8.into(); + + let mut builder = CircuitBuilder::default(); + let a = builder.input(&a); + let b = builder.input(&b); + + let output = builder.lt(&a, &b); + + let circuit = builder.compile(&vec![output].into()); + let result = builder + .execute::<1>(&circuit) + .expect("Failed to execute less than circuit"); + let result_value: bool = result.into(); + assert!(!result_value); + } + + #[test] + fn test_build_and_execute_mixed() { + fn build_and_execute_mixed( + lhs: &GarbledUint, + rhs: &GarbledUint, + ) -> GarbledUint { + let mut builder = CircuitBuilder::default(); + let a = builder.input(lhs); + let b = builder.input(rhs); + + // Create a full adder for each bit + //let add_output = builder.add(&a, &b).0; + //let sub_output = builder.sub(&add_output, &b).0; + //let output = builder.or(&sub_output, &a); + + let output = builder.mul(&a, &b); + let output = builder.mul(&output, &a); + + println!("output: {:?}", output); + // debug gates + builder.gates.iter().for_each(|gate| { + println!("{:?}", gate); + }); + + let circuit = builder.compile(&output); + + // Execute the circuit + builder + .execute(&circuit) + .expect("Failed to execute addition circuit") + } + + let a: GarbledUint8 = 2_u8.into(); + let b: GarbledUint8 = 5_u8.into(); + + let result = build_and_execute_mixed(&a, &b); + let result_value: u8 = result.into(); + assert_eq!(result_value, 2 * 5 * 2); + } + + #[test] + fn test_add_three() { + let mut builder = CircuitBuilder::default(); + let a: GarbledUint8 = 2_u8.into(); + let a = builder.input(&a); + + let b: GarbledUint8 = 5_u8.into(); + let b = builder.input(&b); + + let c: GarbledUint8 = 3_u8.into(); + let c = builder.input(&c); + + let output = builder.add(&a, &b); + let output = builder.add(&output, &c); + + println!("output: {:?}", output); + // debug gates + builder.gates.iter().for_each(|gate| { + println!("{:?}", gate); + }); + + let circuit = builder.compile(&output); + + // Execute the circuit + let result = builder + .execute::<8>(&circuit) + .expect("Failed to execute addition circuit"); + + let result_value: u8 = result.into(); + assert_eq!(result_value, 2 + 5 + 3); + } + + #[test] + fn test_embedded_if_else() { + let mut builder = CircuitBuilder::default(); + let a: GarbledUint8 = 2_u8.into(); + let a = builder.input(&a); + + let b: GarbledUint8 = 5_u8.into(); + let b = builder.input(&b); + + let s: GarbledBoolean = false.into(); + let s: GateIndexVec = builder.input(&s); + + // fails with 'cannot borrow `builder` as mutable more than once at a time' + // let output = builder.mux(s, builder.mul(a.clone(), b.clone()), builder.add(a.clone(), b.clone())); + + let if_true = builder.mul(&a, &b); + let if_false = builder.add(&a, &b); + let output = builder.mux(&s[0], &if_true, &if_false); + + println!("output: {:?}", output); + + let circuit = builder.compile(&output); + + // Execute the circuit + let result = builder + .execute::<8>(&circuit) + .expect("Failed to execute addition circuit"); + + let result_value: u8 = result.into(); + assert_eq!(result_value, 2 + 5); + } + + use circuit_macro::circuit; + + #[test] + fn test_macro_arithmetic() { + let a = 2_u8; + let b = 5_u8; + let c = 3_u8; + let d = 4_u8; + + let result_u8 = my_circuit(&2u8, &3u8, &1u8, &4u8); + println!("Result for u8: {}", result_u8); + + let result: u8 = my_circuit(&a, &b, &c, &d); + assert_eq!(result, a * b + c - d); + + let result = my_circuit_from_macro(a, b, c, d); + assert_eq!(result, a * b + c - d); + + let result = my_circuit_from_macro2(&a, &b, &c, &d); + assert_eq!(result, a * b + c - d); + } + + #[circuit(execute)] + fn my_circuit_from_macro(a: U8, b: U8, c: U8, d: U8) -> U8 { + let res = a * b; + let res = res + c; + res - d + } + + fn my_circuit_from_macro2(a: &U8, b: &U8, c: &U8, d: &U8) -> U8 + where + U8: Into> + + From> + + Into> + + From> + + Into> + + From> + + Into> + + From> + + Into> + + From> + + Clone, + { + fn generate(a: &U8, b: &U8, c: &U8, d: &U8) -> U8 + where + U8: Into> + From> + Clone, + { + let mut context = CircuitBuilder::default(); + let a = &context.input(&a.clone().into()); + let b = &context.input(&b.clone().into()); + let c = &context.input(&c.clone().into()); + let d = &context.input(&d.clone().into()); + let output = { + { + let res = &context.mul(a, b); + let res = &context.add(res, c); + &context.sub(res, d) + } + }; + let compiled_circuit = context.compile(output); + let result = context + .execute::(&compiled_circuit) + .expect("Failed to execute the circuit"); + result.into() + } + match std::any::type_name::() { + "u8" => generate::<8, U8>(a, b, c, d), + "u16" => generate::<16, U8>(a, b, c, d), + "u32" => generate::<32, U8>(a, b, c, d), + "u64" => generate::<64, U8>(a, b, c, d), + "u128" => generate::<128, U8>(a, b, c, d), + _ => panic!("Unsupported type"), + } + } + + fn my_circuit(a: &T, b: &T, c: &T, d: &T) -> T + where + T: Into> + + From> + + Into> + + From> + + Into> + + From> + + Into> + + From> + + Into> + + From> + + Clone, + { + fn generate(a: &T, b: &T, c: &T, d: &T) -> T + where + T: Into> + From> + Clone, + { + let mut context = CircuitBuilder::default(); + //let a = &2_u8; + let a = &context.input(&a.clone().into()); + let b = &context.input(&b.clone().into()); + let c = &context.input(&c.clone().into()); + let d = &context.input(&d.clone().into()); + + let output = { + let res = &context.mul(a, b); + let res = &context.add(res, c); + &context.sub(res, d) + }; + + let output = &output.clone(); + + let compiled_circuit = context.compile(output); + let result = context + .execute::(&compiled_circuit) + .expect("Failed to execute the circuit"); + result.into() + } + + match std::any::type_name::() { + "u8" => generate::<8, T>(a, b, c, d), + "u16" => generate::<16, T>(a, b, c, d), + "u32" => generate::<32, T>(a, b, c, d), + "u64" => generate::<64, T>(a, b, c, d), + "u128" => generate::<128, T>(a, b, c, d), + _ => panic!("Unsupported type"), + } + } } diff --git a/compute/src/operations/circuits/mod.rs b/compute/src/operations/circuits/mod.rs index 5575a85..ab228e5 100644 --- a/compute/src/operations/circuits/mod.rs +++ b/compute/src/operations/circuits/mod.rs @@ -1 +1,2 @@ pub mod builder; +pub mod types; diff --git a/compute/src/operations/circuits/types.rs b/compute/src/operations/circuits/types.rs new file mode 100644 index 0000000..fd36b7b --- /dev/null +++ b/compute/src/operations/circuits/types.rs @@ -0,0 +1,364 @@ +use crate::operations::circuits::builder::GateIndex; +use crate::uint::GarbledBoolean; + +#[derive(Default, Debug, Eq, Hash, PartialEq, Clone)] +pub struct GateIndexVec(Vec); + +impl GateIndexVec { + pub fn new(indices: Vec) -> Self { + Self(indices) + } + + pub fn push(&mut self, value: GateIndex) { + self.0.push(value); + } + + pub fn push_all(&mut self, values: &GateIndexVec) { + self.0.extend_from_slice(&values.0); + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn iter(&self) -> std::slice::Iter { + self.0.iter() + } + + pub fn with_capacity(capacity: usize) -> Self { + Self(Vec::with_capacity(capacity)) + } + + pub fn capacity(&self) -> usize { + self.0.capacity() + } + + pub fn insert(&mut self, index: usize, element: GateIndex) { + self.0.insert(index, element); + } + + pub fn truncate(&mut self, len: usize) { + self.0.truncate(len); + } +} + +// Implement indexing for GateVector +impl std::ops::Index for GateIndexVec { + type Output = GateIndex; + + fn index(&self, index: usize) -> &Self::Output { + &self.0[index] + } +} + +impl From for Vec { + fn from(vec: GateIndexVec) -> Self { + vec.0.to_vec() + } +} + +impl From> for GateIndexVec { + fn from(vec: Vec) -> Self { + Self(vec) + } +} + +impl From for GarbledBoolean { + fn from(vec: GateIndexVec) -> Self { + GarbledBoolean::from(vec.0[0]) + } +} + +impl From> for GateIndexVec { + fn from(vec: Vec<&u32>) -> Self { + let mut indices = Vec::new(); + for index in vec { + indices.push(*index); + } + Self(indices) + } +} + +impl From<&u32> for GateIndexVec { + fn from(index: &u32) -> Self { + Self(vec![*index]) + } +} + +impl From for GateIndexVec { + fn from(index: u32) -> Self { + Self(vec![index]) + } +} + +impl From<&GateIndexVec> for GateIndexVec { + fn from(vec: &GateIndexVec) -> Self { + vec.clone() + } +} + +impl From for GateIndex { + fn from(vec: GateIndexVec) -> Self { + vec.0[0] + } +} + +impl From<&GateIndexVec> for GateIndex { + fn from(vec: &GateIndexVec) -> Self { + vec.0[0] + } +} + +/* +use crate::operations::circuits::builder::CircuitBuilder; + +use std::ops::{ + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, + Mul, MulAssign, Not, Rem, RemAssign, Sub, SubAssign, +}; + +// Implement Add trait for GateIndexVec using the builder reference +impl Add for GateIndexVec { + type Output = Self; + + fn add(self, other: Self) -> Self { + CircuitBuilder::global(|builder| builder.add(&self, &other)) + } +} + +impl Add for &GateIndexVec { + type Output = GateIndexVec; + + fn add(self, other: Self) -> GateIndexVec { + CircuitBuilder::global(|builder| builder.add(self, other)) + } +} + +impl AddAssign for GateIndexVec { + fn add_assign(&mut self, other: Self) { + *self = CircuitBuilder::global(|builder| builder.add(&self, &other)); + } +} + +impl AddAssign<&GateIndexVec> for GateIndexVec { + fn add_assign(&mut self, other: &Self) { + *self = CircuitBuilder::global(|builder| builder.add(self, other)); + } +} + +impl Sub for GateIndexVec { + type Output = Self; + + fn sub(self, other: Self) -> Self { + CircuitBuilder::global(|builder| builder.sub(&self, &other)) + } +} + +impl Sub for &GateIndexVec { + type Output = GateIndexVec; + + fn sub(self, other: Self) -> GateIndexVec { + CircuitBuilder::global(|builder| builder.sub(self, other)) + } +} + +impl SubAssign for GateIndexVec { + fn sub_assign(&mut self, other: Self) { + *self = CircuitBuilder::global(|builder| builder.sub(&self, &other)); + } +} + +impl SubAssign<&GateIndexVec> for GateIndexVec { + fn sub_assign(&mut self, other: &Self) { + *self = CircuitBuilder::global(|builder| builder.sub(self, other)); + } +} + +impl Mul for GateIndexVec { + type Output = Self; + + fn mul(self, other: Self) -> Self { + CircuitBuilder::global(|builder| builder.mul(&self, &other)) + } +} + +impl Mul for &GateIndexVec { + type Output = GateIndexVec; + + fn mul(self, other: Self) -> GateIndexVec { + CircuitBuilder::global(|builder| builder.mul(self, other)) + } +} + +impl MulAssign for GateIndexVec { + fn mul_assign(&mut self, other: Self) { + *self = CircuitBuilder::global(|builder| builder.mul(&self, &other)); + } +} + +impl MulAssign<&GateIndexVec> for GateIndexVec { + fn mul_assign(&mut self, other: &Self) { + *self = CircuitBuilder::global(|builder| builder.mul(self, other)); + } +} + +impl Div for GateIndexVec { + type Output = Self; + + fn div(self, other: Self) -> Self { + CircuitBuilder::global(|builder| builder.div(&self, &other)) + } +} + +impl Div for &GateIndexVec { + type Output = GateIndexVec; + + fn div(self, other: Self) -> GateIndexVec { + CircuitBuilder::global(|builder| builder.div(self, other)) + } +} + +impl DivAssign for GateIndexVec { + fn div_assign(&mut self, other: Self) { + *self = CircuitBuilder::global(|builder| builder.div(&self, &other)); + } +} + +impl DivAssign<&GateIndexVec> for GateIndexVec { + fn div_assign(&mut self, other: &Self) { + *self = CircuitBuilder::global(|builder| builder.div(self, other)); + } +} + +impl Rem for GateIndexVec { + type Output = Self; + + fn rem(self, other: Self) -> Self { + CircuitBuilder::global(|builder| builder.rem(&self, &other)) + } +} + +impl Rem for &GateIndexVec { + type Output = GateIndexVec; + + fn rem(self, other: Self) -> GateIndexVec { + CircuitBuilder::global(|builder| builder.rem(self, other)) + } +} + +impl RemAssign for GateIndexVec { + fn rem_assign(&mut self, other: Self) { + *self = CircuitBuilder::global(|builder| builder.rem(&self, &other)); + } +} + +impl RemAssign<&GateIndexVec> for GateIndexVec { + fn rem_assign(&mut self, other: &Self) { + *self = CircuitBuilder::global(|builder| builder.rem(self, other)); + } +} + +impl BitAnd for GateIndexVec { + type Output = Self; + + fn bitand(self, other: Self) -> Self { + CircuitBuilder::global(|builder| builder.and(&self, &other)) + } +} + +impl BitAnd for &GateIndexVec { + type Output = GateIndexVec; + + fn bitand(self, other: Self) -> GateIndexVec { + CircuitBuilder::global(|builder| builder.and(self, other)) + } +} + +impl BitAndAssign for GateIndexVec { + fn bitand_assign(&mut self, other: Self) { + *self = CircuitBuilder::global(|builder| builder.and(&self, &other)); + } +} + +impl BitAndAssign<&GateIndexVec> for GateIndexVec { + fn bitand_assign(&mut self, other: &Self) { + *self = CircuitBuilder::global(|builder| builder.and(self, other)); + } +} + +impl BitOr for GateIndexVec { + type Output = Self; + + fn bitor(self, other: Self) -> Self { + CircuitBuilder::global(|builder| builder.or(&self, &other)) + } +} + +impl BitOr for &GateIndexVec { + type Output = GateIndexVec; + + fn bitor(self, other: Self) -> GateIndexVec { + CircuitBuilder::global(|builder| builder.or(self, other)) + } +} + +impl BitOrAssign for GateIndexVec { + fn bitor_assign(&mut self, other: Self) { + *self = CircuitBuilder::global(|builder| builder.or(&self, &other)); + } +} + +impl BitOrAssign<&GateIndexVec> for GateIndexVec { + fn bitor_assign(&mut self, other: &Self) { + *self = CircuitBuilder::global(|builder| builder.or(self, other)); + } +} + +impl Not for GateIndexVec { + type Output = Self; + + fn not(self) -> Self { + CircuitBuilder::global(|builder| builder.not(&self)) + } +} + +impl Not for &GateIndexVec { + type Output = GateIndexVec; + + fn not(self) -> GateIndexVec { + CircuitBuilder::global(|builder| builder.not(self)) + } +} + +impl BitXor for GateIndexVec { + type Output = Self; + + fn bitxor(self, other: Self) -> Self { + CircuitBuilder::global(|builder| builder.xor(&self, &other)) + } +} + +impl BitXor for &GateIndexVec { + type Output = GateIndexVec; + + fn bitxor(self, other: Self) -> GateIndexVec { + CircuitBuilder::global(|builder| builder.xor(self, other)) + } +} + +impl BitXorAssign for GateIndexVec { + fn bitxor_assign(&mut self, other: Self) { + *self = CircuitBuilder::global(|builder| builder.xor(&self, &other)); + } +} + +impl BitXorAssign<&GateIndexVec> for GateIndexVec { + fn bitxor_assign(&mut self, other: &Self) { + *self = CircuitBuilder::global(|builder| builder.xor(self, other)); + } +} +*/ diff --git a/compute/src/operations/mux.rs b/compute/src/operations/mux.rs index 06b67f1..25062df 100644 --- a/compute/src/operations/mux.rs +++ b/compute/src/operations/mux.rs @@ -1,16 +1,12 @@ use crate::int::GarbledInt; use crate::operations::circuits::builder::build_and_execute_mux; +use crate::uint::GarbledBoolean; use crate::uint::GarbledUint; impl GarbledUint { // implementation of the MUX operation - pub fn mux(&self, if_true: &GarbledUint, if_false: &GarbledUint) -> GarbledUint { - build_and_execute_mux(self, if_true, if_false) - } - - // implementation of the MUX operation - pub fn mux3( - condition: &GarbledUint, + pub fn mux( + condition: &GarbledBoolean, if_true: &GarbledUint, if_false: &GarbledUint, ) -> GarbledUint { @@ -20,16 +16,11 @@ impl GarbledUint { impl GarbledInt { // implementation of the MUX operation - pub fn mux(&self, if_true: &GarbledInt, if_false: &GarbledInt) -> GarbledInt { - build_and_execute_mux(&self.into(), &if_true.into(), &if_false.into()).into() - } - - // implementation of the MUX operation - pub fn mux3( - condition: &GarbledInt, + pub fn mux( + condition: &GarbledBoolean, if_true: &GarbledInt, if_false: &GarbledInt, ) -> GarbledInt { - build_and_execute_mux(&condition.into(), &if_true.into(), &if_false.into()).into() + build_and_execute_mux(condition, &if_true.into(), &if_false.into()).into() } } diff --git a/compute/src/uint.rs b/compute/src/uint.rs index 5c20e43..4a8ce4e 100644 --- a/compute/src/uint.rs +++ b/compute/src/uint.rs @@ -11,6 +11,10 @@ pub type GarbledUint16 = GarbledUint<16>; pub type GarbledUint32 = GarbledUint<32>; pub type GarbledUint64 = GarbledUint<64>; pub type GarbledUint128 = GarbledUint<128>; +pub type GarbledUint160 = GarbledUint<160>; +pub type GarbledUint256 = GarbledUint<256>; +pub type GarbledUint512 = GarbledUint<512>; +pub type GarbledUint1024 = GarbledUint<1024>; // Define a new type Uint #[derive(Debug, Clone)] @@ -27,6 +31,14 @@ impl GarbledUint { pub fn one() -> Self { GarbledUint::new(vec![true]) } + + pub fn len(&self) -> usize { + self.bits.len() + } + + pub fn is_empty(&self) -> bool { + self.bits.is_empty() + } } impl Display for GarbledUint { @@ -39,7 +51,7 @@ impl Display for GarbledUint { impl GarbledUint { // Constructor for GarbledUint from a boolean vector pub fn new(bits: Vec) -> Self { - assert_eq!(bits.len(), N, "The number of bits must be {}", N); + //assert_eq!(bits.len(), N, "The number of bits must be {}", N); GarbledUint { bits, _phantom: PhantomData, @@ -66,12 +78,6 @@ impl From<&GarbledInt> for GarbledUint { } } -impl From for GarbledBit { - fn from(value: bool) -> Self { - GarbledUint::new(vec![value]) - } -} - impl From for GarbledUint { fn from(value: u8) -> Self { assert!(N <= 8, "Uint can only support up to 8 bits for u8"); @@ -137,8 +143,8 @@ impl From for GarbledUint { } } -impl From for bool { - fn from(guint: GarbledUint<1>) -> Self { +impl From> for bool { + fn from(guint: GarbledUint) -> Self { guint.bits[0] } } @@ -220,3 +226,17 @@ impl From> for u128 { value } } + +impl From for GarbledBit { + fn from(value: bool) -> Self { + GarbledUint::new(vec![value]) + } +} + +/* +impl From for bool { + fn from(guint: GarbledUint<1>) -> Self { + guint.bits[0] + } +} +*/ diff --git a/compute/tests/arithmetic.rs b/compute/tests/arithmetic.rs index 49e4853..fc05377 100644 --- a/compute/tests/arithmetic.rs +++ b/compute/tests/arithmetic.rs @@ -5,6 +5,16 @@ use compute::uint::{ GarbledUint, GarbledUint128, GarbledUint16, GarbledUint32, GarbledUint64, GarbledUint8, }; +#[test] +fn test_uint16_add() { + let a: GarbledUint16 = 11_u16.into(); + let b: GarbledUint16 = 2_u16.into(); + let c: GarbledUint16 = 3_u16.into(); + + let result: u16 = (a + b - c).into(); // Perform addition on the 4-bit values + assert_eq!(result, 11 + 2 - 3); // Expected result of addition between 1010101010101011, 0101010101010101 and 42 +} + #[test] fn test_uint_add() { let a: GarbledUint8 = 170_u8.into(); // Binary 10101010 @@ -402,3 +412,127 @@ fn test_multiple_additions() { let result: u32 = (a + b + c + d + e).into(); assert_eq!(result, 170_u32 + 85_u32 + 42_u32 + 21_u32 + 10_u32); } + +// div + +#[test] +fn test_uint_div() { + let a: GarbledUint8 = 6_u8.into(); // Binary 0110 + let b: GarbledUint8 = 2_u8.into(); // Binary 0010 + + let result: u8 = (a / b).into(); + assert_eq!(result, 6 / 2); // 0110 / 0010 = 0011 + + let a: GarbledUint16 = 300_u16.into(); // Binary 1010101010101011 + let b: GarbledUint16 = 7_u16.into(); // Binary 0101010101010101 + + let result: u16 = (a / b).into(); + assert_eq!(result, 300_u16 / 7_u16); // Expected result of division between 1010101010101011 and 0101010101010101 +} + +#[test] +fn test_int_div() { + let a: GarbledInt8 = 6_i8.into(); + let b: GarbledInt8 = 2_i8.into(); + + let result: i8 = (a / b).into(); + assert_eq!(result, 6_i8 / 2_i8); + + let a: GarbledInt16 = 134_i16.into(); + let b: GarbledInt16 = 85_i16.into(); + + let result: i16 = (a / b).into(); + assert_eq!(result, 134_i16 / 85_i16); +} + +#[test] +fn test_uint_div_assign() { + let mut a: GarbledUint8 = 6_u8.into(); // Binary 0110 + let b: GarbledUint8 = 2_u8.into(); // Binary 0010 + + a /= b; + assert_eq!( as Into>::into(a), 6 / 2); // 0110 / 0010 = 0011 + + let mut a: GarbledUint16 = 300_u16.into(); // Binary 1010101010101011 + let b: GarbledUint16 = 7_u16.into(); // Binary 0101010101010101 + + a /= b; + assert_eq!( as Into>::into(a), 300_u16 / 7_u16); // Expected result of division between 1010101010101011 and 0101010101010101 +} + +#[test] +fn test_int_div_assign() { + let mut a: GarbledInt8 = 6_i8.into(); + let b: GarbledInt8 = 2_i8.into(); + + a /= b; + assert_eq!( as Into>::into(a), 6_i8 / 2_i8); + + let mut a: GarbledInt16 = 134_i16.into(); + let b: GarbledInt16 = 85_i16.into(); + + a /= b; + assert_eq!( as Into>::into(a), 134_i16 / 85_i16); +} + +// rem + +#[test] +fn test_uint_rem() { + let a: GarbledUint8 = 6_u8.into(); // Binary 0110 + let b: GarbledUint8 = 2_u8.into(); // Binary 0010 + + let result: u8 = (a % b).into(); + assert_eq!(result, 6 % 2); // 0110 % 0010 = 0000 + + let a: GarbledUint16 = 300_u16.into(); // Binary 1010101010101011 + let b: GarbledUint16 = 7_u16.into(); // Binary 0101010101010101 + + let result: u16 = (a % b).into(); + assert_eq!(result, 300_u16 % 7_u16); // Expected result of remainder between 1010101010101011 and 0101010101010101 +} + +#[test] +fn test_int_rem() { + let a: GarbledInt8 = 6_i8.into(); + let b: GarbledInt8 = 2_i8.into(); + + let result: i8 = (a % b).into(); + assert_eq!(result, 6_i8 % 2_i8); + + let a: GarbledInt16 = 134_i16.into(); + let b: GarbledInt16 = 85_i16.into(); + + let result: i16 = (a % b).into(); + assert_eq!(result, 134_i16 % 85_i16); +} + +#[test] +fn test_uint_rem_assign() { + let mut a: GarbledUint8 = 6_u8.into(); // Binary 0110 + let b: GarbledUint8 = 2_u8.into(); // Binary 0010 + + a %= b; + assert_eq!( as Into>::into(a), 6 % 2); // 0110 % 0010 = 0000 + + let mut a: GarbledUint16 = 300_u16.into(); // Binary 1010101010101011 + let b: GarbledUint16 = 7_u16.into(); // Binary 0101010101010101 + + a %= b; + assert_eq!( as Into>::into(a), 300_u16 % 7_u16); // Expected result of remainder between 1010101010101011 and 0101010101010101 +} + +#[test] +fn test_int_rem_assign() { + let mut a: GarbledInt8 = 6_i8.into(); + let b: GarbledInt8 = 2_i8.into(); + + a %= b; + assert_eq!( as Into>::into(a), 6_i8 % 2_i8); + + let mut a: GarbledInt16 = 134_i16.into(); + let b: GarbledInt16 = 85_i16.into(); + + a %= b; + assert_eq!( as Into>::into(a), 134_i16 % 85_i16); +} diff --git a/compute/tests/macro.rs b/compute/tests/macro.rs new file mode 100644 index 0000000..a1149c8 --- /dev/null +++ b/compute/tests/macro.rs @@ -0,0 +1,600 @@ +use circuit_macro::circuit; +use compute::executor::get_executor; +use compute::operations::circuits::builder::CircuitBuilder; +use compute::uint::GarbledUint; + +use tandem::Circuit; + +#[test] +fn test_macro_arithmetic_compiler() { + #[circuit(compile)] + fn multi_arithmetic(a: u8, b: u8, c: u8, d: u8) -> (Circuit, Vec) { + let res = a * b; + let res = res + c; + res - d + } + + let a = 2_u8; + let b = 5_u8; + let c = 3_u8; + let d = 4_u8; + + let (circuit, inputs) = multi_arithmetic(a, b, c, d); + let result = get_executor().execute(&circuit, &inputs, &[]).unwrap(); + let result: GarbledUint<8> = GarbledUint::new(result); + let result: u8 = result.into(); + assert_eq!(result, a * b + c - d); +} + +#[test] +fn test_macro_arithmetic() { + #[circuit(execute)] + fn multi_arithmetic(a: u8, b: u8, c: u8, d: u8) -> u8 { + let res = a * b; + let res = res + c; + res - d + } + + let a = 2_u8; + let b = 5_u8; + let c = 3_u8; + let d = 4_u8; + + let result = multi_arithmetic(a, b, c, d); + assert_eq!(result, a * b + c - d); +} + +#[test] +fn test_macro_arithmetic_u128() { + #[circuit(execute)] + fn multi_arithmetic_u128(a: u8, b: u8, c: u8, d: u8) -> u8 { + let res = a + b; + let res = res + c; + res - d + } + + let a = 2_u128; + let b = 5_u128; + let c = 3_u128; + let d = 4_u128; + + let result = multi_arithmetic_u128(a, b, c, d); + assert_eq!(result, a + b + c - d); +} + +#[test] +fn test_macro_mixed_arithmetic() { + #[circuit(execute)] + fn mixed_arithmetic(a: u8, b: u8, c: u8, d: u8) -> u8 { + let res = a * b; + let res = context.add(res, c); + let res = res - d; + context.mul(res, a) + } + + let a = 2_u8; + let b = 5_u8; + let c = 3_u8; + let d = 4_u8; + + let result = mixed_arithmetic(a, b, c, d); + assert_eq!(result, ((a * b + c - d) * a)); +} + +#[test] +fn test_macro_addition() { + #[circuit(execute)] + fn addition(a: u8, b: u8) -> u8 { + a + b + } + + let a = 2_u8; + let b = 5_u8; + + let result = addition(a, b); + assert_eq!(result, a + b); +} + +#[test] +fn test_macro_subtraction() { + #[circuit(execute)] + fn subtraction(a: u8, b: u8) -> u8 { + a - b + } + + let a = 20_u8; + let b = 5_u8; + + let result = subtraction(a, b); + assert_eq!(result, a - b); +} + +#[test] +fn test_macro_multiplication() { + #[circuit(execute)] + fn multiplication(a: u8, b: u8) -> u8 { + a * b + } + + let a = 20_u8; + let b = 5_u8; + + let result = multiplication(a, b); + assert_eq!(result, a * b); +} + +#[test] +fn test_macro_mux() { + #[circuit(execute)] + fn mux_circuit(a: u8, b: u8) -> u8 { + let condition = a == b; + &context.mux(condition, a, b) + } + + let a = 5_u8; + let b = 10_u8; + + let result = mux_circuit(a, b); + assert_eq!(result, b); +} + +#[test] +fn test_macro_if_else() { + #[circuit(execute)] + fn mux_circuit(a: T, b: T) -> T { + if a == b { + let c = a * b; + c + a + } else { + a + b + } + } + + let a = 10_u16; + let b = 5_u16; + + let result: u16 = mux_circuit(a, b); + assert_eq!(result, a + b); +} + +#[test] +fn test_macro_if_else2() { + #[circuit(execute)] + fn mux_circuit(a: u8, b: u8) -> u8 { + let true_branch = a * b; + let false_branch = a + b; + let condition = a == b; + if condition { + true_branch + } else { + false_branch + } + } + + let a = 10_u8; + let b = 5_u8; + + let result = mux_circuit(a, b); + assert_eq!(result, a + b); + + let a = 5_u8; + let result = mux_circuit(a, b); + assert_eq!(result, a * b); +} + +#[test] +fn test_macro_if_else3() { + #[circuit(execute)] + fn mux_circuit(a: u8, b: u8) -> u8 { + if a == b { + a * b + } else { + a + b + } + } + + let a = 4_u8; + let b = 4_u8; + + let result = mux_circuit(a, b); + assert_eq!(result, a * b); + + let a = 5_u8; + let result = mux_circuit(a, b); + assert_eq!(result, a + b); +} + +#[test] +fn test_macro_if_else4() { + #[circuit(execute)] + fn mux_circuit(a: u8, b: u8) -> u8 { + if a == b { + let c = a * b; + c + a + } else { + let x = a + b; + x * x + } + } + + let a = 5_u8; + let b = 7_u8; + + let result = mux_circuit(a, b); + assert_eq!(result, (a + b) * (a + b)); +} + +#[ignore = "division not yet supported"] +#[test] +fn test_macro_division() { + #[circuit(execute)] + fn division(a: u8, b: u8) -> u8 { + a / b + } + + let a = 20_u8; + let b = 5_u8; + + let result = division(a, b); + assert_eq!(result, a / b); +} + +#[ignore = "modulo not yet supported"] +#[test] +fn test_macro_remainder() { + #[circuit(execute)] + fn remainder(a: u8, b: u8) -> u8 { + a % b + } + + let a = 20_u8; + let b = 5_u8; + + let result = remainder(a, b); + assert_eq!(result, a % b); +} + +#[test] +fn test_macro_nested_arithmetic() { + #[circuit(execute)] + fn nested_arithmetic(a: u8, b: u8, c: u8, d: u8) -> u8 { + let res = a * b; + let res = res + c; + res - d + } + + let a = 2_u8; + let b = 5_u8; + let c = 3_u8; + let d = 4_u8; + + let result = nested_arithmetic(a, b, c, d); + assert_eq!(result, a * b + c - d); +} + +// test bitwise operations +#[test] +fn test_macro_bitwise_and() { + #[circuit(execute)] + fn bitwise_and(a: u8, b: u8) -> u8 { + a & b + } + + let a = 2_u8; + let b = 3_u8; + + let result = bitwise_and(a, b); + assert_eq!(result, a & b); +} + +#[test] +fn test_macro_bitwise_or() { + #[circuit(execute)] + fn bitwise_or(a: u8, b: u8) -> u8 { + a | b + } + + let a = 2_u8; + let b = 3_u8; + + let result = bitwise_or(a, b); + assert_eq!(result, a | b); +} + +#[test] +fn test_macro_bitwise_xor() { + #[circuit(execute)] + fn bitwise_xor(a: u8, b: u8) -> u8 { + a ^ b + } + + let a = 2_u8; + let b = 3_u8; + + let result = bitwise_xor(a, b); + assert_eq!(result, a ^ b); +} + +#[test] +fn test_macro_bitwise_not() { + #[circuit(execute)] + fn bitwise_not(a: u8) -> u8 { + !a + } + + let a = 2_u8; + + let result = bitwise_not(a); + assert_eq!(result, !a); +} + +#[test] +fn test_macro_bitwise_nand() { + #[circuit(execute)] + fn bitwise_nand(a: u8, b: u8) -> u8 { + let and = a & b; + !and + } + + let a = 2_u8; + let b = 3_u8; + + let result = bitwise_nand(a, b); + assert_eq!(result, !(a & b)); +} + +#[test] +fn test_macro_bitwise_nor() { + #[circuit(execute)] + fn bitwise_nor(a: u8, b: u8) -> u8 { + let or = a | b; + !or + } + + let a = 2_u8; + let b = 3_u8; + + let result = bitwise_nor(a, b); + assert_eq!(result, !(a | b)); +} + +#[test] +fn test_macro_bitwise_xnor() { + #[circuit(execute)] + fn bitwise_xnor(a: u8, b: u8) -> u8 { + let xor = a ^ b; + !xor + } + + let a = 2_u8; + let b = 3_u8; + + let result = bitwise_xnor(a, b); + assert_eq!(result, !(a ^ b)); +} + +#[test] +fn test_macro_equal() { + #[circuit(execute)] + fn equal(a: u8, b: u8) -> u8 { + if a == b { + a * b + } else { + a + b + } + } + + let a = 2_u8; + let b = 3_u8; + + let result = equal(a, b); + assert_eq!(result, a + b); +} + +#[test] +fn test_macro_not_equal() { + #[circuit(execute)] + fn not_equal(a: u8, b: u8) -> u8 { + if a != b { + a * b + } else { + a + b + } + } + + let a = 2_u8; + let b = 3_u8; + + let result = not_equal(a, b); + assert_eq!(result, a * b); +} + +#[test] +fn test_macro_greater_than() { + #[circuit(execute)] + fn greater_than(a: u8, b: u8) -> u8 { + if a > b { + a * b + } else { + a + b + } + } + + let a = 2_u8; + let b = 3_u8; + + let result = greater_than(a, b); + assert_eq!(result, a + b); + + let a = 3_u8; + let result = greater_than(a, b); + assert_eq!(result, a + b); + + let a = 4_u8; + let result = greater_than(a, b); + assert_eq!(result, a * b); +} + +#[test] +fn test_macro_greater_than_or_equal() { + #[circuit(execute)] + fn greater_than_or_equal(a: u8, b: u8) -> u8 { + if a >= b { + a * b + } else { + a + b + } + } + + let a = 2_u8; + let b = 3_u8; + + let result = greater_than_or_equal(a, b); + assert_eq!(result, a + b); + + let a = 3_u8; + let result = greater_than_or_equal(a, b); + assert_eq!(result, a * b); + + let a = 4_u8; + let result = greater_than_or_equal(a, b); + assert_eq!(result, a * b); +} + +#[test] +fn test_macro_less_than() { + #[circuit(execute)] + fn less_than(a: u8, b: u8) -> u8 { + if a < b { + a * b + } else { + a + b + } + } + + let a = 2_u8; + let b = 3_u8; + + let result = less_than(a, b); + assert_eq!(result, a * b); + + let a = 3_u8; + let result = less_than(a, b); + assert_eq!(result, a + b); + + let a = 4_u8; + let result = less_than(a, b); + assert_eq!(result, a + b); +} + +#[test] +fn test_macro_less_than_or_equal() { + #[circuit(execute)] + fn less_than_or_equal(a: u8, b: u8) -> u8 { + if a <= b { + a * b + } else { + a + b + } + } + + let a = 2_u8; + let b = 3_u8; + + let result = less_than_or_equal(a, b); + assert_eq!(result, a * b); + + let a = 3_u8; + let result = less_than_or_equal(a, b); + assert_eq!(result, a * b); + + let a = 4_u8; + let result = less_than_or_equal(a, b); + assert_eq!(result, a + b); +} + +#[test] +fn test_macro_bool_return() { + #[circuit(execute)] + fn equal(a: u8, b: u8) -> bool { + a == b + } + + let a = 2_u8; + let b = 3_u8; + + let result = equal(a, b); + assert!(!result); +} + +// div +#[test] +fn test_macro_div() { + #[circuit(execute)] + fn div(a: u8, b: u8) -> u8 { + a / b + } + + let a = 20_u8; + let b = 5_u8; + + let result = div(a, b); + assert_eq!(result, a / b); +} + +#[test] +fn test_macro_div_with_remainder() { + #[circuit(execute)] + fn div(a: u8, b: u8) -> u8 { + a / b + } + + let a = 20_u8; + let b = 3_u8; + + let result = div(a, b); + assert_eq!(result, a / b); +} + +#[test] +fn test_macro_div_with_remainder2() { + #[circuit(execute)] + fn div(a: u8, b: u8) -> u8 { + a / b + } + + let a = 20_u8; + let b = 7_u8; + + let result = div(a, b); + assert_eq!(result, a / b); +} + +// rem +#[test] +fn test_macro_rem() { + #[circuit(execute)] + fn rem(a: u8, b: u8) -> u8 { + a % b + } + + let a = 20_u8; + let b = 5_u8; + + let result = rem(a, b); + assert_eq!(result, a % b); +} + +#[test] +fn test_macro_rem_with_remainder() { + #[circuit(execute)] + fn rem(a: u8, b: u8) -> u8 { + a % b + } + + let a = 20_u8; + let b = 3_u8; + + let result = rem(a, b); + assert_eq!(result, a % b); +} diff --git a/compute/tests/single_macro.rs b/compute/tests/single_macro.rs new file mode 100644 index 0000000..1436242 --- /dev/null +++ b/compute/tests/single_macro.rs @@ -0,0 +1,20 @@ +/* +use circuit_macro::circuit; +use compute::operations::circuits::builder::{CircuitBuilder, GateIndex}; +use compute::uint::GarbledUint; + +#[ignore = "reason"] +#[test] +fn test_macro_logical_and() { + #[circuit(execute)] + fn logical_and(a: bool, b: bool) -> bool { + a && b + } + + let a = true; + let b = false; + + let result = logical_and(a, b); + assert_eq!(result, false); +} +*/