From 9257f9eba598dd7819a7107fc16e2e90abf48e3c Mon Sep 17 00:00:00 2001 From: microproofs Date: Thu, 9 Jan 2025 15:41:47 +0700 Subject: [PATCH] Inline now handles (if cond then body else error) patterns. This allows conditions like expect x == 1 to match performance with x == 1 && ... --- crates/aiken-lang/src/gen_uplc.rs | 4 +- ...oject__export__tests__recursive_types.snap | 5 +- crates/aiken-project/src/tests/gen_uplc.rs | 23 +- crates/uplc/src/optimize/shrinker.rs | 393 +++++++++++++----- 4 files changed, 298 insertions(+), 127 deletions(-) diff --git a/crates/aiken-lang/src/gen_uplc.rs b/crates/aiken-lang/src/gen_uplc.rs index 588edc00c..d6f4c4b6a 100644 --- a/crates/aiken-lang/src/gen_uplc.rs +++ b/crates/aiken-lang/src/gen_uplc.rs @@ -217,9 +217,11 @@ impl<'a> CodeGenerator<'a> { fn finalize(&mut self, mut term: Term) -> Program { term = self.special_functions.apply_used_functions(term); - + println!("PROG BEFORE IS {}", term.to_pretty()); let program = aiken_optimize_and_intern(self.new_program(term)); + println!("PROG IS {}", program.to_pretty()); + // This is very important to call here. // If this isn't done, re-using the same instance // of the generator will result in free unique errors diff --git a/crates/aiken-project/src/snapshots/aiken_project__export__tests__recursive_types.snap b/crates/aiken-project/src/snapshots/aiken_project__export__tests__recursive_types.snap index 586b057ca..b88efa23a 100644 --- a/crates/aiken-project/src/snapshots/aiken_project__export__tests__recursive_types.snap +++ b/crates/aiken-project/src/snapshots/aiken_project__export__tests__recursive_types.snap @@ -1,7 +1,6 @@ --- source: crates/aiken-project/src/export.rs description: "Code:\n\npub type Foo {\n Empty\n Bar(a, Foo)\n}\n\npub fn add(a: Foo, b: Foo) -> Int {\n when (a, b) is {\n (Empty, Empty) -> 0\n (Bar(x, y), Bar(c, d)) -> x + c + add(y, d)\n (Empty, Bar(c, d)) -> c + add(Empty, d)\n (Bar(x, y), Empty) -> x + add(y, Empty)\n }\n}\n" -snapshot_kind: text --- { "name": "test_module.add", @@ -25,8 +24,8 @@ snapshot_kind: text "$ref": "#/definitions/Int" } }, - "compiledCode": "59017d0101003232323232322232323232325333008300430093754002264a666012600a60146ea800452000132337006eb4c038004cc011300103d8798000300e300f001300b37540026018601a00a264a66601266e1d2002300a37540022646466e00cdc01bad300f002375a601e0026600a601e6020004601e602000260186ea8008c02cdd500109919b80375a601c00266008601c601e002980103d8798000300b37540046018601a00a601600860020024446464a666014600c60166ea80044c94ccc02cc01cc030dd50008a400026466e00dd69808000999803803a60103d879800030103011001300d3754002601c601e004264a66601666e1d2002300c37540022646466e00cdc01bad3011002375a60220026660100106022602400460226024002601c6ea8008c034dd500109919b80375a602000266600e00e60206022002980103d8798000300d3754004601c601e004601a002660160046601600297ae0370e90001980300119803000a5eb815cd2ab9d5573cae815d0aba201", - "hash": "c6af3f04e300cb8c1d0429cc0d8e56a0413eef9fcb338f72076b426c", + "compiledCode": "59017d0101003333332222222232323232325333008300430093754002264a666012600a60146ea800452000132337006eb4c038004cc011300103d8798000300e300f001300b37540026018601a00a264a66601266e1d2002300a37540022646466e00cdc01bad300f002375a601e0026600a601e6020004601e602000260186ea8008c02cdd500109919b80375a601c00266008601c601e002980103d8798000300b37540046018601a00a601600860020024446464a666014600c60166ea80044c94ccc02cc01cc030dd50008a400026466e00dd69808000999803803a60103d879800030103011001300d3754002601c601e004264a66601666e1d2002300c37540022646466e00cdc01bad3011002375a60220026660100106022602400460226024002601c6ea8008c034dd500109919b80375a602000266600e00e60206022002980103d8798000300d3754004601c601e004601a002660160046601600297ae0370e90001980300119803000a5eb815d12ba15740aae7955ceab9a01", + "hash": "1f7ba1be8ac6bba7b61d818e0f274b1feae61d24737dd9e833d59430", "definitions": { "Int": { "dataType": "integer" diff --git a/crates/aiken-project/src/tests/gen_uplc.rs b/crates/aiken-project/src/tests/gen_uplc.rs index 8a5506a5f..a51292966 100644 --- a/crates/aiken-project/src/tests/gen_uplc.rs +++ b/crates/aiken-project/src/tests/gen_uplc.rs @@ -64,6 +64,7 @@ fn assert_uplc(source_code: &str, expected: Term, should_fail: bool, verbo version: (1, 1, 0), term: expected, }; + println!("BEFORE OPT IS {}", expected.to_pretty()); let expected = optimize::aiken_optimize_and_intern(expected); @@ -3603,7 +3604,7 @@ fn when_bool_is_true() { assert_uplc( src, Term::var("subject") - .delayed_if_then_else(Term::bool(true), Term::Error) + .delayed_if_then_else(Term::bool(true), Term::Error.delay().force()) .lambda("subject") .apply(Term::bool(true)), false, @@ -3627,7 +3628,7 @@ fn when_bool_is_true_switched_cases() { assert_uplc( src, Term::var("subject") - .delayed_if_then_else(Term::bool(true), Term::Error) + .delayed_if_then_else(Term::bool(true), Term::Error.delay().force()) .lambda("subject") .apply(Term::bool(true)), false, @@ -3651,7 +3652,7 @@ fn when_bool_is_false() { assert_uplc( src, Term::var("subject") - .delayed_if_then_else(Term::bool(true), Term::Error) + .delayed_if_then_else(Term::bool(true), Term::Error.delay().force()) .lambda("subject") .apply(Term::bool(false)), true, @@ -4088,16 +4089,16 @@ fn generic_validator_type_test() { Term::tail_list() .apply(Term::Var(tail_id_5.clone())) .as_var("tail_id_6", |tail_id_6| { - Term::head_list() + Term::tail_list() .apply(Term::Var(tail_id_6.clone())) - .as_var("__val", |val| { - Term::tail_list() + .delayed_choose_list( + Term::head_list() .apply(Term::Var(tail_id_6)) - .delayed_choose_list( - expect_b(val, Term::Var(then_delayed), trace), - Term::Error, - ) - }) + .as_var("__val", |val| { + expect_b(val, Term::Var(then_delayed), trace) + }), + Term::Error, + ) }) } }); diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index 8dacdd069..0b37631da 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -167,6 +167,7 @@ impl DefaultFunction { ) } /// For now all of the curry builtins are not forceable + /// Curryable builtins must take in 2 or more arguments pub fn can_curry_builtin(self) -> bool { matches!( self, @@ -205,7 +206,8 @@ impl DefaultFunction { | DefaultFunction::MultiplyInteger | DefaultFunction::EqualsInteger | DefaultFunction::LessThanInteger - | DefaultFunction::LessThanEqualsInteger => arg_stack.iter().all(|arg| { + | DefaultFunction::LessThanEqualsInteger + | DefaultFunction::IData => arg_stack.iter().all(|arg| { if let Term::Constant(c) = arg { matches!(c.as_ref(), Constant::Integer(_)) } else { @@ -226,10 +228,13 @@ impl DefaultFunction { false } }), - DefaultFunction::EqualsByteString + DefaultFunction::LengthOfByteString + | DefaultFunction::EqualsByteString | DefaultFunction::AppendByteString | DefaultFunction::LessThanEqualsByteString - | DefaultFunction::LessThanByteString => arg_stack.iter().all(|arg| { + | DefaultFunction::LessThanByteString + | DefaultFunction::DecodeUtf8 + | DefaultFunction::BData => arg_stack.iter().all(|arg| { if let Term::Constant(c) = arg { matches!(c.as_ref(), Constant::ByteString(_)) } else { @@ -282,24 +287,26 @@ impl DefaultFunction { } } - DefaultFunction::EqualsString | DefaultFunction::AppendString => { + DefaultFunction::EqualsString + | DefaultFunction::AppendString + | DefaultFunction::EncodeUtf8 => arg_stack.iter().all(|arg| { + if let Term::Constant(c) = arg { + matches!(c.as_ref(), Constant::String(_)) + } else { + false + } + }), + + DefaultFunction::EqualsData | DefaultFunction::SerialiseData => { arg_stack.iter().all(|arg| { if let Term::Constant(c) = arg { - matches!(c.as_ref(), Constant::String(_)) + matches!(c.as_ref(), Constant::Data(_)) } else { false } }) } - DefaultFunction::EqualsData => arg_stack.iter().all(|arg| { - if let Term::Constant(c) = arg { - matches!(c.as_ref(), Constant::Data(_)) - } else { - false - } - }), - DefaultFunction::Bls12_381_G1_Equal | DefaultFunction::Bls12_381_G1_Add => { arg_stack.iter().all(|arg| { if let Term::Constant(c) = arg { @@ -337,6 +344,10 @@ impl DefaultFunction { _ => false, } } + + pub fn wrapped_name(self) -> String { + format!("__{}_wrapped", self.aiken_name()) + } } #[derive(PartialEq, Clone, Debug)] @@ -1102,7 +1113,7 @@ impl Term { } } Term::Delay(body) => { - let not_forced: isize = isize::from(force_stack.pop().is_none()); + let not_forced = isize::from(force_stack.pop().is_none()); body.var_occurrences(search_for, arg_stack, force_stack) .delay_if_found(not_forced) @@ -1125,11 +1136,34 @@ impl Term { } } Term::Apply { function, argument } => { + // unwrap apply and add void to arg stack! arg_stack.push(()); - function - .var_occurrences(search_for.clone(), arg_stack, force_stack) - .combine(argument.var_occurrences(search_for, vec![], vec![])) + let apply_var_occurrence_stack = |term: &Term, arg_stack: Vec<()>| { + term.var_occurrences(search_for.clone(), arg_stack, force_stack) + }; + + let apply_var_occurrence_no_stack = + |term: &Term| term.var_occurrences(search_for.clone(), vec![], vec![]); + + if let Term::Apply { + function: next_func, + argument: next_arg, + } = function.as_ref() + { + // unwrap apply and add void to arg stack! + arg_stack.push(()); + next_func.carry_args_to_branch( + next_arg, + argument, + arg_stack, + apply_var_occurrence_stack, + apply_var_occurrence_no_stack, + ) + } else { + apply_var_occurrence_stack(function, arg_stack) + .combine(apply_var_occurrence_no_stack(argument)) + } } Term::Force(x) => { force_stack.push(()); @@ -1141,6 +1175,81 @@ impl Term { } } + // This handles the very common case of (if condition then body else error) + // or (if condition then error else body) + // In this case it is fine to treat the body as if it is not delayed + // since the other branch is error + fn carry_args_to_branch( + &self, + then_arg: &Rc>, + else_arg: &Rc>, + mut arg_stack: Vec<()>, + var_occurrence_stack: impl FnOnce(&Term, Vec<()>) -> VarLookup, + var_occurrence_no_stack: impl Fn(&Term) -> VarLookup, + ) -> VarLookup { + let Term::Apply { + function: builtin, + argument: condition, + } = self + else { + return var_occurrence_stack(self, arg_stack) + .combine(var_occurrence_no_stack(then_arg)) + .combine(var_occurrence_no_stack(else_arg)); + }; + + // unwrap apply and add void to arg stack! + arg_stack.push(()); + + let Term::Delay(else_arg) = else_arg.as_ref() else { + return var_occurrence_stack(builtin, arg_stack) + .combine(var_occurrence_no_stack(condition)) + .combine(var_occurrence_no_stack(then_arg)) + .combine(var_occurrence_no_stack(else_arg)); + }; + + let Term::Delay(then_arg) = then_arg.as_ref() else { + return var_occurrence_stack(builtin, arg_stack) + .combine(var_occurrence_no_stack(condition)) + .combine(var_occurrence_no_stack(then_arg)) + .combine(var_occurrence_no_stack(else_arg)); + }; + + match builtin.as_ref() { + Term::Var(a) + if a.text == DefaultFunction::IfThenElse.wrapped_name() + || a.text == DefaultFunction::ChooseList.wrapped_name() => + { + if matches!(else_arg.as_ref(), Term::Error) { + // Pop 3 args of arg_stack due to branch execution + arg_stack.pop(); + arg_stack.pop(); + arg_stack.pop(); + + var_occurrence_no_stack(condition) + .combine(var_occurrence_stack(then_arg, arg_stack)) + } else if matches!(then_arg.as_ref(), Term::Error) { + // Pop 3 args of arg_stack due to branch execution + arg_stack.pop(); + arg_stack.pop(); + arg_stack.pop(); + + var_occurrence_no_stack(condition) + .combine(var_occurrence_stack(else_arg, arg_stack)) + } else { + var_occurrence_stack(builtin, arg_stack) + .combine(var_occurrence_no_stack(condition)) + .combine(var_occurrence_no_stack(then_arg)) + .combine(var_occurrence_no_stack(else_arg)) + } + } + + _ => var_occurrence_stack(builtin, arg_stack) + .combine(var_occurrence_no_stack(condition)) + .combine(var_occurrence_no_stack(then_arg)) + .combine(var_occurrence_no_stack(else_arg)), + } + } + fn lambda_reducer( &mut self, _id: Option, @@ -1207,7 +1316,7 @@ impl Term { if has_forces { context.builtins_map.insert(*func as u8, ()); - *self = Term::var(format!("__{}_wrapped", func.aiken_name())); + *self = Term::var(func.wrapped_name()); } } } @@ -1269,44 +1378,37 @@ impl Term { } => { let body = Rc::make_mut(body); // pops stack here no matter what - let temp = Term::Error; - if let ( - arg_id, - Term::Lambda { - parameter_name: identity_name, - body: identity_body, - }, - ) = match &arg_stack.pop() { - Some(Args::Apply( - arg_id, - Term::Lambda { - parameter_name: inline_name, - body, - }, - )) if inline_name.text == NO_INLINE => (*arg_id, body.as_ref()), - Some(Args::Apply(arg_id, term)) => (*arg_id, term), - _ => (0, &temp), - } { - let Term::Var(identity_var) = identity_body.as_ref() else { - return false; - }; + let Some(Args::Apply(arg_id, identity_func)) = arg_stack.pop() else { + return false; + }; + + let Term::Lambda { + parameter_name: identity_name, + body: identity_body, + } = identity_func.pierce_no_inlines() + else { + return false; + }; - if identity_var.text == identity_name.text - && identity_var.unique == identity_name.unique + let Term::Var(identity_var) = identity_body.as_ref() else { + return false; + }; + + if identity_var.text == identity_name.text + && identity_var.unique == identity_name.unique + { + // Replace all applied usages of identity with the arg + body.replace_identity_usage(parameter_name.clone()); + // Have to check if the body still has any occurrences of the parameter + // After attempting replacement + if !body + .var_occurrences(parameter_name.clone(), vec![], vec![]) + .found { - // Replace all applied usages of identity with the arg - body.replace_identity_usage(parameter_name.clone()); - // Have to check if the body still has any occurrences of the parameter - // After attempting replacement - if !body - .var_occurrences(parameter_name.clone(), vec![], vec![]) - .found - { - changed = true; - context.inlined_apply_ids.push(arg_id); - *self = std::mem::replace(body, Term::Error.force()); - } + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = std::mem::replace(body, Term::Error.force()); } } } @@ -1333,53 +1435,44 @@ impl Term { body, } => { // pops stack here no matter what - if let Some(Args::Apply(arg_id, arg_term)) = arg_stack.pop() { - let arg_term = match &arg_term { - Term::Lambda { - parameter_name, - body, - } if parameter_name.text == NO_INLINE => body.as_ref().clone(), - _ => arg_term, - }; + let Some(Args::Apply(arg_id, arg_term)) = arg_stack.pop() else { + return false; + }; - let body = Rc::make_mut(body); + let arg_term = arg_term.pierce_no_inlines(); - let var_lookup = body.var_occurrences(parameter_name.clone(), vec![], vec![]); + let body = Rc::make_mut(body); - let substitute_condition = (var_lookup.delays == 0 && !var_lookup.no_inline) - || matches!( - &arg_term, - Term::Var(_) - | Term::Constant(_) - | Term::Delay(_) - | Term::Lambda { .. } - | Term::Builtin(_), - ); + let var_lookup = body.var_occurrences(parameter_name.clone(), vec![], vec![]); - if var_lookup.occurrences == 1 && substitute_condition { - changed = true; - body.substitute_var(parameter_name.clone(), arg_term.pierce_no_inlines()); + let must_execute_condition = var_lookup.delays == 0 && !var_lookup.no_inline; - context.inlined_apply_ids.push(arg_id); - *self = std::mem::replace(body, Term::Error.force()); + let cant_throw_condition = matches!( + arg_term, + Term::Var(_) + | Term::Constant(_) + | Term::Delay(_) + | Term::Lambda { .. } + | Term::Builtin(_), + ); - // This will strip out unused terms that can't throw an error by themselves - } else if !var_lookup.found - && matches!( - arg_term, - Term::Var(_) - | Term::Constant(_) - | Term::Delay(_) - | Term::Lambda { .. } - | Term::Builtin(_) - ) - { - changed = true; - context.inlined_apply_ids.push(arg_id); - *self = std::mem::replace(body, Term::Error.force()); - } + // This will inline terms that only occur once + // if they are guaranteed to execute or can't throw an error by themselves + if var_lookup.occurrences == 1 && (must_execute_condition || cant_throw_condition) { + changed = true; + body.substitute_var(parameter_name.clone(), arg_term); + + context.inlined_apply_ids.push(arg_id); + *self = std::mem::replace(body, Term::Error.force()); + + // This will strip out unused terms that can't throw an error by themselves + } else if !var_lookup.found && cant_throw_condition { + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = std::mem::replace(body, Term::Error.force()); } } + Term::Constr { .. } => todo!(), Term::Case { .. } => todo!(), _ => {} @@ -1647,10 +1740,7 @@ impl Term { term.pierce_no_inlines() }) .collect_vec(); - if func.can_curry_builtin() - && arg_stack.len() == func.arity() - && func.is_error_safe(&args) - { + if arg_stack.len() == func.arity() && func.is_error_safe(&args) { changed = true; let applied_term = arg_stack @@ -1671,7 +1761,7 @@ impl Term { } .to_named_debruijn() .unwrap() - .eval(ExBudget::max()) + .eval(ExBudget::default()) .result() .unwrap() .try_into() @@ -1827,13 +1917,17 @@ impl Program { for default_func_index in context.builtins_map.keys().sorted().cloned() { let default_func: DefaultFunction = default_func_index.try_into().unwrap(); - term = term - .lambda(format!("__{}_wrapped", default_func.aiken_name())) - .apply(if default_func.force_count() == 1 { - Term::Builtin(default_func).force() - } else { - Term::Builtin(default_func).force().force() - }); + term = term.lambda(default_func.wrapped_name()); + } + + for default_func_index in context.builtins_map.keys().sorted().cloned().rev() { + let default_func: DefaultFunction = default_func_index.try_into().unwrap(); + + term = term.apply(if default_func.force_count() == 1 { + Term::Builtin(default_func).force() + } else { + Term::Builtin(default_func).force().force() + }); } let mut program = Program { @@ -2165,10 +2259,11 @@ fn is_a_builtin_wrapper(term: &Term) -> bool { while let Term::Apply { function, argument } = term { match argument.as_ref() { - Term::Var(name) => arg_names.push(name), + Term::Var(name) => arg_names.push(format!("{}_{}", name.text, name.unique)), Term::Constant(_) => {} _ => { + //Break loop, it's not a builtin wrapper function return false; } } @@ -2178,7 +2273,7 @@ fn is_a_builtin_wrapper(term: &Term) -> bool { arg_names.iter().all(|item| names.contains(item)) && matches!(term, Term::Builtin(_)) } -fn pop_lambdas_and_get_names(term: &Term) -> (Vec>, &Term) { +fn pop_lambdas_and_get_names(term: &Term) -> (Vec, &Term) { let mut names = vec![]; let mut term = term; @@ -2189,7 +2284,7 @@ fn pop_lambdas_and_get_names(term: &Term) -> (Vec>, &Term) } = term { if parameter_name.text != NO_INLINE { - names.push(parameter_name.clone()); + names.push(format!("{}_{}", parameter_name.text, parameter_name.unique)); } term = body.as_ref(); } @@ -2365,11 +2460,11 @@ mod tests { .lambda("y") // Forces are automatically applied by builder .lambda("__cons_list_wrapped") - .apply(Term::mk_cons()) .lambda("__head_list_wrapped") - .apply(Term::head_list()) .lambda("__tail_list_wrapped") - .apply(Term::tail_list()), + .apply(Term::tail_list()) + .apply(Term::head_list()) + .apply(Term::mk_cons()), }; compare_optimization(expected, program, |p| p.run_once_pass()); @@ -2453,9 +2548,9 @@ mod tests { .apply(Term::data(Data::integer(5.into()))), ) .lambda("__fst_pair_wrapped") - .apply(Term::fst_pair()) .lambda("__snd_pair_wrapped") - .apply(Term::snd_pair()), + .apply(Term::snd_pair()) + .apply(Term::fst_pair()), }; compare_optimization(expected, program, |p| p.run_once_pass()); @@ -2644,6 +2739,80 @@ mod tests { }); } + #[test] + fn inline_reduce_if_then_else_then() { + let program: Program = Program { + version: (1, 0, 0), + term: Term::var("__if_then_else_wrapped") + .apply(Term::bool(true)) + .apply(Term::sha3_256().apply(Term::var("x")).delay()) + .apply(Term::Error.delay()) + .force() + .lambda("x") + .apply(Term::sha3_256().apply(Term::byte_string(vec![]))) + .lambda("__if_then_else_wrapped") + .apply(Term::Builtin(DefaultFunction::IfThenElse).force()), + }; + + let expected = Program { + version: (1, 0, 0), + term: Term::var("__if_then_else_wrapped") + .apply(Term::bool(true)) + .apply( + Term::sha3_256() + .apply(Term::sha3_256().apply(Term::byte_string(vec![]))) + .delay(), + ) + .apply(Term::Error.delay()) + .force() + .lambda("__if_then_else_wrapped") + .apply(Term::Builtin(DefaultFunction::IfThenElse).force()), + }; + + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.inline_reducer(id, arg_stack, scope, context); + }) + }); + } + + #[test] + fn inline_reduce_if_then_else_else() { + let program: Program = Program { + version: (1, 0, 0), + term: Term::var("__if_then_else_wrapped") + .apply(Term::bool(true)) + .apply(Term::Error.delay()) + .apply(Term::sha3_256().apply(Term::var("x")).delay()) + .force() + .lambda("x") + .apply(Term::sha3_256().apply(Term::byte_string(vec![]))) + .lambda("__if_then_else_wrapped") + .apply(Term::Builtin(DefaultFunction::IfThenElse).force()), + }; + + let expected = Program { + version: (1, 0, 0), + term: Term::var("__if_then_else_wrapped") + .apply(Term::bool(true)) + .apply(Term::Error.delay()) + .apply( + Term::sha3_256() + .apply(Term::sha3_256().apply(Term::byte_string(vec![]))) + .delay(), + ) + .force() + .lambda("__if_then_else_wrapped") + .apply(Term::Builtin(DefaultFunction::IfThenElse).force()), + }; + + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.inline_reducer(id, arg_stack, scope, context); + }) + }); + } + #[test] fn inline_reduce_0_occurrence() { let program: Program = Program {