Skip to content

Commit

Permalink
fix(macro): added support for order of ops; added if/elseif support
Browse files Browse the repository at this point in the history
  • Loading branch information
10d9e committed Oct 28, 2024
1 parent 641cc49 commit 394c876
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 35 deletions.
25 changes: 21 additions & 4 deletions benchmark/benches/benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fn tfhe_encrypted_addition() -> Result<(), Box<dyn std::error::Error>> {
}

// Another function to benchmark
fn gateway_encrypted_addition() -> Result<(), Box<dyn std::error::Error>> {
fn _gateway_encrypted_addition2() -> Result<(), Box<dyn std::error::Error>> {
use compute::uint::GarbledUint128;

let clear_a = 12297829382473034410u128;
Expand All @@ -45,6 +45,23 @@ fn gateway_encrypted_addition() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

// Another function to benchmark
fn gateway_encrypted_addition() -> Result<(), Box<dyn std::error::Error>> {
use compute::prelude::*;

#[circuit(execute)]
fn addition(a: u128, b: u128) -> u128 {
a + b
}

let clear_a = 12297829382473034410u128;
let clear_b = 424242424242u128;

let result = addition(clear_a, clear_b);
assert_eq!(result, clear_a + clear_b);
Ok(())
}

fn tfhe_encrypted_bitwise_and() -> Result<(), Box<dyn std::error::Error>> {
use tfhe::prelude::*;
use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint128};
Expand Down Expand Up @@ -1034,15 +1051,15 @@ criterion_group!(
name = benches;
config = custom_criterion();
targets =
benchmark_gateway_encrypted_addition,
benchmark_tfhe_encrypted_addition,

benchmark_gateway_encrypted_division,
benchmark_tfhe_encrypted_division,
benchmark_gateway_encrypted_modulus,
benchmark_tfhe_encrypted_modulus,

benchmark_gateway_encrypted_mux,
benchmark_tfhe_encrypted_mux,
benchmark_gateway_encrypted_addition,
benchmark_tfhe_encrypted_addition,
benchmark_gateway_encrypted_subtraction,
benchmark_tfhe_encrypted_subtraction,
benchmark_gateway_encrypted_multiplication,
Expand Down
55 changes: 39 additions & 16 deletions circuit_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ fn generate_macro(item: TokenStream, mode: &str) -> TokenStream {
let mut context = CircuitBuilder::default();
#(#mapped_inputs)*
#(#constants)*
let const_true = &context.input::<N>(&1u128.into());
let const_false = &context.input::<N>(&0u128.into());

// Use the transformed function block (with context.add and if/else replacements)
let output = { #transformed_block };
Expand Down Expand Up @@ -170,6 +172,11 @@ fn modify_body(block: syn::Block, constants: &mut Vec<proc_macro2::TokenStream>)
/// Replaces binary operators and if/else expressions with appropriate context calls.
fn replace_expressions(expr: Expr, constants: &mut Vec<proc_macro2::TokenStream>) -> Expr {
match expr {
// Handle parentheses to ensure proper order of operations
Expr::Paren(expr_paren) => {
let inner_expr = replace_expressions(*expr_paren.expr, constants);
syn::parse_quote! { (#inner_expr) }
}
Expr::Lit(syn::ExprLit {
lit: Lit::Int(lit_int),
..
Expand Down Expand Up @@ -443,32 +450,48 @@ fn replace_expressions(expr: Expr, constants: &mut Vec<proc_macro2::TokenStream>
&context.not(&single.into())
}}
}

Expr::If(ExprIf {
cond,
then_branch,
else_branch,
..
}) => {
if let Some((_, else_branch)) = else_branch {
let then_expr = modify_body(then_branch.clone(), constants);
let cond_expr = replace_expressions(*cond, constants);
let then_block = modify_body(then_branch, constants);

let else_expr = match *else_branch {
syn::Expr::Block(syn::ExprBlock { block, .. }) => {
modify_body(block.clone(), constants)
if let Some((_, else_expr)) = else_branch {
match *else_expr {
Expr::If(else_if) => {
let else_if_expr = replace_expressions(Expr::If(else_if), constants);
syn::parse_quote! {{
let if_true = #then_block;
let if_false = #else_if_expr;
let cond = #cond_expr;
&context.mux(cond, if_true, if_false)
}}
}
_ => panic!("Expected a block in else branch"),
};

let cond = replace_expressions(*cond.clone(), constants);

_ => {
let else_block = modify_body(syn::parse_quote! { #else_expr }, constants);
syn::parse_quote! {{
let if_true = #then_block;
let if_false = #else_block;
let cond = #cond_expr;
&context.mux(cond, if_true, if_false)
}}
}
}
} else {
panic!("If without else is not supported");
/*
syn::parse_quote! {{
let if_true = #then_expr;
let if_false = #else_expr;
let cond = #cond;
&context.mux(cond, if_true, if_false)
let if_true = #then_block;
//let if_false = context.len() + 1;
let cond = #cond_expr;
let if_false = &context.len() + 1;
&context.mux(&cond, &if_true, &if_false.into());
}}
} else {
panic!("Expected else branch for if expression");
*/
}
}

Expand Down
18 changes: 9 additions & 9 deletions compute/examples/loan_eligibility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,23 @@ fn evaluate_loan_application(income: u32, credit_score: u32, debt_ratio: u32) ->
if income >= HIGH_INCOME_REQ && credit_score >= MIN_CREDIT_SCORE && debt_ratio <= MAX_DEBT_RATIO
{
FULLY_APPROVED
} else {
let income_and_credit_score = income >= MIN_INCOME_REQ && credit_score >= MIN_CREDIT_SCORE;
} else if income >= MIN_INCOME_REQ
&& credit_score >= MIN_CREDIT_SCORE
&& debt_ratio >= MAX_CONDITIONAL_DEBT_RATIO
{
// Check for Conditional Approval
if income_and_credit_score && debt_ratio <= MAX_CONDITIONAL_DEBT_RATIO {
CONDITIONAL_APPROVED
} else {
// Denied if neither criteria met
DENIED
}
CONDITIONAL_APPROVED
} else {
// Denied if neither criteria met
DENIED
}
}

fn main() {
// Example applicant data
let income = 75000_u32;
let credit_score = 680_u32;
let debt_ratio = 30_u32;
let debt_ratio = 90_u32;

let result = evaluate_loan_application(income, credit_score, debt_ratio);

Expand Down
19 changes: 14 additions & 5 deletions compute/src/operations/circuits/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl CircuitBuilder {
output.into()
}

// Add a a.len()OT gate for a single input and return the index
// Add a NOT gate for a single input and return the index
pub fn push_not(&mut self, a: &GateIndex) -> GateIndex {
let not_index = self.gates.len() as u32;
self.gates.push(Gate::Not(*a));
Expand Down Expand Up @@ -195,6 +195,16 @@ impl CircuitBuilder {
output
}

pub fn mux2(&mut self, s: &GateIndex, a: &GateIndexVec, b: &GateIndex) -> GateIndexVec {
// repeat with output_indices
let mut output = GateIndexVec::default();
for i in 0..a.len() {
let mux = self.push_mux(s, b, &a[i]);
output.push(mux);
}
output
}

#[allow(dead_code)]
// Add a MUX gate: MUX(a, b, s) = (a & !s) | (b & s)
pub fn push_mux(&mut self, s: &GateIndex, a: &GateIndex, b: &GateIndex) -> GateIndex {
Expand Down Expand Up @@ -498,10 +508,9 @@ fn partial_product_shift(
for i in 0..lhs.len() {
if i < shift {
// For the lower bits, we push a constant 0.
let zero_bit = builder.len();
builder.push_not(&rhs[0]);
let _zero = builder.push_and(&rhs[0], &zero_bit); // Constant 0
shifted.push(builder.len() - 1);
let zero_bit = builder.push_not(&rhs[0]);
let and_gate = builder.push_and(&rhs[0], &zero_bit); // Constant 0
shifted.push(and_gate);
} else {
let lhs_bit = lhs[i - shift];
let and_gate = builder.push_and(&lhs_bit, &(rhs[shift]));
Expand Down
28 changes: 28 additions & 0 deletions compute/tests/macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,31 @@ fn test_macro_embedded_constants() {
let result = embedded_constants(a);
assert_eq!(result, 30_u8);
}

#[test]
fn test_order_of_operations() {
#[circuit(execute)]
fn order_of_operations(a: u16, b: u16, c: u16) -> u16 {
a + b * c
}

let a = 10_u16;
let b = 20_u16;
let c = 30_u16;
let result = order_of_operations(a, b, c);
assert_eq!(result, 610_u16);
}

#[test]
fn test_order_of_operations2() {
#[circuit(execute)]
fn order_of_operations(a: u16, b: u16, c: u16) -> u16 {
(a + b) * c
}

let a = 10_u16;
let b = 20_u16;
let c = 30_u16;
let result = order_of_operations(a, b, c);
assert_eq!(result, 900);
}
Loading

0 comments on commit 394c876

Please sign in to comment.