diff --git a/crates/aiken-lang/src/builtins.rs b/crates/aiken-lang/src/builtins.rs index 08179ba33..346f1dbea 100644 --- a/crates/aiken-lang/src/builtins.rs +++ b/crates/aiken-lang/src/builtins.rs @@ -509,12 +509,8 @@ pub fn plutus(id_gen: &IdGenerator) -> TypeInfo { }; for builtin in DefaultFunction::iter() { - // FIXME: Disabling WriteBits for now, since its signature requires the ability to create - // list of raw integers, which isn't possible through Aiken at the moment. - if !matches!(builtin, DefaultFunction::WriteBits) { - let value = from_default_function(builtin, id_gen); - plutus.values.insert(builtin.aiken_name(), value); - } + let value = from_default_function(builtin, id_gen); + plutus.values.insert(builtin.aiken_name(), value); } let index_tipo = Type::function(vec![Type::data()], Type::int()); diff --git a/crates/aiken-lang/src/gen_uplc.rs b/crates/aiken-lang/src/gen_uplc.rs index 588edc00c..160ea970f 100644 --- a/crates/aiken-lang/src/gen_uplc.rs +++ b/crates/aiken-lang/src/gen_uplc.rs @@ -3712,7 +3712,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.clean_up().try_into().unwrap(); + program.clean_up_no_inlines().try_into().unwrap(); Some( eval_program @@ -3822,7 +3822,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.clean_up().try_into().unwrap(); + program.clean_up_no_inlines().try_into().unwrap(); let evaluated_term: Term = eval_program .eval(ExBudget::default()) @@ -4028,7 +4028,7 @@ impl<'a> CodeGenerator<'a> { } else { let term = arg_stack.pop().unwrap(); - match term.pierce_no_inlines() { + match term.pierce_no_inlines_ref() { Term::Var(_) => Some(term.force()), Term::Delay(inner_term) => Some(inner_term.as_ref().clone()), Term::Apply { .. } => Some(term.force()), @@ -4356,7 +4356,7 @@ impl<'a> CodeGenerator<'a> { known_data_to_type(term, &tipo) }; - if extract_constant(term.pierce_no_inlines()).is_some() { + if extract_constant(term.pierce_no_inlines_ref()).is_some() { let mut program = self.new_program(term); let mut interner = CodeGenInterner::new(); @@ -4364,7 +4364,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.clean_up().try_into().unwrap(); + program.clean_up_no_inlines().try_into().unwrap(); let evaluated_term: Term = eval_program .eval(ExBudget::default()) @@ -4379,7 +4379,7 @@ impl<'a> CodeGenerator<'a> { Air::CastToData { tipo } => { let mut term = arg_stack.pop().unwrap(); - if extract_constant(term.pierce_no_inlines()).is_some() { + if extract_constant(term.pierce_no_inlines_ref()).is_some() { term = builder::convert_type_to_data(term, &tipo); let mut program = self.new_program(term); @@ -4389,7 +4389,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.clean_up().try_into().unwrap(); + program.clean_up_no_inlines().try_into().unwrap(); let evaluated_term: Term = eval_program .eval(ExBudget::default()) @@ -4792,7 +4792,7 @@ impl<'a> CodeGenerator<'a> { .apply(term); if arg_vec.iter().all(|item| { - let maybe_const = extract_constant(item.pierce_no_inlines()); + let maybe_const = extract_constant(item.pierce_no_inlines_ref()); maybe_const.is_some() }) { let mut program = self.new_program(term); @@ -4802,7 +4802,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.clean_up().try_into().unwrap(); + program.clean_up_no_inlines().try_into().unwrap(); let evaluated_term: Term = eval_program .eval(ExBudget::default()) diff --git a/crates/aiken-project/src/snapshots/aiken_project__export__tests__basic_export.snap b/crates/aiken-project/src/snapshots/aiken_project__export__tests__basic_export.snap index ed1fd9f89..46632a51d 100644 --- a/crates/aiken-project/src/snapshots/aiken_project__export__tests__basic_export.snap +++ b/crates/aiken-project/src/snapshots/aiken_project__export__tests__basic_export.snap @@ -1,7 +1,6 @@ --- source: crates/aiken-project/src/export.rs description: "Code:\n\npub fn add(a: Int, b: Int) -> Int {\n a + b\n}\n" -snapshot_kind: text --- { "name": "test_module.add", @@ -25,8 +24,8 @@ snapshot_kind: text "$ref": "#/definitions/Int" } }, - "compiledCode": "500101002322337000046eb4004dd68009", - "hash": "b8374597a772cef80d891b7f6a03588e10cc19b780251228ba4ce9c6", + "compiledCode": "500101002232337000026eb4008dd68011", + "hash": "e5951afb3263ef11acc0b4c88cd5f5b30b8621ce63fe024b3ea2bec8", "definitions": { "Int": { "dataType": "integer" diff --git a/crates/aiken-project/src/snapshots/aiken_project__export__tests__recursive_types.snap b/crates/aiken-project/src/snapshots/aiken_project__export__tests__recursive_types.snap index 586b057ca..8c4d40af9 100644 --- a/crates/aiken-project/src/snapshots/aiken_project__export__tests__recursive_types.snap +++ b/crates/aiken-project/src/snapshots/aiken_project__export__tests__recursive_types.snap @@ -1,7 +1,6 @@ --- source: crates/aiken-project/src/export.rs description: "Code:\n\npub type Foo {\n Empty\n Bar(a, Foo)\n}\n\npub fn add(a: Foo, b: Foo) -> Int {\n when (a, b) is {\n (Empty, Empty) -> 0\n (Bar(x, y), Bar(c, d)) -> x + c + add(y, d)\n (Empty, Bar(c, d)) -> c + add(Empty, d)\n (Bar(x, y), Empty) -> x + add(y, Empty)\n }\n}\n" -snapshot_kind: text --- { "name": "test_module.add", @@ -25,8 +24,8 @@ snapshot_kind: text "$ref": "#/definitions/Int" } }, - "compiledCode": "59017d0101003232323232322232323232325333008300430093754002264a666012600a60146ea800452000132337006eb4c038004cc011300103d8798000300e300f001300b37540026018601a00a264a66601266e1d2002300a37540022646466e00cdc01bad300f002375a601e0026600a601e6020004601e602000260186ea8008c02cdd500109919b80375a601c00266008601c601e002980103d8798000300b37540046018601a00a601600860020024446464a666014600c60166ea80044c94ccc02cc01cc030dd50008a400026466e00dd69808000999803803a60103d879800030103011001300d3754002601c601e004264a66601666e1d2002300c37540022646466e00cdc01bad3011002375a60220026660100106022602400460226024002601c6ea8008c034dd500109919b80375a602000266600e00e60206022002980103d8798000300d3754004601c601e004601a002660160046601600297ae0370e90001980300119803000a5eb815cd2ab9d5573cae815d0aba201", - "hash": "c6af3f04e300cb8c1d0429cc0d8e56a0413eef9fcb338f72076b426c", + "compiledCode": "590186010100229800aba2aba1aba0aab9eaab9dab9a9b874800122222223322332259800980298039baa0018992cc004c018c020dd5000c5200089919b80375a60180026600898103d8798000300c300d001300937540028038c028c02c012264b30013370e900118041baa0018999119b80337006eb4c034008dd6980680099802980698070011806980700098049baa00230093754003132337006eb4c030004cc010c030c03400530103d8798000300937540048038c028c02c01100618008009804001198028049980280425eb80888c8c966002600c60106ea8006264b300130073009375400314800226466e00dd69806800cc00401e98103d879800098069807000a00e300a37540028040c02cc03000a264b30013370e900118049baa0018999119b80337006eb4c038008dd69807000cc004022601c601e005300e300f001402060146ea8008c028dd5000c4c8cdc01bad300d0019800803cc034c038006980103d8798000401c60146ea800900818059806001200e300a00133008002330080014bd701", + "hash": "dc9b9c2bbcfb1cb422534ed1c4d04f2e2b9b57a0a498175d055f83e8", "definitions": { "Int": { "dataType": "integer" diff --git a/crates/aiken-project/src/tests/gen_uplc.rs b/crates/aiken-project/src/tests/gen_uplc.rs index 8a5506a5f..73f91e239 100644 --- a/crates/aiken-project/src/tests/gen_uplc.rs +++ b/crates/aiken-project/src/tests/gen_uplc.rs @@ -3603,7 +3603,7 @@ fn when_bool_is_true() { assert_uplc( src, Term::var("subject") - .delayed_if_then_else(Term::bool(true), Term::Error) + .delayed_if_then_else(Term::bool(true), Term::Error.delay().force()) .lambda("subject") .apply(Term::bool(true)), false, @@ -3627,7 +3627,7 @@ fn when_bool_is_true_switched_cases() { assert_uplc( src, Term::var("subject") - .delayed_if_then_else(Term::bool(true), Term::Error) + .delayed_if_then_else(Term::bool(true), Term::Error.delay().force()) .lambda("subject") .apply(Term::bool(true)), false, @@ -3651,7 +3651,7 @@ fn when_bool_is_false() { assert_uplc( src, Term::var("subject") - .delayed_if_then_else(Term::bool(true), Term::Error) + .delayed_if_then_else(Term::bool(true), Term::Error.delay().force()) .lambda("subject") .apply(Term::bool(false)), true, @@ -4088,16 +4088,16 @@ fn generic_validator_type_test() { Term::tail_list() .apply(Term::Var(tail_id_5.clone())) .as_var("tail_id_6", |tail_id_6| { - Term::head_list() + Term::tail_list() .apply(Term::Var(tail_id_6.clone())) - .as_var("__val", |val| { - Term::tail_list() + .delayed_choose_list( + Term::head_list() .apply(Term::Var(tail_id_6)) - .delayed_choose_list( - expect_b(val, Term::Var(then_delayed), trace), - Term::Error, - ) - }) + .as_var("__val", |val| { + expect_b(val, Term::Var(then_delayed), trace) + }), + Term::Error, + ) }) } }); diff --git a/crates/uplc/src/ast.rs b/crates/uplc/src/ast.rs index 17bcf6827..82c78eecb 100644 --- a/crates/uplc/src/ast.rs +++ b/crates/uplc/src/ast.rs @@ -517,7 +517,7 @@ impl hash::Hash for Name { impl PartialEq for Name { fn eq(&self, other: &Self) -> bool { - self.unique == other.unique + self.unique == other.unique && self.text == other.text } } diff --git a/crates/uplc/src/builder.rs b/crates/uplc/src/builder.rs index e4174c1c6..abc187961 100644 --- a/crates/uplc/src/builder.rs +++ b/crates/uplc/src/builder.rs @@ -9,6 +9,7 @@ pub const CONSTR_FIELDS_EXPOSER: &str = "__constr_fields_exposer"; pub const CONSTR_INDEX_EXPOSER: &str = "__constr_index_exposer"; pub const EXPECT_ON_LIST: &str = "__expect_on_list"; pub const INNER_EXPECT_ON_LIST: &str = "__inner_expect_on_list"; +pub const INDICES_CONVERTER: &str = "__indices_converter"; impl Term where @@ -30,6 +31,17 @@ where Term::Delay(self.into()) } + pub fn constr(tag: usize, fields: Vec>) -> Self { + Term::Constr { tag, fields } + } + + pub fn case(self, branches: Vec>) -> Self { + Term::Case { + constr: self.into(), + branches, + } + } + // Primitives pub fn integer(i: num_bigint::BigInt) -> Self { Term::Constant(Constant::Integer(i).into()) @@ -71,6 +83,10 @@ where Term::Constant(Constant::ProtoList(Type::Data, vals).into()) } + pub fn int_values(vals: Vec) -> Self { + Term::Constant(Constant::ProtoList(Type::Integer, vals).into()) + } + pub fn empty_map() -> Self { Term::Constant( Constant::ProtoList(Type::Pair(Type::Data.into(), Type::Data.into()), vec![]).into(), @@ -392,6 +408,10 @@ where pub fn serialise_data() -> Self { Term::Builtin(DefaultFunction::SerialiseData) } + + pub fn write_bits() -> Self { + Term::Builtin(DefaultFunction::WriteBits) + } } impl Term @@ -535,6 +555,33 @@ impl Term { ) } + pub fn data_list_to_integer_list(self) -> Self { + self.lambda(INDICES_CONVERTER) + .apply(Term::var(INDICES_CONVERTER).apply(Term::var(INDICES_CONVERTER))) + .lambda(INDICES_CONVERTER) + .apply( + Term::var("xs") + .delayed_choose_list( + Term::int_values(vec![]), + Term::mk_cons() + .apply(Term::var("x")) + .apply( + Term::var(INDICES_CONVERTER) + .apply(Term::var(INDICES_CONVERTER)) + .apply(Term::var("rest")), + ) + .lambda("rest") + .apply(Term::tail_list().apply(Term::var("xs"))) + .lambda("x") + .apply( + Term::un_i_data().apply(Term::head_list().apply(Term::var("xs"))), + ), + ) + .lambda("xs") + .lambda(INDICES_CONVERTER), + ) + } + /// Introduce a let-binding for a given term. The callback receives a Term::Var /// whose name matches the given 'var_name'. Handy to re-use a same var across /// multiple lambda expressions. diff --git a/crates/uplc/src/builtins.rs b/crates/uplc/src/builtins.rs index 574bda470..6f315702e 100644 --- a/crates/uplc/src/builtins.rs +++ b/crates/uplc/src/builtins.rs @@ -6,7 +6,19 @@ use strum_macros::EnumIter; /// All the possible builtin functions in Untyped Plutus Core. #[repr(u8)] #[allow(non_camel_case_types)] -#[derive(Debug, Clone, PartialEq, Eq, Copy, EnumIter, serde::Serialize, serde::Deserialize)] +#[derive( + Debug, + Clone, + PartialEq, + Eq, + Copy, + EnumIter, + serde::Serialize, + serde::Deserialize, + Hash, + PartialOrd, + Ord, +)] pub enum DefaultFunction { // Integer functions AddInteger = 0, diff --git a/crates/uplc/src/optimize.rs b/crates/uplc/src/optimize.rs index fd00a2c9d..bbc0a1c5e 100644 --- a/crates/uplc/src/optimize.rs +++ b/crates/uplc/src/optimize.rs @@ -38,5 +38,5 @@ pub fn aiken_optimize_and_intern(program: Program) -> Program { } } - prog.clean_up() + prog.clean_up_no_inlines().afterwards() } diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index 8dacdd069..6f0144614 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -1,13 +1,13 @@ use super::interner::CodeGenInterner; use crate::{ ast::{Constant, Data, Name, NamedDeBruijn, Program, Term, Type}, - builder::{CONSTR_FIELDS_EXPOSER, CONSTR_INDEX_EXPOSER}, + builder::{CONSTR_FIELDS_EXPOSER, CONSTR_INDEX_EXPOSER, INDICES_CONVERTER}, builtins::DefaultFunction, - machine::{cost_model::ExBudget, runtime::Compressable}, + machine::{cost_model::ExBudget, runtime::Compressable, value::from_pallas_bigint}, }; use blst::{blst_p1, blst_p2}; use indexmap::IndexMap; -use itertools::Itertools; +use itertools::{FoldWhile, Itertools}; use pallas_primitives::conway::{BigInt, PlutusData}; use std::{cmp::Ordering, iter, ops::Neg, rc::Rc}; @@ -90,8 +90,8 @@ pub const NO_INLINE: &str = "__no_inline__"; #[derive(PartialEq, PartialOrd, Default, Debug, Clone)] pub struct VarLookup { found: bool, - occurrences: isize, - delays: isize, + occurrences: usize, + delays: usize, no_inline: bool, } @@ -123,7 +123,7 @@ impl VarLookup { } } - pub fn delay_if_found(self, delay_amount: isize) -> Self { + pub fn delay_if_found(self, delay_amount: usize) -> Self { if self.found { Self { found: self.found, @@ -167,6 +167,7 @@ impl DefaultFunction { ) } /// For now all of the curry builtins are not forceable + /// Curryable builtins must take in 2 or more arguments pub fn can_curry_builtin(self) -> bool { matches!( self, @@ -205,7 +206,8 @@ impl DefaultFunction { | DefaultFunction::MultiplyInteger | DefaultFunction::EqualsInteger | DefaultFunction::LessThanInteger - | DefaultFunction::LessThanEqualsInteger => arg_stack.iter().all(|arg| { + | DefaultFunction::LessThanEqualsInteger + | DefaultFunction::IData => arg_stack.iter().all(|arg| { if let Term::Constant(c) = arg { matches!(c.as_ref(), Constant::Integer(_)) } else { @@ -226,10 +228,12 @@ impl DefaultFunction { false } }), - DefaultFunction::EqualsByteString + DefaultFunction::LengthOfByteString + | DefaultFunction::EqualsByteString | DefaultFunction::AppendByteString | DefaultFunction::LessThanEqualsByteString - | DefaultFunction::LessThanByteString => arg_stack.iter().all(|arg| { + | DefaultFunction::LessThanByteString + | DefaultFunction::BData => arg_stack.iter().all(|arg| { if let Term::Constant(c) = arg { matches!(c.as_ref(), Constant::ByteString(_)) } else { @@ -282,24 +286,26 @@ impl DefaultFunction { } } - DefaultFunction::EqualsString | DefaultFunction::AppendString => { + DefaultFunction::EqualsString + | DefaultFunction::AppendString + | DefaultFunction::EncodeUtf8 => arg_stack.iter().all(|arg| { + if let Term::Constant(c) = arg { + matches!(c.as_ref(), Constant::String(_)) + } else { + false + } + }), + + DefaultFunction::EqualsData | DefaultFunction::SerialiseData => { arg_stack.iter().all(|arg| { if let Term::Constant(c) = arg { - matches!(c.as_ref(), Constant::String(_)) + matches!(c.as_ref(), Constant::Data(_)) } else { false } }) } - DefaultFunction::EqualsData => arg_stack.iter().all(|arg| { - if let Term::Constant(c) = arg { - matches!(c.as_ref(), Constant::Data(_)) - } else { - false - } - }), - DefaultFunction::Bls12_381_G1_Equal | DefaultFunction::Bls12_381_G1_Add => { arg_stack.iter().all(|arg| { if let Term::Constant(c) = arg { @@ -337,6 +343,10 @@ impl DefaultFunction { _ => false, } } + + pub fn wrapped_name(self) -> String { + format!("__{}_wrapped", self.aiken_name()) + } } #[derive(PartialEq, Clone, Debug)] @@ -562,28 +572,32 @@ impl CurriedArgs { mut fst_args, mut snd_args, }, - BuiltinArgs::TwoArgsAnyOrder { fst, snd }, + BuiltinArgs::TwoArgsAnyOrder { mut fst, snd }, ) => { - let mut switched = false; - let fst_args = if fst_args.iter_mut().any(|item| item.term == fst.1) { - fst_args + let (switched, fst_args) = if fst_args.iter_mut().any(|item| item.term == fst.1) { + (false, fst_args) } else if fst_args.iter_mut().any(|item| match &snd { Some(snd) => item.term == snd.1, None => false, }) { - switched = true; - fst_args + (true, fst_args) } else { fst_args.push(CurriedNode { id: fst.0, - term: fst.1.clone(), + // Replace the value here instead of cloning since + // switched must be false here + // I use Term::Error.force() since it's not a + // naturally occurring term in code gen. + term: std::mem::replace(&mut fst.1, Term::Error.force()), }); - fst_args + (false, fst_args) }; // If switched then put the first arg in the second arg slot let snd_args = if switched { + assert!(fst.1 != Term::Error.force()); + if snd_args.iter_mut().any(|item| item.term == fst.1) { snd_args } else { @@ -669,12 +683,13 @@ impl CurriedArgs { } } - fn get_id_args(&self, path: &BuiltinArgs) -> Option> { + // TODO: switch clones to memory moves out of path + fn get_id_args(&self, path: BuiltinArgs) -> Option> { match (self, path) { (CurriedArgs::TwoArgs { fst_args, snd_args }, BuiltinArgs::TwoArgs { fst, snd }) => { let arg = fst_args.iter().find(|item| fst.1 == item.term)?; - let Some(arg2) = snd_args.iter().find(|item| match snd { + let Some(arg2) = snd_args.iter().find(|item| match &snd { Some(snd) => item.term == snd.1, None => false, }) else { @@ -711,7 +726,7 @@ impl CurriedArgs { term: arg.term.clone(), }); - let Some(arg2) = snd_args.iter().find(|item| match snd { + let Some(arg2) = snd_args.iter().find(|item| match &snd { Some(snd) => snd.1 == item.term, None => false, }) else { @@ -761,7 +776,7 @@ impl CurriedArgs { ) => { let arg = fst_args.iter().find(|item| fst.1 == item.term)?; - let Some(arg2) = snd_args.iter().find(|item| match snd { + let Some(arg2) = snd_args.iter().find(|item| match &snd { Some(snd) => item.term == snd.1, None => false, }) else { @@ -772,7 +787,7 @@ impl CurriedArgs { }]); }; - let Some(arg3) = thd_args.iter().find(|item| match thd { + let Some(arg3) = thd_args.iter().find(|item| match &thd { Some(thd) => item.term == thd.1, None => false, }) else { @@ -844,7 +859,7 @@ impl CurriedBuiltin { } } - pub fn get_id_args(&self, path: &BuiltinArgs) -> Option> { + pub fn get_id_args(&self, path: BuiltinArgs) -> Option> { self.args.get_id_args(path) } @@ -857,9 +872,11 @@ impl CurriedBuiltin { pub struct Context { pub inlined_apply_ids: Vec, pub constants_to_flip: Vec, - pub builtins_map: IndexMap, + pub write_bits_indices_arg: Vec, + pub builtins_map: IndexMap, pub blst_p1_list: Vec, pub blst_p2_list: Vec, + pub write_bits_convert: bool, pub node_count: usize, } @@ -893,6 +910,7 @@ impl Term { ); let apply_id = id_gen.next_id(); + // Here we must clone since we must leave the original AST alone arg_stack.push(Args::Apply(apply_id, arg.clone())); let func = Rc::make_mut(function); @@ -962,7 +980,7 @@ impl Term { Term::Lambda { parameter_name, body, - } if parameter_name.text == p.text && parameter_name.unique == p.unique => { + } if *parameter_name == p => { let body = Rc::make_mut(body); body.traverse_uplc_with_helper( scope, @@ -973,9 +991,6 @@ impl Term { inline_lambda, ); } - - Term::Constr { .. } => todo!(), - Term::Case { .. } => todo!(), other => other.traverse_uplc_with_helper( scope, arg_stack, @@ -1001,8 +1016,53 @@ impl Term { } } - Term::Case { .. } => todo!(), - Term::Constr { .. } => todo!(), + Term::Case { constr, branches } => { + let constr = Rc::make_mut(constr); + constr.traverse_uplc_with_helper( + scope, + vec![], + id_gen, + with, + context, + inline_lambda, + ); + + if branches.len() == 1 { + // save a potentially big clone + // where currently all cases will be 1 branch + branches[0].traverse_uplc_with_helper( + scope, + arg_stack, + id_gen, + with, + context, + inline_lambda, + ); + } else { + for branch in branches { + branch.traverse_uplc_with_helper( + scope, + arg_stack.clone(), + id_gen, + with, + context, + inline_lambda, + ); + } + } + } + Term::Constr { fields, .. } => { + for field in fields { + field.traverse_uplc_with_helper( + scope, + vec![], + id_gen, + with, + context, + inline_lambda, + ); + } + } Term::Builtin(func) => { let mut args = vec![]; @@ -1024,16 +1084,14 @@ impl Term { fn substitute_var(&mut self, original: Rc, replace_with: &Term) { match self { - Term::Var(name) if name.text == original.text && name.unique == original.unique => { + Term::Var(name) if *name == original => { *self = replace_with.clone(); } Term::Delay(body) => Rc::make_mut(body).substitute_var(original, replace_with), Term::Lambda { parameter_name, body, - } if parameter_name.text != original.text - || parameter_name.unique != original.unique => - { + } if *parameter_name != original => { Rc::make_mut(body).substitute_var(original, replace_with); } Term::Apply { function, argument } => { @@ -1058,8 +1116,7 @@ impl Term { parameter_name, body, } => { - if parameter_name.text != original.text || parameter_name.unique != original.unique - { + if *parameter_name != original { Rc::make_mut(body).replace_identity_usage(original.clone()); } } @@ -1074,7 +1131,7 @@ impl Term { return; }; - if name.text == original.text && name.unique == original.unique { + if *name == original { *self = std::mem::replace(arg, Term::Error.force()); } } @@ -1095,14 +1152,14 @@ impl Term { ) -> VarLookup { match self { Term::Var(name) => { - if name.text == search_for.text && name.unique == search_for.unique { + if *name == search_for { VarLookup::new_found() } else { VarLookup::new() } } Term::Delay(body) => { - let not_forced: isize = isize::from(force_stack.pop().is_none()); + let not_forced = usize::from(force_stack.pop().is_none()); body.var_occurrences(search_for, arg_stack, force_stack) .delay_if_found(not_forced) @@ -1114,22 +1171,43 @@ impl Term { if parameter_name.text == NO_INLINE { body.var_occurrences(search_for, arg_stack, force_stack) .no_inline_if_found() - } else if parameter_name.text == search_for.text - && parameter_name.unique == search_for.unique - { + } else if *parameter_name == search_for { VarLookup::new() } else { - let not_applied: isize = isize::from(arg_stack.pop().is_none()); + let not_applied = usize::from(arg_stack.pop().is_none()); body.var_occurrences(search_for, arg_stack, force_stack) .delay_if_found(not_applied) } } Term::Apply { function, argument } => { + // unwrap apply and add void to arg stack! arg_stack.push(()); - function - .var_occurrences(search_for.clone(), arg_stack, force_stack) - .combine(argument.var_occurrences(search_for, vec![], vec![])) + let apply_var_occurrence_stack = |term: &Term, arg_stack: Vec<()>| { + term.var_occurrences(search_for.clone(), arg_stack, force_stack) + }; + + let apply_var_occurrence_no_stack = + |term: &Term| term.var_occurrences(search_for.clone(), vec![], vec![]); + + if let Term::Apply { + function: next_func, + argument: next_arg, + } = function.as_ref() + { + // unwrap apply and add void to arg stack! + arg_stack.push(()); + next_func.carry_args_to_branch( + next_arg, + argument, + arg_stack, + apply_var_occurrence_stack, + apply_var_occurrence_no_stack, + ) + } else { + apply_var_occurrence_stack(function, arg_stack) + .combine(apply_var_occurrence_no_stack(argument)) + } } Term::Force(x) => { force_stack.push(()); @@ -1141,6 +1219,81 @@ impl Term { } } + // This handles the very common case of (if condition then body else error) + // or (if condition then error else body) + // In this case it is fine to treat the body as if it is not delayed + // since the other branch is error + fn carry_args_to_branch( + &self, + then_arg: &Rc>, + else_arg: &Rc>, + mut arg_stack: Vec<()>, + var_occurrence_stack: impl FnOnce(&Term, Vec<()>) -> VarLookup, + var_occurrence_no_stack: impl Fn(&Term) -> VarLookup, + ) -> VarLookup { + let Term::Apply { + function: builtin, + argument: condition, + } = self + else { + return var_occurrence_stack(self, arg_stack) + .combine(var_occurrence_no_stack(then_arg)) + .combine(var_occurrence_no_stack(else_arg)); + }; + + // unwrap apply and add void to arg stack! + arg_stack.push(()); + + let Term::Delay(else_arg) = else_arg.as_ref() else { + return var_occurrence_stack(builtin, arg_stack) + .combine(var_occurrence_no_stack(condition)) + .combine(var_occurrence_no_stack(then_arg)) + .combine(var_occurrence_no_stack(else_arg)); + }; + + let Term::Delay(then_arg) = then_arg.as_ref() else { + return var_occurrence_stack(builtin, arg_stack) + .combine(var_occurrence_no_stack(condition)) + .combine(var_occurrence_no_stack(then_arg)) + .combine(var_occurrence_no_stack(else_arg)); + }; + + match builtin.as_ref() { + Term::Var(a) + if a.text == DefaultFunction::IfThenElse.wrapped_name() + || a.text == DefaultFunction::ChooseList.wrapped_name() => + { + if matches!(else_arg.as_ref(), Term::Error) { + // Pop 3 args of arg_stack due to branch execution + arg_stack.pop(); + arg_stack.pop(); + arg_stack.pop(); + + var_occurrence_no_stack(condition) + .combine(var_occurrence_stack(then_arg, arg_stack)) + } else if matches!(then_arg.as_ref(), Term::Error) { + // Pop 3 args of arg_stack due to branch execution + arg_stack.pop(); + arg_stack.pop(); + arg_stack.pop(); + + var_occurrence_no_stack(condition) + .combine(var_occurrence_stack(else_arg, arg_stack)) + } else { + var_occurrence_stack(builtin, arg_stack) + .combine(var_occurrence_no_stack(condition)) + .combine(var_occurrence_no_stack(then_arg)) + .combine(var_occurrence_no_stack(else_arg)) + } + } + + _ => var_occurrence_stack(builtin, arg_stack) + .combine(var_occurrence_no_stack(condition)) + .combine(var_occurrence_no_stack(then_arg)) + .combine(var_occurrence_no_stack(else_arg)), + } + } + fn lambda_reducer( &mut self, _id: Option, @@ -1175,7 +1328,10 @@ impl Term { let body = Rc::make_mut(body); context.inlined_apply_ids.push(arg_id); - body.substitute_var(parameter_name.clone(), arg_term.pierce_no_inlines()); + body.substitute_var( + parameter_name.clone(), + arg_term.pierce_no_inlines_ref(), + ); // creates new body that replaces all var occurrences with the arg *self = std::mem::replace(body, Term::Error.force()); } @@ -1206,8 +1362,8 @@ impl Term { } if has_forces { - context.builtins_map.insert(*func as u8, ()); - *self = Term::var(format!("__{}_wrapped", func.aiken_name())); + context.builtins_map.insert(*func, ()); + *self = Term::var(func.wrapped_name()); } } } @@ -1253,6 +1409,303 @@ impl Term { } } + // The ultimate function when used in conjunction with case_constr_apply + // This splits [lam fun_name [lam fun_name2 rest ..] ..] into + // [[lam fun_name lam fun_name2 rest ..]..] thus + // allowing for some crazy gains from cast_constr_apply_reducer + fn split_body_lambda(&mut self) { + let mut arg_stack = vec![]; + let mut current_term = &mut std::mem::replace(self, Term::Error.force()); + let mut unsat_lams = vec![]; + + let mut function_groups = vec![vec![]]; + let mut function_dependencies = vec![vec![]]; + + loop { + match current_term { + Term::Apply { function, argument } => { + current_term = Rc::make_mut(function); + + let arg = Rc::make_mut(argument); + + arg.split_body_lambda(); + + arg_stack.push(std::mem::replace(arg, Term::Error.force())); + } + Term::Lambda { + parameter_name, + body, + } => { + current_term = Rc::make_mut(body); + + if let Some(arg) = arg_stack.pop() { + let names = arg.get_var_names(); + + let func = (parameter_name.clone(), arg); + + if let Some((position, _)) = + function_groups.iter().enumerate().rfind(|named_functions| { + named_functions + .1 + .iter() + .any(|(name, _)| names.contains(name)) + }) + { + let insert_position = position + 1; + if insert_position == function_groups.len() { + function_groups.push(vec![func]); + function_dependencies.push(names); + } else { + function_groups[insert_position].push(func); + function_dependencies[insert_position].extend(names); + } + } else { + function_groups[0].push(func); + function_dependencies[0].extend(names); + } + } else { + unsat_lams.push(parameter_name.clone()); + } + } + Term::Delay(term) | Term::Force(term) => { + Rc::make_mut(term).split_body_lambda(); + break; + } + Term::Case { .. } => todo!(), + Term::Constr { .. } => todo!(), + _ => break, + } + } + let mut swap_postions = vec![]; + + function_groups + .iter() + .enumerate() + .for_each(|(group_index, group)| { + if group.len() <= 3 { + group + .iter() + .enumerate() + .rev() + .for_each(|(item_index, (item_name, _))| { + let current_eligible_position = function_dependencies + .iter() + .enumerate() + .fold_while(group_index, |acc, (new_position, dependencies)| { + if dependencies.contains(item_name) { + FoldWhile::Done(acc) + } else { + FoldWhile::Continue(new_position) + } + }) + .into_inner(); + + if current_eligible_position > group_index { + swap_postions.push(( + group_index, + item_index, + current_eligible_position, + )); + } + }); + } + }); + + for (group_index, item_index, swap_index) in swap_postions { + let item = function_groups[group_index].remove(item_index); + + function_groups[swap_index].push(item); + } + + let term_to_build_on = std::mem::replace(current_term, Term::Error.force()); + + // Replace args that weren't consumed + let term = arg_stack + .into_iter() + .rfold(term_to_build_on, |term, arg| term.apply(arg)); + + let term = function_groups.into_iter().rfold(term, |term, group| { + let term = group.iter().rfold(term, |term, (name, _)| Term::Lambda { + parameter_name: name.clone(), + body: term.into(), + }); + + group + .into_iter() + .fold(term, |term, (_, arg)| term.apply(arg)) + }); + + let term = unsat_lams + .into_iter() + .rfold(term, |term, name| Term::Lambda { + parameter_name: name.clone(), + body: term.into(), + }); + + *self = term; + } + + fn get_var_names(&self) -> Vec> { + let mut names = vec![]; + + let mut term = self; + + loop { + match term { + Term::Apply { function, argument } => { + let arg_names = argument.get_var_names(); + + names.extend(arg_names); + + term = function; + } + Term::Var(name) => { + names.push(name.clone()); + break; + } + Term::Delay(t) => { + term = t; + } + Term::Lambda { body, .. } => { + term = body; + } + Term::Constant(_) | Term::Error | Term::Builtin(_) => { + break; + } + Term::Force(t) => { + term = t; + } + Term::Constr { .. } => todo!(), + Term::Case { .. } => todo!(), + } + } + + names + } + + // IMPORTANT: RUNS ONE TIME AND ONLY ON THE LAST PASS + fn case_constr_apply_reducer( + &mut self, + _id: Option, + _arg_stack: Vec, + _scope: &Scope, + _context: &mut Context, + ) { + let mut term = &mut std::mem::replace(self, Term::Error.force()); + + let mut arg_vec = vec![]; + + while let Term::Apply { function, argument } = term { + arg_vec.push(Rc::make_mut(argument)); + + term = Rc::make_mut(function); + } + + arg_vec.reverse(); + + match term { + Term::Case { constr, branches } + if branches.len() == 1 && matches!(constr.as_ref(), Term::Constr { .. }) => + { + let Term::Constr { fields, .. } = Rc::make_mut(constr) else { + unreachable!(); + }; + + for arg in arg_vec { + fields.push(std::mem::replace(arg, Term::Error.force())); + } + + *self = std::mem::replace(term, Term::Error.force()); + } + _ => { + if arg_vec.len() > 2 { + let mut fields = vec![]; + + for arg in arg_vec { + fields.push(std::mem::replace(arg, Term::Error.force())); + } + + *self = Term::constr(0, fields) + .case(vec![std::mem::replace(term, Term::Error.force())]); + } else { + for arg in arg_vec { + *term = (std::mem::replace(term, Term::Error.force())) + .apply(std::mem::replace(arg, Term::Error.force())); + } + + *self = std::mem::replace(term, Term::Error.force()); + } + } + } + } + // List in Aiken is actually List> + // So now we want to convert writeBits arg List> to List + // Important: Only runs once and at the end. + fn write_bits_convert_arg( + &mut self, + id: Option, + mut arg_stack: Vec, + _scope: &Scope, + context: &mut Context, + ) { + match self { + Term::Apply { argument, .. } => { + let id = id.unwrap(); + + if context.write_bits_indices_arg.contains(&id) { + match Rc::make_mut(argument) { + Term::Constant(constant) => { + let Constant::ProtoList(tipo, items) = Rc::make_mut(constant) else { + unreachable!(); + }; + + assert!(*tipo == Type::Data); + *tipo = Type::Integer; + + for item in items { + let Constant::Data(PlutusData::BigInt(i)) = item else { + unreachable!(); + }; + + *item = Constant::Integer(from_pallas_bigint(i)); + } + } + arg => { + context.write_bits_convert = true; + + *arg = Term::var(INDICES_CONVERTER) + .apply(std::mem::replace(arg, Term::Error.force())); + } + } + } + } + + Term::Builtin(DefaultFunction::WriteBits) => { + if arg_stack.is_empty() { + context.write_bits_convert = true; + + *self = Term::write_bits() + .apply(Term::var("__arg_1")) + .apply(Term::var(INDICES_CONVERTER).apply(Term::var("__arg_2"))) + .apply(Term::var("__arg_3")) + .lambda("__arg_3") + .lambda("__arg_2") + .lambda("__arg_1") + } else { + // first arg not needed + arg_stack.pop(); + + let Some(Args::Apply(arg_id, _)) = arg_stack.pop() else { + return; + }; + + context.write_bits_indices_arg.push(arg_id); + } + } + _ => (), + } + } + fn identity_reducer( &mut self, _id: Option, @@ -1269,44 +1722,35 @@ impl Term { } => { let body = Rc::make_mut(body); // pops stack here no matter what - let temp = Term::Error; - if let ( - arg_id, - Term::Lambda { - parameter_name: identity_name, - body: identity_body, - }, - ) = match &arg_stack.pop() { - Some(Args::Apply( - arg_id, - Term::Lambda { - parameter_name: inline_name, - body, - }, - )) if inline_name.text == NO_INLINE => (*arg_id, body.as_ref()), - Some(Args::Apply(arg_id, term)) => (*arg_id, term), - _ => (0, &temp), - } { - let Term::Var(identity_var) = identity_body.as_ref() else { - return false; - }; + let Some(Args::Apply(arg_id, identity_func)) = arg_stack.pop() else { + return false; + }; + + let Term::Lambda { + parameter_name: identity_name, + body: identity_body, + } = identity_func.pierce_no_inlines() + else { + return false; + }; - if identity_var.text == identity_name.text - && identity_var.unique == identity_name.unique + let Term::Var(identity_var) = identity_body.as_ref() else { + return false; + }; + + if *identity_var == identity_name { + // Replace all applied usages of identity with the arg + body.replace_identity_usage(parameter_name.clone()); + // Have to check if the body still has any occurrences of the parameter + // After attempting replacement + if !body + .var_occurrences(parameter_name.clone(), vec![], vec![]) + .found { - // Replace all applied usages of identity with the arg - body.replace_identity_usage(parameter_name.clone()); - // Have to check if the body still has any occurrences of the parameter - // After attempting replacement - if !body - .var_occurrences(parameter_name.clone(), vec![], vec![]) - .found - { - changed = true; - context.inlined_apply_ids.push(arg_id); - *self = std::mem::replace(body, Term::Error.force()); - } + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = std::mem::replace(body, Term::Error.force()); } } } @@ -1333,53 +1777,52 @@ impl Term { body, } => { // pops stack here no matter what - if let Some(Args::Apply(arg_id, arg_term)) = arg_stack.pop() { - let arg_term = match &arg_term { - Term::Lambda { - parameter_name, - body, - } if parameter_name.text == NO_INLINE => body.as_ref().clone(), - _ => arg_term, - }; + let Some(Args::Apply(arg_id, arg_term)) = arg_stack.pop() else { + return false; + }; - let body = Rc::make_mut(body); + let arg_term = arg_term.pierce_no_inlines_ref(); - let var_lookup = body.var_occurrences(parameter_name.clone(), vec![], vec![]); + let body = Rc::make_mut(body); - let substitute_condition = (var_lookup.delays == 0 && !var_lookup.no_inline) - || matches!( - &arg_term, - Term::Var(_) - | Term::Constant(_) - | Term::Delay(_) - | Term::Lambda { .. } - | Term::Builtin(_), - ); + let var_lookup = body.var_occurrences(parameter_name.clone(), vec![], vec![]); - if var_lookup.occurrences == 1 && substitute_condition { - changed = true; - body.substitute_var(parameter_name.clone(), arg_term.pierce_no_inlines()); + let must_execute_condition = var_lookup.delays == 0 && !var_lookup.no_inline; - context.inlined_apply_ids.push(arg_id); - *self = std::mem::replace(body, Term::Error.force()); + let cant_throw_condition = matches!( + arg_term, + Term::Var(_) + | Term::Constant(_) + | Term::Delay(_) + | Term::Lambda { .. } + | Term::Builtin(_), + ); - // This will strip out unused terms that can't throw an error by themselves - } else if !var_lookup.found - && matches!( - arg_term, - Term::Var(_) - | Term::Constant(_) - | Term::Delay(_) - | Term::Lambda { .. } - | Term::Builtin(_) - ) - { - changed = true; - context.inlined_apply_ids.push(arg_id); - *self = std::mem::replace(body, Term::Error.force()); - } + let force_wrapped_builtin = context + .builtins_map + .keys() + .any(|b| b.wrapped_name() == parameter_name.text); + + // This will inline terms that only occur once + // if they are guaranteed to execute or can't throw an error by themselves + if !force_wrapped_builtin + && var_lookup.occurrences == 1 + && (must_execute_condition || cant_throw_condition) + { + changed = true; + body.substitute_var(parameter_name.clone(), arg_term); + + context.inlined_apply_ids.push(arg_id); + *self = std::mem::replace(body, Term::Error.force()); + + // This will strip out unused terms that can't throw an error by themselves + } else if !var_lookup.found && (cant_throw_condition || force_wrapped_builtin) { + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = std::mem::replace(body, Term::Error.force()); } } + Term::Constr { .. } => todo!(), Term::Case { .. } => todo!(), _ => {} @@ -1644,13 +2087,10 @@ impl Term { unreachable!() }; - term.pierce_no_inlines() + term.pierce_no_inlines_ref() }) .collect_vec(); - if func.can_curry_builtin() - && arg_stack.len() == func.arity() - && func.is_error_safe(&args) - { + if arg_stack.len() == func.arity() && func.is_error_safe(&args) { changed = true; let applied_term = arg_stack @@ -1664,14 +2104,14 @@ impl Term { acc.apply(arg.pierce_no_inlines().clone()) }); - // Check above for is error safe + // The check above is to make sure the program is error safe let eval_term: Term = Program { version: (1, 0, 0), term: applied_term, } .to_named_debruijn() .unwrap() - .eval(ExBudget::max()) + .eval(ExBudget::default()) .result() .unwrap() .try_into() @@ -1736,7 +2176,7 @@ impl Term { } } - pub fn pierce_no_inlines(&self) -> &Self { + pub fn pierce_no_inlines_ref(&self) -> &Self { let mut term = self; while let Term::Lambda { @@ -1753,6 +2193,24 @@ impl Term { term } + + pub fn pierce_no_inlines(mut self) -> Self { + let term = &mut self; + + while let Term::Lambda { + parameter_name, + body, + } = term + { + if parameter_name.as_ref().text == NO_INLINE { + *term = std::mem::replace(Rc::make_mut(body), Term::Error.force()); + } else { + break; + } + } + + std::mem::replace(term, Term::Error.force()) + } } impl Program { @@ -1769,9 +2227,11 @@ impl Program { let mut context = Context { inlined_apply_ids: vec![], constants_to_flip: vec![], + write_bits_indices_arg: vec![], builtins_map: IndexMap::new(), blst_p1_list: vec![], blst_p2_list: vec![], + write_bits_convert: false, node_count: 0, }; @@ -1791,16 +2251,16 @@ impl Program { context, ) } - // This one runs the optimizations that are only done a single time + // This runs the optimizations that are only done a single time pub fn run_once_pass(self) -> Self { - let program = self + // First pass is necessary to ensure fst_pair and snd_pair are inlined before + // builtin_force_reducer is run + let (program, context) = self .traverse_uplc_with(false, &mut |id, term, _arg_stack, scope, context| { term.inline_constr_ops(id, vec![], scope, context); }) - .0; - - let (program, context) = - program.traverse_uplc_with(false, &mut |id, term, arg_stack, scope, context| { + .0 + .traverse_uplc_with(false, &mut |id, term, arg_stack, scope, context| { term.bls381_compressor(id, vec![], scope, context); term.builtin_force_reducer(id, arg_stack, scope, context); term.remove_inlined_ids(id, vec![], scope, context); @@ -1824,16 +2284,16 @@ impl Program { .apply(Term::bls12_381_g2_uncompress().apply(Term::byte_string(compressed))); } - for default_func_index in context.builtins_map.keys().sorted().cloned() { - let default_func: DefaultFunction = default_func_index.try_into().unwrap(); + for default_func in context.builtins_map.keys().sorted().cloned() { + term = term.lambda(default_func.wrapped_name()); + } - term = term - .lambda(format!("__{}_wrapped", default_func.aiken_name())) - .apply(if default_func.force_count() == 1 { - Term::Builtin(default_func).force() - } else { - Term::Builtin(default_func).force().force() - }); + for default_func in context.builtins_map.keys().sorted().cloned().rev() { + term = term.apply(if default_func.force_count() == 1 { + Term::Builtin(default_func).force() + } else { + Term::Builtin(default_func).force().force() + }); } let mut program = Program { @@ -1852,38 +2312,36 @@ impl Program { pub fn multi_pass(self) -> (Self, Context) { self.traverse_uplc_with(true, &mut |id, term, arg_stack, scope, context| { - let mut changed; - - changed = term.lambda_reducer(id, arg_stack.clone(), scope, context); - if changed { + let false = term.lambda_reducer(id, arg_stack.clone(), scope, context) else { term.remove_inlined_ids(id, vec![], scope, context); return; - } - changed = term.identity_reducer(id, arg_stack.clone(), scope, context); - if changed { + }; + + let false = term.identity_reducer(id, arg_stack.clone(), scope, context) else { term.remove_inlined_ids(id, vec![], scope, context); return; - } - changed = term.inline_reducer(id, arg_stack.clone(), scope, context); - if changed { + }; + + let false = term.inline_reducer(id, arg_stack.clone(), scope, context) else { term.remove_inlined_ids(id, vec![], scope, context); return; - } - changed = term.force_delay_reducer(id, arg_stack.clone(), scope, context); - if changed { + }; + + let false = term.force_delay_reducer(id, arg_stack.clone(), scope, context) else { term.remove_inlined_ids(id, vec![], scope, context); return; - } - changed = term.cast_data_reducer(id, arg_stack.clone(), scope, context); - if changed { + }; + + let false = term.cast_data_reducer(id, arg_stack.clone(), scope, context) else { term.remove_inlined_ids(id, vec![], scope, context); return; - } - changed = term.builtin_eval_reducer(id, arg_stack.clone(), scope, context); - if changed { + }; + + let false = term.builtin_eval_reducer(id, arg_stack.clone(), scope, context) else { term.remove_inlined_ids(id, vec![], scope, context); return; - } + }; + term.convert_arithmetic_ops(id, arg_stack, scope, context); term.flip_constants(id, vec![], scope, context); term.remove_inlined_ids(id, vec![], scope, context); @@ -1895,21 +2353,53 @@ impl Program { inline_lambda: bool, with: &mut impl FnMut(Option, &mut Term, Vec, &Scope, &mut Context), ) -> Self { - self.traverse_uplc_with(inline_lambda, &mut |id, term, arg_stack, scope, context| { - with(id, term, arg_stack, scope, context); - term.flip_constants(id, vec![], scope, context); - term.remove_inlined_ids(id, vec![], scope, context); - }) - .0 + let (mut program, context) = + self.traverse_uplc_with(inline_lambda, &mut |id, term, arg_stack, scope, context| { + with(id, term, arg_stack, scope, context); + term.flip_constants(id, vec![], scope, context); + term.remove_inlined_ids(id, vec![], scope, context); + }); + + if context.write_bits_convert { + program.term = program.term.data_list_to_integer_list(); + } + + program } - pub fn clean_up(self) -> Self { + pub fn clean_up_no_inlines(self) -> Self { self.traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| { term.remove_no_inlines(id, vec![], scope, context); }) .0 } + pub fn afterwards(self) -> Self { + let (mut program, context) = + self.traverse_uplc_with(true, &mut |id, term, arg_stack, scope, context| { + term.write_bits_convert_arg(id, arg_stack, scope, context); + }); + + program = program + .split_body_lambda_reducer() + .traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| { + term.case_constr_apply_reducer(id, vec![], scope, context); + }) + .0; + + if context.write_bits_convert { + program.term = program.term.data_list_to_integer_list(); + } + + let mut interner = CodeGenInterner::new(); + + interner.program(&mut program); + + let program = Program::::try_from(program).unwrap(); + + Program::::try_from(program).unwrap() + } + // This one doesn't use the context since it's complicated and traverses the ast twice pub fn builtin_curry_reducer(self) -> Self { let mut curried_terms = vec![]; @@ -1949,19 +2439,18 @@ impl Program { ) { // We found it the builtin was curried before // So now we merge the new args into the existing curried builtin - let curried_builtin = curried_terms.swap_remove(index); let curried_builtin = curried_builtin.merge_node_by_path(builtin_args.clone()); - let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else { - unreachable!(); - }; - flipped_terms .insert(scope.clone(), curried_builtin.is_flipped(&builtin_args)); + let Some(id_vec) = curried_builtin.get_id_args(builtin_args) else { + unreachable!(); + }; + curried_terms.push(curried_builtin); id_vec @@ -1969,7 +2458,7 @@ impl Program { // Brand new buitlin so we add it to the list let curried_builtin = builtin_args.clone().args_to_curried_args(*func); - let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else { + let Some(id_vec) = curried_builtin.get_id_args(builtin_args) else { unreachable!(); }; @@ -2068,7 +2557,7 @@ impl Program { let builtin_args = BuiltinArgs::args_from_arg_stack(arg_stack, *func); - let Some(mut id_vec) = curried_builtin.get_id_args(&builtin_args) else { + let Some(mut id_vec) = curried_builtin.get_id_args(builtin_args) else { return; }; @@ -2142,6 +2631,12 @@ impl Program { step_b } + + pub fn split_body_lambda_reducer(mut self) -> Self { + self.term.split_body_lambda(); + + self + } } fn id_vec_function_to_var(func_name: &str, id_vec: &[usize]) -> String { @@ -2165,10 +2660,11 @@ fn is_a_builtin_wrapper(term: &Term) -> bool { while let Term::Apply { function, argument } = term { match argument.as_ref() { - Term::Var(name) => arg_names.push(name), + Term::Var(name) => arg_names.push(format!("{}_{}", name.text, name.unique)), Term::Constant(_) => {} _ => { + //Break loop, it's not a builtin wrapper function return false; } } @@ -2178,7 +2674,7 @@ fn is_a_builtin_wrapper(term: &Term) -> bool { arg_names.iter().all(|item| names.contains(item)) && matches!(term, Term::Builtin(_)) } -fn pop_lambdas_and_get_names(term: &Term) -> (Vec>, &Term) { +fn pop_lambdas_and_get_names(term: &Term) -> (Vec, &Term) { let mut names = vec![]; let mut term = term; @@ -2189,7 +2685,7 @@ fn pop_lambdas_and_get_names(term: &Term) -> (Vec>, &Term) } = term { if parameter_name.text != NO_INLINE { - names.push(parameter_name.clone()); + names.push(format!("{}_{}", parameter_name.text, parameter_name.unique)); } term = body.as_ref(); } @@ -2365,11 +2861,11 @@ mod tests { .lambda("y") // Forces are automatically applied by builder .lambda("__cons_list_wrapped") - .apply(Term::mk_cons()) .lambda("__head_list_wrapped") - .apply(Term::head_list()) .lambda("__tail_list_wrapped") - .apply(Term::tail_list()), + .apply(Term::tail_list()) + .apply(Term::head_list()) + .apply(Term::mk_cons()), }; compare_optimization(expected, program, |p| p.run_once_pass()); @@ -2453,9 +2949,9 @@ mod tests { .apply(Term::data(Data::integer(5.into()))), ) .lambda("__fst_pair_wrapped") - .apply(Term::fst_pair()) .lambda("__snd_pair_wrapped") - .apply(Term::snd_pair()), + .apply(Term::snd_pair()) + .apply(Term::fst_pair()), }; compare_optimization(expected, program, |p| p.run_once_pass()); @@ -2644,6 +3140,80 @@ mod tests { }); } + #[test] + fn inline_reduce_if_then_else_then() { + let program: Program = Program { + version: (1, 0, 0), + term: Term::var("__if_then_else_wrapped") + .apply(Term::bool(true)) + .apply(Term::sha3_256().apply(Term::var("x")).delay()) + .apply(Term::Error.delay()) + .force() + .lambda("x") + .apply(Term::sha3_256().apply(Term::byte_string(vec![]))) + .lambda("__if_then_else_wrapped") + .apply(Term::Builtin(DefaultFunction::IfThenElse).force()), + }; + + let expected = Program { + version: (1, 0, 0), + term: Term::var("__if_then_else_wrapped") + .apply(Term::bool(true)) + .apply( + Term::sha3_256() + .apply(Term::sha3_256().apply(Term::byte_string(vec![]))) + .delay(), + ) + .apply(Term::Error.delay()) + .force() + .lambda("__if_then_else_wrapped") + .apply(Term::Builtin(DefaultFunction::IfThenElse).force()), + }; + + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.inline_reducer(id, arg_stack, scope, context); + }) + }); + } + + #[test] + fn inline_reduce_if_then_else_else() { + let program: Program = Program { + version: (1, 0, 0), + term: Term::var("__if_then_else_wrapped") + .apply(Term::bool(true)) + .apply(Term::Error.delay()) + .apply(Term::sha3_256().apply(Term::var("x")).delay()) + .force() + .lambda("x") + .apply(Term::sha3_256().apply(Term::byte_string(vec![]))) + .lambda("__if_then_else_wrapped") + .apply(Term::Builtin(DefaultFunction::IfThenElse).force()), + }; + + let expected = Program { + version: (1, 0, 0), + term: Term::var("__if_then_else_wrapped") + .apply(Term::bool(true)) + .apply(Term::Error.delay()) + .apply( + Term::sha3_256() + .apply(Term::sha3_256().apply(Term::byte_string(vec![]))) + .delay(), + ) + .force() + .lambda("__if_then_else_wrapped") + .apply(Term::Builtin(DefaultFunction::IfThenElse).force()), + }; + + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.inline_reducer(id, arg_stack, scope, context); + }) + }); + } + #[test] fn inline_reduce_0_occurrence() { let program: Program = Program { @@ -3234,4 +3804,63 @@ mod tests { compare_optimization(expected, program, |p| p.builtin_curry_reducer()); } + + #[test] + fn case_constr_apply_test_1() { + let program: Program = Program { + version: (1, 1, 0), + term: Term::add_integer() + .apply(Term::integer(0.into())) + .apply(Term::integer(0.into())) + .apply(Term::integer(0.into())) + .apply(Term::integer(0.into())) + .apply(Term::integer(0.into())) + .apply(Term::integer(0.into())), + }; + + let expected = Program { + version: (1, 1, 0), + term: Term::constr( + 0, + vec![ + Term::integer(0.into()), + Term::integer(0.into()), + Term::integer(0.into()), + Term::integer(0.into()), + Term::integer(0.into()), + Term::integer(0.into()), + ], + ) + .case(vec![Term::add_integer()]), + }; + + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.case_constr_apply_reducer(id, arg_stack, scope, context); + }) + }); + } + + #[test] + fn case_constr_apply_test_2() { + let program: Program = Program { + version: (1, 1, 0), + term: Term::add_integer() + .apply(Term::integer(0.into())) + .apply(Term::integer(0.into())), + }; + + let expected = Program { + version: (1, 1, 0), + term: Term::add_integer() + .apply(Term::integer(0.into())) + .apply(Term::integer(0.into())), + }; + + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.case_constr_apply_reducer(id, arg_stack, scope, context); + }) + }); + } } diff --git a/crates/uplc/tests/conformance.rs b/crates/uplc/tests/conformance.rs index 7aca4897f..b1ff57614 100644 --- a/crates/uplc/tests/conformance.rs +++ b/crates/uplc/tests/conformance.rs @@ -49,7 +49,7 @@ peg::parser! { fn actual_evaluation_result( file: &Path, language: &Language, -) -> Result<(Program, ExBudget), String> { +) -> Result<(Program, ExBudget), String> { let code = fs::read_to_string(file).expect("Failed to read .uplc file"); let program = parser::program(&code).map_err(|_| PARSE_ERROR.to_string())?; @@ -68,7 +68,7 @@ fn actual_evaluation_result( let program = Program { version, term }; - Ok((program.try_into().unwrap(), cost)) + Ok((program, cost)) } fn plutus_conformance_tests(language: Language) { @@ -89,7 +89,8 @@ fn plutus_conformance_tests(language: Language) { let expected_budget_file = path.with_extension("uplc.budget.expected"); let eval = actual_evaluation_result(path, &language); - let expected = expected_to_program(&expected_file); + let expected = expected_to_program(&expected_file) + .map(|program| Program::::try_from(program).unwrap()); match eval { Ok((actual, cost)) => { diff --git a/examples/acceptance_tests/117/aiken.toml b/examples/acceptance_tests/117/aiken.toml new file mode 100644 index 000000000..86dfc5e86 --- /dev/null +++ b/examples/acceptance_tests/117/aiken.toml @@ -0,0 +1,2 @@ +name = "aiken-lang/acceptance_test_117" +version = "0.0.0" diff --git a/examples/acceptance_tests/117/lib/tests.ak b/examples/acceptance_tests/117/lib/tests.ak new file mode 100644 index 000000000..c8b913bfb --- /dev/null +++ b/examples/acceptance_tests/117/lib/tests.ak @@ -0,0 +1,36 @@ +use aiken/builtin.{write_bits} + +test bar() { + let x = + if True { + [0, 1, 2, 3] + } else { + [0, 1] + } + + write_bits(#"f0", x, True) == #"ff" +} + +test baz() { + let x = [0, 1, 2, 3] + write_bits(#"f0", x, True) == #"ff" +} + +test bur() { + let x = + if True { + [0, 1, 2, 3] + } else { + [0, 1] + } + + if False { + fn(_a, _b, _c) { #"" } + } else { + write_bits + }( + #"f0", + x, + True, + ) == #"ff" +} diff --git a/examples/acceptance_tests/script_context/v3/validators/mint.ak b/examples/acceptance_tests/script_context/v3/validators/mint.ak index 7c7a59ef6..0c4e19e80 100644 --- a/examples/acceptance_tests/script_context/v3/validators/mint.ak +++ b/examples/acceptance_tests/script_context/v3/validators/mint.ak @@ -83,6 +83,8 @@ fn assert_outputs( }, ) == list.at(outputs, 0) + trace @"This test validator has a higher hash than the one below. Change and try again." + expect Some( Output {