diff --git a/.noir-sync-commit b/.noir-sync-commit index 612897a2dba..07ca104bbc0 100644 --- a/.noir-sync-commit +++ b/.noir-sync-commit @@ -1 +1 @@ -a9e985064303b0843cbf68fb5a9d41f9ade1e30d +130d99125a09110a3ee3e877d88d83b5aa37f369 diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs index f23e64aec52..1594bac2acc 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs @@ -58,7 +58,11 @@ pub(crate) fn gen_brillig_for( brillig: &Brillig, ) -> Result, InternalError> { // Create the entry point artifact - let globals_memory_size = brillig.globals_memory_size.get(&func.id()).copied().unwrap_or(0); + let globals_memory_size = brillig + .globals_memory_size + .get(&func.id()) + .copied() + .expect("Should have the globals memory size specified for an entry point"); let mut entry_point = BrilligContext::new_entry_point_artifact( arguments, FunctionContext::return_values(func), diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs index 639458f5a7d..6f5645485a2 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs @@ -182,20 +182,11 @@ impl BrilligGlobals { &self, brillig_function_id: FunctionId, ) -> SsaToBrilligGlobals { - let mut globals_allocations = HashMap::default(); - - // First check whether the `brillig_function_id` is itself an entry point, - // if so we can fetch the global allocations directly from `self.entry_point_globals_map`. - if let Some(globals) = self.entry_point_globals_map.get(&brillig_function_id) { - globals_allocations.extend(globals); - return globals_allocations; - } - - // If the Brillig function we are compiling is not an entry point, we should search - // for the entry point which triggers the given function. let entry_points = self.inner_call_to_entry_point.get(&brillig_function_id); + + let mut globals_allocations = HashMap::default(); if let Some(entry_points) = entry_points { - // A Brillig function can be used by multiple entry points. Fetch both globals allocations + // A Brillig function is used by multiple entry points. Fetch both globals allocations // in case one is used by the internal call. let entry_point_allocations = entry_points .iter() @@ -204,6 +195,11 @@ impl BrilligGlobals { for map in entry_point_allocations { globals_allocations.extend(map); } + } else if let Some(globals) = self.entry_point_globals_map.get(&brillig_function_id) { + // If there is no mapping from an inner call to an entry point, that means `brillig_function_id` + // is itself an entry point and we can fetch the global allocations directly from `self.entry_point_globals_map`. + // vec![globals] + globals_allocations.extend(globals); } else { unreachable!( "ICE: Expected global allocation to be set for function {brillig_function_id}" diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index f4ddc60c8a8..37ffa0ab9d7 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -98,6 +98,49 @@ impl Ssa { function.constant_fold(false, brillig_info); } + // It could happen that we inlined all calls to a given brillig function. + // In that case it's unused so we can remove it. This is what we check next. + self.remove_unused_brillig_functions(brillig_functions) + } + + fn remove_unused_brillig_functions( + mut self, + mut brillig_functions: BTreeMap, + ) -> Ssa { + // Remove from the above map functions that are called + for function in self.functions.values() { + for block_id in function.reachable_blocks() { + for instruction_id in function.dfg[block_id].instructions() { + let instruction = &function.dfg[*instruction_id]; + let Instruction::Call { func: func_id, arguments: _ } = instruction else { + continue; + }; + + let func_value = &function.dfg[*func_id]; + let Value::Function(func_id) = func_value else { continue }; + + if function.runtime().is_acir() { + brillig_functions.remove(func_id); + } + } + } + } + + // The ones that remain are never called: let's remove them. + for (func_id, func) in &brillig_functions { + // We never want to remove the main function (it could be `unconstrained` or it + // could have been turned into brillig if `--force-brillig` was given). + // We also don't want to remove entry points. + let runtime = func.runtime(); + if self.main_id == *func_id + || (runtime.is_entry_point() && matches!(runtime, RuntimeType::Acir(_))) + { + continue; + } + + self.functions.remove(func_id); + } + self } } @@ -682,6 +725,11 @@ impl<'brillig> Context<'brillig> { // Should we consider calls to slice_push_back and similar to be mutating operations as well? if let Store { value: array, .. } | ArraySet { array, .. } = instruction { + if function.dfg.is_global(*array) { + // Early return as we expect globals to be immutable. + return; + }; + let instruction = match &function.dfg[*array] { Value::Instruction { instruction, .. } => &function.dfg[*instruction], _ => return, @@ -1533,6 +1581,82 @@ mod test { assert_normalized_ssa_equals(ssa, expected); } + #[test] + fn inlines_brillig_call_with_entry_point_globals() { + let src = " + g0 = Field 2 + + acir(inline) fn main f0 { + b0(): + v1 = call f1() -> Field + return v1 + } + + brillig(inline) fn one f1 { + b0(): + v1 = add g0, Field 3 + return v1 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let mut ssa = ssa.dead_instruction_elimination(); + let used_globals_map = std::mem::take(&mut ssa.used_globals); + let brillig = ssa.to_brillig_with_globals(false, used_globals_map); + + let expected = " + g0 = Field 2 + + acir(inline) fn main f0 { + b0(): + return Field 5 + } + "; + + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_non_entry_point_globals() { + let src = " + g0 = Field 2 + + acir(inline) fn main f0 { + b0(): + v1 = call f1() -> Field + return v1 + } + + brillig(inline) fn entry_point f1 { + b0(): + v1 = call f2() -> Field + return v1 + } + + brillig(inline) fn one f2 { + b0(): + v1 = add g0, Field 3 + return v1 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let mut ssa = ssa.dead_instruction_elimination(); + let used_globals_map = std::mem::take(&mut ssa.used_globals); + let brillig = ssa.to_brillig_with_globals(false, used_globals_map); + + let expected = " + g0 = Field 2 + + acir(inline) fn main f0 { + b0(): + return Field 5 + } + "; + + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + #[test] fn does_not_use_cached_constrain_in_block_that_is_not_dominated() { let src = " diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/expression.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/expression.rs index d36966e2efe..9c9c0ded867 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/expression.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/expression.rs @@ -31,6 +31,7 @@ pub enum ExpressionKind { Cast(Box), Infix(Box), If(Box), + Match(Box), Variable(Path), Tuple(Vec), Lambda(Box), @@ -465,6 +466,12 @@ pub struct IfExpression { pub alternative: Option, } +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct MatchExpression { + pub expression: Expression, + pub rules: Vec<(/*pattern*/ Expression, /*branch*/ Expression)>, +} + #[derive(Debug, PartialEq, Eq, Clone)] pub struct Lambda { pub parameters: Vec<(Pattern, UnresolvedType)>, @@ -612,6 +619,7 @@ impl Display for ExpressionKind { Cast(cast) => cast.fmt(f), Infix(infix) => infix.fmt(f), If(if_expr) => if_expr.fmt(f), + Match(match_expr) => match_expr.fmt(f), Variable(path) => path.fmt(f), Constructor(constructor) => constructor.fmt(f), MemberAccess(access) => access.fmt(f), @@ -790,6 +798,16 @@ impl Display for IfExpression { } } +impl Display for MatchExpression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "match {} {{", self.expression)?; + for (pattern, branch) in &self.rules { + writeln!(f, " {pattern} -> {branch},")?; + } + write!(f, "}}") + } +} + impl Display for Lambda { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let parameters = vecmap(&self.parameters, |(name, r#type)| format!("{name}: {type}")); diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/visitor.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/visitor.rs index 30b8deb4925..a43bd0a5d3d 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/visitor.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/visitor.rs @@ -22,7 +22,7 @@ use crate::{ use super::{ ForBounds, FunctionReturnType, GenericTypeArgs, IntegerBitSize, ItemVisibility, - NoirEnumeration, Pattern, Signedness, TraitBound, TraitImplItemKind, TypePath, + MatchExpression, NoirEnumeration, Pattern, Signedness, TraitBound, TraitImplItemKind, TypePath, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, }; @@ -222,6 +222,10 @@ pub trait Visitor { true } + fn visit_match_expression(&mut self, _: &MatchExpression, _: Span) -> bool { + true + } + fn visit_tuple(&mut self, _: &[Expression], _: Span) -> bool { true } @@ -866,6 +870,9 @@ impl Expression { ExpressionKind::If(if_expression) => { if_expression.accept(self.span, visitor); } + ExpressionKind::Match(match_expression) => { + match_expression.accept(self.span, visitor); + } ExpressionKind::Tuple(expressions) => { if visitor.visit_tuple(expressions, self.span) { visit_expressions(expressions, visitor); @@ -1073,6 +1080,22 @@ impl IfExpression { } } +impl MatchExpression { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_match_expression(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.expression.accept(visitor); + for (pattern, branch) in &self.rules { + pattern.accept(visitor); + branch.accept(visitor); + } + } +} + impl Lambda { pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { if visitor.visit_lambda(self, span) { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs index ff5ff48cbf4..16278995104 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -7,9 +7,9 @@ use crate::{ ast::{ ArrayLiteral, BlockExpression, CallExpression, CastExpression, ConstructorExpression, Expression, ExpressionKind, Ident, IfExpression, IndexExpression, InfixExpression, - ItemVisibility, Lambda, Literal, MemberAccessExpression, MethodCallExpression, Path, - PathSegment, PrefixExpression, StatementKind, UnaryOp, UnresolvedTypeData, - UnresolvedTypeExpression, + ItemVisibility, Lambda, Literal, MatchExpression, MemberAccessExpression, + MethodCallExpression, Path, PathSegment, PrefixExpression, StatementKind, UnaryOp, + UnresolvedTypeData, UnresolvedTypeExpression, }, hir::{ comptime::{self, InterpreterError}, @@ -51,6 +51,7 @@ impl<'context> Elaborator<'context> { ExpressionKind::Cast(cast) => self.elaborate_cast(*cast, expr.span), ExpressionKind::Infix(infix) => return self.elaborate_infix(*infix, expr.span), ExpressionKind::If(if_) => self.elaborate_if(*if_), + ExpressionKind::Match(match_) => self.elaborate_match(*match_), ExpressionKind::Variable(variable) => return self.elaborate_variable(variable), ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple), ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda, None), @@ -926,6 +927,10 @@ impl<'context> Elaborator<'context> { (HirExpression::If(if_expr), ret_type) } + fn elaborate_match(&mut self, _match_expr: MatchExpression) -> (HirExpression, Type) { + (HirExpression::Error, Type::Error) + } + fn elaborate_tuple(&mut self, tuple: Vec) -> (HirExpression, Type) { let mut element_ids = Vec::with_capacity(tuple.len()); let mut element_types = Vec::with_capacity(tuple.len()); diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/display.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/display.rs index 6be5e19577d..1be4bbe61ab 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/display.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/display.rs @@ -8,9 +8,9 @@ use crate::{ ArrayLiteral, AsTraitPath, AssignStatement, BlockExpression, CallExpression, CastExpression, ConstrainStatement, ConstructorExpression, Expression, ExpressionKind, ForBounds, ForLoopStatement, ForRange, GenericTypeArgs, IfExpression, IndexExpression, - InfixExpression, LValue, Lambda, LetStatement, Literal, MemberAccessExpression, - MethodCallExpression, Pattern, PrefixExpression, Statement, StatementKind, UnresolvedType, - UnresolvedTypeData, + InfixExpression, LValue, Lambda, LetStatement, Literal, MatchExpression, + MemberAccessExpression, MethodCallExpression, Pattern, PrefixExpression, Statement, + StatementKind, UnresolvedType, UnresolvedTypeData, }, hir_def::traits::TraitConstraint, node_interner::{InternedStatementKind, NodeInterner}, @@ -241,6 +241,7 @@ impl<'interner> TokenPrettyPrinter<'interner> { | Token::GreaterEqual | Token::Equal | Token::NotEqual + | Token::FatArrow | Token::Arrow => write!(f, " {token} "), Token::Assign => { if last_was_op { @@ -602,6 +603,14 @@ fn remove_interned_in_expression_kind( .alternative .map(|alternative| remove_interned_in_expression(interner, alternative)), })), + ExpressionKind::Match(match_expr) => ExpressionKind::Match(Box::new(MatchExpression { + expression: remove_interned_in_expression(interner, match_expr.expression), + rules: vecmap(match_expr.rules, |(pattern, branch)| { + let pattern = remove_interned_in_expression(interner, pattern); + let branch = remove_interned_in_expression(interner, branch); + (pattern, branch) + }), + })), ExpressionKind::Variable(_) => expr, ExpressionKind::Tuple(expressions) => ExpressionKind::Tuple(vecmap(expressions, |expr| { remove_interned_in_expression(interner, expr) diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 6503b0cf77b..9abb1b190d5 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -246,7 +246,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { "unresolved_type_is_bool" => unresolved_type_is_bool(interner, arguments, location), "unresolved_type_is_field" => unresolved_type_is_field(interner, arguments, location), "unresolved_type_is_unit" => unresolved_type_is_unit(interner, arguments, location), - "zeroed" => zeroed(return_type, location.span), + "zeroed" => Ok(zeroed(return_type, location.span)), _ => { let item = format!("Comptime evaluation for builtin function '{name}'"); Err(InterpreterError::Unimplemented { item, location }) @@ -499,21 +499,21 @@ fn struct_def_generics( _ => return Err(InterpreterError::TypeMismatch { expected, actual, location }), }; - let generics: IResult<_> = struct_def + let generics = struct_def .generics .iter() - .map(|generic| -> IResult { + .map(|generic| { let generic_as_named = generic.clone().as_named_generic(); let numeric_type = match generic_as_named.kind() { Kind::Numeric(numeric_type) => Some(Value::Type(*numeric_type)), _ => None, }; - let numeric_type = option(option_typ.clone(), numeric_type, location.span)?; - Ok(Value::Tuple(vec![Value::Type(generic_as_named), numeric_type])) + let numeric_type = option(option_typ.clone(), numeric_type, location.span); + Value::Tuple(vec![Value::Type(generic_as_named), numeric_type]) }) .collect(); - Ok(Value::Slice(generics?, slice_item_type)) + Ok(Value::Slice(generics, slice_item_type)) } fn struct_def_hash(arguments: Vec<(Value, Location)>, location: Location) -> IResult { @@ -811,7 +811,7 @@ fn quoted_as_expr( }, ); - option(return_type, value, location.span) + Ok(option(return_type, value, location.span)) } // fn as_module(quoted: Quoted) -> Option @@ -834,7 +834,7 @@ fn quoted_as_module( module.map(Value::ModuleDefinition) }); - option(return_type, option_value, location.span) + Ok(option(return_type, option_value, location.span)) } // fn as_trait_constraint(quoted: Quoted) -> TraitConstraint @@ -1146,7 +1146,7 @@ where let option_value = f(typ)?; - option(return_type, option_value, location.span) + Ok(option(return_type, option_value, location.span)) } // fn type_eq(_first: Type, _second: Type) -> bool @@ -1181,7 +1181,7 @@ fn type_get_trait_impl( _ => None, }; - option(return_type, option_value, location.span) + Ok(option(return_type, option_value, location.span)) } // fn implements(self, constraint: TraitConstraint) -> bool @@ -1302,7 +1302,7 @@ fn typed_expr_as_function_definition( } else { None }; - option(return_type, option_value, location.span) + Ok(option(return_type, option_value, location.span)) } // fn get_type(self) -> Option @@ -1324,7 +1324,7 @@ fn typed_expr_get_type( } else { None }; - option(return_type, option_value, location.span) + Ok(option(return_type, option_value, location.span)) } // fn as_mutable_reference(self) -> Option @@ -1407,80 +1407,97 @@ where let typ = get_unresolved_type(interner, value)?; let option_value = f(typ); - - option(return_type, option_value, location.span) + Ok(option(return_type, option_value, location.span)) } // fn zeroed() -> T -fn zeroed(return_type: Type, span: Span) -> IResult { +fn zeroed(return_type: Type, span: Span) -> Value { match return_type { - Type::FieldElement => Ok(Value::Field(0u128.into())), + Type::FieldElement => Value::Field(0u128.into()), Type::Array(length_type, elem) => { if let Ok(length) = length_type.evaluate_to_u32(span) { - let element = zeroed(elem.as_ref().clone(), span)?; + let element = zeroed(elem.as_ref().clone(), span); let array = std::iter::repeat(element).take(length as usize).collect(); - Ok(Value::Array(array, Type::Array(length_type, elem))) + Value::Array(array, Type::Array(length_type, elem)) } else { // Assume we can resolve the length later - Ok(Value::Zeroed(Type::Array(length_type, elem))) + Value::Zeroed(Type::Array(length_type, elem)) } } - Type::Slice(_) => Ok(Value::Slice(im::Vector::new(), return_type)), + Type::Slice(_) => Value::Slice(im::Vector::new(), return_type), Type::Integer(sign, bits) => match (sign, bits) { - (Signedness::Unsigned, IntegerBitSize::One) => Ok(Value::U8(0)), - (Signedness::Unsigned, IntegerBitSize::Eight) => Ok(Value::U8(0)), - (Signedness::Unsigned, IntegerBitSize::Sixteen) => Ok(Value::U16(0)), - (Signedness::Unsigned, IntegerBitSize::ThirtyTwo) => Ok(Value::U32(0)), - (Signedness::Unsigned, IntegerBitSize::SixtyFour) => Ok(Value::U64(0)), - (Signedness::Signed, IntegerBitSize::One) => Ok(Value::I8(0)), - (Signedness::Signed, IntegerBitSize::Eight) => Ok(Value::I8(0)), - (Signedness::Signed, IntegerBitSize::Sixteen) => Ok(Value::I16(0)), - (Signedness::Signed, IntegerBitSize::ThirtyTwo) => Ok(Value::I32(0)), - (Signedness::Signed, IntegerBitSize::SixtyFour) => Ok(Value::I64(0)), + (Signedness::Unsigned, IntegerBitSize::One) => Value::U8(0), + (Signedness::Unsigned, IntegerBitSize::Eight) => Value::U8(0), + (Signedness::Unsigned, IntegerBitSize::Sixteen) => Value::U16(0), + (Signedness::Unsigned, IntegerBitSize::ThirtyTwo) => Value::U32(0), + (Signedness::Unsigned, IntegerBitSize::SixtyFour) => Value::U64(0), + (Signedness::Signed, IntegerBitSize::One) => Value::I8(0), + (Signedness::Signed, IntegerBitSize::Eight) => Value::I8(0), + (Signedness::Signed, IntegerBitSize::Sixteen) => Value::I16(0), + (Signedness::Signed, IntegerBitSize::ThirtyTwo) => Value::I32(0), + (Signedness::Signed, IntegerBitSize::SixtyFour) => Value::I64(0), }, - Type::Bool => Ok(Value::Bool(false)), + Type::Bool => Value::Bool(false), Type::String(length_type) => { if let Ok(length) = length_type.evaluate_to_u32(span) { - Ok(Value::String(Rc::new("\0".repeat(length as usize)))) + Value::String(Rc::new("\0".repeat(length as usize))) } else { // Assume we can resolve the length later - Ok(Value::Zeroed(Type::String(length_type))) + Value::Zeroed(Type::String(length_type)) } } Type::FmtString(length_type, captures) => { let length = length_type.evaluate_to_u32(span); let typ = Type::FmtString(length_type, captures); if let Ok(length) = length { - Ok(Value::FormatString(Rc::new("\0".repeat(length as usize)), typ)) + Value::FormatString(Rc::new("\0".repeat(length as usize)), typ) } else { // Assume we can resolve the length later - Ok(Value::Zeroed(typ)) + Value::Zeroed(typ) } } - Type::Unit => Ok(Value::Unit), - Type::Tuple(fields) => Ok(Value::Tuple(try_vecmap(fields, |field| zeroed(field, span))?)), - Type::DataType(struct_type, generics) => { - // TODO: Handle enums - let fields = struct_type.borrow().get_fields(&generics).unwrap(); - let mut values = HashMap::default(); - - for (field_name, field_type) in fields { - let field_value = zeroed(field_type, span)?; - values.insert(Rc::new(field_name), field_value); - } + Type::Unit => Value::Unit, + Type::Tuple(fields) => Value::Tuple(vecmap(fields, |field| zeroed(field, span))), + Type::DataType(data_type, generics) => { + let typ = data_type.borrow(); + + if let Some(fields) = typ.get_fields(&generics) { + let mut values = HashMap::default(); + + for (field_name, field_type) in fields { + let field_value = zeroed(field_type, span); + values.insert(Rc::new(field_name), field_value); + } - let typ = Type::DataType(struct_type, generics); - Ok(Value::Struct(values, typ)) + drop(typ); + Value::Struct(values, Type::DataType(data_type, generics)) + } else if let Some(mut variants) = typ.get_variants(&generics) { + // Since we're defaulting to Vec::new(), this'd allow us to construct 0 element + // variants... `zeroed` is often used for uninitialized values e.g. in a BoundedVec + // though so we'll allow it. + let mut args = Vec::new(); + if !variants.is_empty() { + // is_empty & swap_remove let us avoid a .clone() we'd need if we did .get(0) + let (_name, params) = variants.swap_remove(0); + args = vecmap(params, |param| zeroed(param, span)); + } + + drop(typ); + Value::Enum(0, args, Type::DataType(data_type, generics)) + } else { + drop(typ); + Value::Zeroed(Type::DataType(data_type, generics)) + } } Type::Alias(alias, generics) => zeroed(alias.borrow().get_type(&generics), span), Type::CheckedCast { to, .. } => zeroed(*to, span), typ @ Type::Function(..) => { // Using Value::Zeroed here is probably safer than using FuncId::dummy_id() or similar - Ok(Value::Zeroed(typ)) + Value::Zeroed(typ) } Type::MutableReference(element) => { - let element = zeroed(*element, span)?; - Ok(Value::Pointer(Shared::new(element), false)) + let element = zeroed(*element, span); + Value::Pointer(Shared::new(element), false) } // Optimistically assume we can resolve this type later or that the value is unused Type::TypeVariable(_) @@ -1490,7 +1507,7 @@ fn zeroed(return_type: Type, span: Span) -> IResult { | Type::Quoted(_) | Type::Error | Type::TraitAsType(..) - | Type::NamedGeneric(_, _) => Ok(Value::Zeroed(return_type)), + | Type::NamedGeneric(_, _) => Value::Zeroed(return_type), } } @@ -1543,7 +1560,7 @@ fn expr_as_assert( let option_type = tuple_types.pop().unwrap(); let message = message.map(|msg| Value::expression(msg.kind)); - let message = option(option_type, message, location.span).ok()?; + let message = option(option_type, message, location.span); Some(Value::Tuple(vec![predicate, message])) } else { @@ -1589,7 +1606,7 @@ fn expr_as_assert_eq( let option_type = tuple_types.pop().unwrap(); let message = message.map(|message| Value::expression(message.kind)); - let message = option(option_type, message, location.span).ok()?; + let message = option(option_type, message, location.span); Some(Value::Tuple(vec![lhs, rhs, message])) } else { @@ -1765,7 +1782,7 @@ fn expr_as_constructor( None }; - option(return_type, option_value, location.span) + Ok(option(return_type, option_value, location.span)) } // fn as_for(self) -> Option<(Quoted, Expr, Expr)> @@ -1865,7 +1882,7 @@ fn expr_as_if( Some(Value::Tuple(vec![ Value::expression(if_expr.condition.kind), Value::expression(if_expr.consequence.kind), - alternative.ok()?, + alternative, ])) } else { None @@ -1948,7 +1965,7 @@ fn expr_as_lambda( } else { Some(Value::UnresolvedType(typ.typ)) }; - let typ = option(option_unresolved_type.clone(), typ, location.span).unwrap(); + let typ = option(option_unresolved_type.clone(), typ, location.span); Value::Tuple(vec![pattern, typ]) }) .collect(); @@ -1967,7 +1984,7 @@ fn expr_as_lambda( Some(return_type) }; let return_type = return_type.map(Value::UnresolvedType); - let return_type = option(option_unresolved_type, return_type, location.span).ok()?; + let return_type = option(option_unresolved_type, return_type, location.span); let body = Value::expression(lambda.body.kind); @@ -2001,7 +2018,7 @@ fn expr_as_let( Some(Value::UnresolvedType(let_statement.r#type.typ)) }; - let typ = option(option_type, typ, location.span).ok()?; + let typ = option(option_type, typ, location.span); Some(Value::Tuple(vec![ Value::pattern(let_statement.pattern), @@ -2253,7 +2270,7 @@ where let expr_value = unwrap_expr_value(interner, expr_value); let option_value = f(expr_value); - option(return_type, option_value, location.span) + Ok(option(return_type, option_value, location.span)) } // fn resolve(self, in_function: Option) -> TypedExpr @@ -2902,18 +2919,18 @@ fn trait_def_as_trait_constraint( /// Creates a value that holds an `Option`. /// `option_type` must be a Type referencing the `Option` type. -pub(crate) fn option(option_type: Type, value: Option, span: Span) -> IResult { +pub(crate) fn option(option_type: Type, value: Option, span: Span) -> Value { let t = extract_option_generic_type(option_type.clone()); let (is_some, value) = match value { Some(value) => (Value::Bool(true), value), - None => (Value::Bool(false), zeroed(t, span)?), + None => (Value::Bool(false), zeroed(t, span)), }; let mut fields = HashMap::default(); fields.insert(Rc::new("_is_some".to_string()), is_some); fields.insert(Rc::new("_value".to_string()), value); - Ok(Value::Struct(fields, option_type)) + Value::Struct(fields, option_type) } /// Given a type, assert that it's an Option and return the Type for T diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs index a98c892eb34..a79af9a7630 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs @@ -1742,6 +1742,13 @@ impl Type { ) -> Result<(), UnificationError> { use Type::*; + // If the two types are exactly the same then they trivially unify. + // This check avoids potentially unifying very complex types (usually infix + // expressions) when they are the same. + if self == other { + return Ok(()); + } + let lhs = self.follow_bindings_shallow(); let rhs = other.follow_bindings_shallow(); diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs index 5750365c62d..ce9125cd5f0 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs @@ -63,12 +63,15 @@ impl Type { let dummy_span = Span::default(); // evaluate_to_field_element also calls canonicalize so if we just called // `self.evaluate_to_field_element(..)` we'd get infinite recursion. - if let (Ok(lhs_value), Ok(rhs_value)) = ( - lhs.evaluate_to_field_element_helper(&kind, dummy_span, run_simplifications), - rhs.evaluate_to_field_element_helper(&kind, dummy_span, run_simplifications), - ) { - if let Ok(result) = op.function(lhs_value, rhs_value, &kind, dummy_span) { - return Type::Constant(result, kind); + if let Ok(lhs_value) = + lhs.evaluate_to_field_element_helper(&kind, dummy_span, run_simplifications) + { + if let Ok(rhs_value) = + rhs.evaluate_to_field_element_helper(&kind, dummy_span, run_simplifications) + { + if let Ok(result) = op.function(lhs_value, rhs_value, &kind, dummy_span) { + return Type::Constant(result, kind); + } } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/lexer/lexer.rs b/noir/noir-repo/compiler/noirc_frontend/src/lexer/lexer.rs index 0b7bd0991d9..771af3daba0 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/lexer/lexer.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/lexer/lexer.rs @@ -215,8 +215,19 @@ impl<'a> Lexer<'a> { Ok(prev_token.into_single_span(start)) } } + Token::Assign => { + let start = self.position; + if self.peek_char_is('=') { + self.next_char(); + Ok(Token::Equal.into_span(start, start + 1)) + } else if self.peek_char_is('>') { + self.next_char(); + Ok(Token::FatArrow.into_span(start, start + 1)) + } else { + Ok(prev_token.into_single_span(start)) + } + } Token::Bang => self.single_double_peek_token('=', prev_token, Token::NotEqual), - Token::Assign => self.single_double_peek_token('=', prev_token, Token::Equal), Token::Minus => self.single_double_peek_token('>', prev_token, Token::Arrow), Token::Colon => self.single_double_peek_token(':', prev_token, Token::DoubleColon), Token::Slash => { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/lexer/token.rs b/noir/noir-repo/compiler/noirc_frontend/src/lexer/token.rs index 7d11b97ca16..d0a6f05e05a 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/lexer/token.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/lexer/token.rs @@ -91,6 +91,8 @@ pub enum BorrowedToken<'input> { RightBracket, /// -> Arrow, + /// => + FatArrow, /// | Pipe, /// # @@ -212,6 +214,8 @@ pub enum Token { RightBracket, /// -> Arrow, + /// => + FatArrow, /// | Pipe, /// # @@ -296,6 +300,7 @@ pub fn token_to_borrowed_token(token: &Token) -> BorrowedToken<'_> { Token::LeftBracket => BorrowedToken::LeftBracket, Token::RightBracket => BorrowedToken::RightBracket, Token::Arrow => BorrowedToken::Arrow, + Token::FatArrow => BorrowedToken::FatArrow, Token::Pipe => BorrowedToken::Pipe, Token::Pound => BorrowedToken::Pound, Token::Comma => BorrowedToken::Comma, @@ -473,6 +478,7 @@ impl fmt::Display for Token { Token::LeftBracket => write!(f, "["), Token::RightBracket => write!(f, "]"), Token::Arrow => write!(f, "->"), + Token::FatArrow => write!(f, "=>"), Token::Pipe => write!(f, "|"), Token::Pound => write!(f, "#"), Token::Comma => write!(f, ","), diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/expression.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/expression.rs index 90e9e53921e..eff309154e3 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/expression.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/expression.rs @@ -4,7 +4,7 @@ use noirc_errors::Span; use crate::{ ast::{ ArrayLiteral, BlockExpression, CallExpression, CastExpression, ConstructorExpression, - Expression, ExpressionKind, Ident, IfExpression, IndexExpression, Literal, + Expression, ExpressionKind, Ident, IfExpression, IndexExpression, Literal, MatchExpression, MemberAccessExpression, MethodCallExpression, Statement, TypePath, UnaryOp, UnresolvedType, }, parser::{labels::ParsingRuleLabel, parser::parse_many::separated_by_comma, ParserErrorReason}, @@ -91,8 +91,7 @@ impl<'a> Parser<'a> { } /// AtomOrUnaryRightExpression - /// = Atom - /// | UnaryRightExpression + /// = Atom UnaryRightExpression* fn parse_atom_or_unary_right(&mut self, allow_constructors: bool) -> Option { let start_span = self.current_token_span; let mut atom = self.parse_atom(allow_constructors)?; @@ -311,6 +310,10 @@ impl<'a> Parser<'a> { return Some(kind); } + if let Some(kind) = self.parse_match_expr() { + return Some(kind); + } + if let Some(kind) = self.parse_lambda() { return Some(kind); } @@ -518,6 +521,49 @@ impl<'a> Parser<'a> { Some(ExpressionKind::If(Box::new(IfExpression { condition, consequence, alternative }))) } + /// MatchExpression = 'match' ExpressionExceptConstructor '{' MatchRule* '}' + pub(super) fn parse_match_expr(&mut self) -> Option { + let start_span = self.current_token_span; + if !self.eat_keyword(Keyword::Match) { + return None; + } + + let expression = self.parse_expression_except_constructor_or_error(); + + self.eat_left_brace(); + + let rules = self.parse_many( + "match cases", + without_separator().until(Token::RightBrace), + Self::parse_match_rule, + ); + + self.push_error(ParserErrorReason::ExperimentalFeature("Match expressions"), start_span); + Some(ExpressionKind::Match(Box::new(MatchExpression { expression, rules }))) + } + + /// MatchRule = Expression '->' (Block ','?) | (Expression ',') + fn parse_match_rule(&mut self) -> Option<(Expression, Expression)> { + let pattern = self.parse_expression()?; + self.eat_or_error(Token::FatArrow); + + let start_span = self.current_token_span; + let branch = match self.parse_block() { + Some(block) => { + let span = self.span_since(start_span); + let block = Expression::new(ExpressionKind::Block(block), span); + self.eat_comma(); // comma is optional if we have a block + block + } + None => { + let branch = self.parse_expression_or_error(); + self.eat_or_error(Token::Comma); + branch + } + }; + Some((pattern, branch)) + } + /// ComptimeExpression = 'comptime' Block fn parse_comptime_expr(&mut self) -> Option { if !self.eat_keyword(Keyword::Comptime) { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/statement.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/statement.rs index 005216b1deb..37013e91528 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/statement.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/statement.rs @@ -162,10 +162,13 @@ impl<'a> Parser<'a> { } if let Some(kind) = self.parse_if_expr() { - return Some(StatementKind::Expression(Expression { - kind, - span: self.span_since(start_span), - })); + let span = self.span_since(start_span); + return Some(StatementKind::Expression(Expression { kind, span })); + } + + if let Some(kind) = self.parse_match_expr() { + let span = self.span_since(start_span); + return Some(StatementKind::Expression(Expression { kind, span })); } if let Some(block) = self.parse_block() { diff --git a/noir/noir-repo/test_programs/compile_success_empty/comptime_enums/src/main.nr b/noir/noir-repo/test_programs/compile_success_empty/comptime_enums/src/main.nr index a16a4cf4da4..e76792005ab 100644 --- a/noir/noir-repo/test_programs/compile_success_empty/comptime_enums/src/main.nr +++ b/noir/noir-repo/test_programs/compile_success_empty/comptime_enums/src/main.nr @@ -3,6 +3,9 @@ fn main() { let _two = Foo::Couple(1, 2); let _one = Foo::One(3); let _none = Foo::None; + + // Ensure zeroed works with enums + let _zeroed: Foo = std::mem::zeroed(); } } diff --git a/noir/noir-repo/tooling/lsp/src/requests/completion.rs b/noir/noir-repo/tooling/lsp/src/requests/completion.rs index 0c51772935a..b464c3e7adc 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/completion.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/completion.rs @@ -1,6 +1,7 @@ use std::{ collections::{BTreeMap, HashMap, HashSet}, future::{self, Future}, + ops::Deref, }; use async_lsp::ResponseError; @@ -199,15 +200,15 @@ impl<'a> NodeFinder<'a> { }; let location = Location::new(span, self.file); - let Some(ReferenceId::Type(struct_id)) = self.interner.find_referenced(location) else { + let Some(ReferenceId::Type(type_id)) = self.interner.find_referenced(location) else { return; }; - let struct_type = self.interner.get_type(struct_id); - let struct_type = struct_type.borrow(); + let data_type = self.interner.get_type(type_id); + let data_type = data_type.borrow(); // First get all of the struct's fields - let Some(fields) = struct_type.get_fields_as_written() else { + let Some(fields) = data_type.get_fields_as_written() else { return; }; @@ -223,7 +224,7 @@ impl<'a> NodeFinder<'a> { self.completion_items.push(self.struct_field_completion_item( &field.name.0.contents, &field.typ, - struct_type.id, + data_type.id, *field_index, self_prefix, )); @@ -320,10 +321,11 @@ impl<'a> NodeFinder<'a> { match module_def_id { ModuleDefId::ModuleId(id) => module_id = id, - ModuleDefId::TypeId(struct_id) => { - let struct_type = self.interner.get_type(struct_id); + ModuleDefId::TypeId(type_id) => { + let data_type = self.interner.get_type(type_id); + self.complete_enum_variants_without_parameters(&data_type.borrow(), &prefix); self.complete_type_methods( - &Type::DataType(struct_type, vec![]), + &Type::DataType(data_type, vec![]), &prefix, FunctionKind::Any, function_completion_kind, @@ -657,7 +659,7 @@ impl<'a> NodeFinder<'a> { return; }; - let struct_id = get_type_struct_id(typ); + let type_id = get_type_type_id(typ); let is_primitive = typ.is_primitive(); let has_self_param = matches!(function_kind, FunctionKind::SelfType(..)); @@ -669,15 +671,11 @@ impl<'a> NodeFinder<'a> { for (func_id, trait_id) in methods.find_matching_methods(typ, has_self_param, self.interner) { - if let Some(struct_id) = struct_id { + if let Some(type_id) = type_id { let modifiers = self.interner.function_modifiers(&func_id); let visibility = modifiers.visibility; - if !struct_member_is_visible( - struct_id, - visibility, - self.module_id, - self.def_maps, - ) { + if !struct_member_is_visible(type_id, visibility, self.module_id, self.def_maps) + { continue; } } @@ -801,6 +799,23 @@ impl<'a> NodeFinder<'a> { } } + fn complete_enum_variants_without_parameters(&mut self, data_type: &DataType, prefix: &str) { + let Some(variants) = data_type.get_variants_as_written() else { + return; + }; + + for (index, variant) in variants.iter().enumerate() { + // Variants with parameters are represented as functions and are suggested in `complete_type_methods` + if variant.is_function || !name_matches(&variant.name.0.contents, prefix) { + continue; + } + + let item = + self.enum_variant_completion_item(variant.name.to_string(), data_type.id, index); + self.completion_items.push(item); + } + } + fn complete_struct_fields( &mut self, struct_type: &DataType, @@ -1900,13 +1915,13 @@ fn get_array_element_type(typ: Type) -> Option { } } -fn get_type_struct_id(typ: &Type) -> Option { - match typ { +fn get_type_type_id(typ: &Type) -> Option { + match typ.follow_bindings_shallow().deref() { Type::DataType(struct_type, _) => Some(struct_type.borrow().id), Type::Alias(type_alias, generics) => { let type_alias = type_alias.borrow(); let typ = type_alias.get_type(generics); - get_type_struct_id(&typ) + get_type_type_id(&typ) } _ => None, } @@ -1958,7 +1973,7 @@ fn name_matches(name: &str, prefix: &str) -> bool { fn module_def_id_from_reference_id(reference_id: ReferenceId) -> Option { match reference_id { ReferenceId::Module(module_id) => Some(ModuleDefId::ModuleId(module_id)), - ReferenceId::Type(struct_id) => Some(ModuleDefId::TypeId(struct_id)), + ReferenceId::Type(type_id) => Some(ModuleDefId::TypeId(type_id)), ReferenceId::Trait(trait_id) => Some(ModuleDefId::TraitId(trait_id)), ReferenceId::Function(func_id) => Some(ModuleDefId::FunctionId(func_id)), ReferenceId::Alias(type_alias_id) => Some(ModuleDefId::TypeAliasId(type_alias_id)), diff --git a/noir/noir-repo/tooling/lsp/src/requests/completion/completion_items.rs b/noir/noir-repo/tooling/lsp/src/requests/completion/completion_items.rs index 039b745172b..b3367c287a0 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/completion/completion_items.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/completion/completion_items.rs @@ -86,7 +86,14 @@ impl<'a> NodeFinder<'a> { None, // trait_id false, // self_prefix ), - ModuleDefId::TypeId(struct_id) => vec![self.struct_completion_item(name, struct_id)], + ModuleDefId::TypeId(type_id) => { + let data_type = self.interner.get_type(type_id); + if data_type.borrow().is_struct() { + vec![self.struct_completion_item(name, type_id)] + } else { + vec![self.enum_completion_item(name, type_id)] + } + } ModuleDefId::TypeAliasId(id) => vec![self.type_alias_completion_item(name, id)], ModuleDefId::TraitId(trait_id) => vec![self.trait_completion_item(name, trait_id)], ModuleDefId::GlobalId(global_id) => vec![self.global_completion_item(name, global_id)], @@ -106,14 +113,18 @@ impl<'a> NodeFinder<'a> { name: impl Into, id: ModuleId, ) -> CompletionItem { - let completion_item = module_completion_item(name); - self.completion_item_with_doc_comments(ReferenceId::Module(id), completion_item) + let item = module_completion_item(name); + self.completion_item_with_doc_comments(ReferenceId::Module(id), item) } - fn struct_completion_item(&self, name: String, struct_id: TypeId) -> CompletionItem { - let completion_item = - simple_completion_item(name.clone(), CompletionItemKind::STRUCT, Some(name)); - self.completion_item_with_doc_comments(ReferenceId::Type(struct_id), completion_item) + fn struct_completion_item(&self, name: String, type_id: TypeId) -> CompletionItem { + let items = simple_completion_item(name.clone(), CompletionItemKind::STRUCT, Some(name)); + self.completion_item_with_doc_comments(ReferenceId::Type(type_id), items) + } + + fn enum_completion_item(&self, name: String, type_id: TypeId) -> CompletionItem { + let item = simple_completion_item(name.clone(), CompletionItemKind::ENUM, Some(name)); + self.completion_item_with_doc_comments(ReferenceId::Type(type_id), item) } pub(super) fn struct_field_completion_item( @@ -124,33 +135,42 @@ impl<'a> NodeFinder<'a> { field_index: usize, self_type: bool, ) -> CompletionItem { - let completion_item = struct_field_completion_item(field, typ, self_type); - self.completion_item_with_doc_comments( - ReferenceId::StructMember(struct_id, field_index), - completion_item, - ) + let item = struct_field_completion_item(field, typ, self_type); + let reference_id = ReferenceId::StructMember(struct_id, field_index); + self.completion_item_with_doc_comments(reference_id, item) } fn type_alias_completion_item(&self, name: String, id: TypeAliasId) -> CompletionItem { - let completion_item = - simple_completion_item(name.clone(), CompletionItemKind::STRUCT, Some(name)); - self.completion_item_with_doc_comments(ReferenceId::Alias(id), completion_item) + let item = simple_completion_item(name.clone(), CompletionItemKind::STRUCT, Some(name)); + self.completion_item_with_doc_comments(ReferenceId::Alias(id), item) } fn trait_completion_item(&self, name: String, trait_id: TraitId) -> CompletionItem { - let completion_item = - simple_completion_item(name.clone(), CompletionItemKind::INTERFACE, Some(name)); - self.completion_item_with_doc_comments(ReferenceId::Trait(trait_id), completion_item) + let item = simple_completion_item(name.clone(), CompletionItemKind::INTERFACE, Some(name)); + self.completion_item_with_doc_comments(ReferenceId::Trait(trait_id), item) } fn global_completion_item(&self, name: String, global_id: GlobalId) -> CompletionItem { let global = self.interner.get_global(global_id); let typ = self.interner.definition_type(global.definition_id); let description = typ.to_string(); + let item = simple_completion_item(name, CompletionItemKind::CONSTANT, Some(description)); + self.completion_item_with_doc_comments(ReferenceId::Global(global_id), item) + } - let completion_item = - simple_completion_item(name, CompletionItemKind::CONSTANT, Some(description)); - self.completion_item_with_doc_comments(ReferenceId::Global(global_id), completion_item) + pub(super) fn enum_variant_completion_item( + &self, + name: String, + type_id: TypeId, + variant_index: usize, + ) -> CompletionItem { + let kind = CompletionItemKind::ENUM_MEMBER; + let item = simple_completion_item(name.clone(), kind, Some(name.clone())); + let item = completion_item_with_detail(item, name); + self.completion_item_with_doc_comments( + ReferenceId::EnumVariant(type_id, variant_index), + item, + ) } #[allow(clippy::too_many_arguments)] @@ -354,6 +374,8 @@ impl<'a> NodeFinder<'a> { if let (Some(type_id), Some(variant_index)) = (func_meta.type_id, func_meta.enum_variant_index) { + completion_item.kind = Some(CompletionItemKind::ENUM_MEMBER); + self.completion_item_with_doc_comments( ReferenceId::EnumVariant(type_id, variant_index), completion_item, diff --git a/noir/noir-repo/tooling/lsp/src/requests/completion/tests.rs b/noir/noir-repo/tooling/lsp/src/requests/completion/tests.rs index a3cd6b0d024..f670f26ffeb 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/completion/tests.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/completion/tests.rs @@ -3094,6 +3094,7 @@ fn main() { assert_eq!(items.len(), 1); let item = &items[0]; + assert_eq!(item.kind, Some(CompletionItemKind::ENUM_MEMBER)); assert_eq!(item.label, "Variant(…)".to_string()); let details = item.label_details.as_ref().unwrap(); @@ -3108,4 +3109,52 @@ fn main() { }; assert!(markdown.value.contains("Some docs")); } + + #[test] + async fn test_suggests_enum_variant_without_parameters() { + let src = r#" + enum Enum { + /// Some docs + Variant + } + + fn foo() { + Enum::Var>|< + } + "#; + let items = get_completions(src).await; + assert_eq!(items.len(), 1); + + let item = &items[0]; + assert_eq!(item.kind, Some(CompletionItemKind::ENUM_MEMBER)); + assert_eq!(item.label, "Variant".to_string()); + + let details = item.label_details.as_ref().unwrap(); + assert_eq!(details.description, Some("Variant".to_string())); + + assert_eq!(item.detail, Some("Variant".to_string())); + assert_eq!(item.insert_text, None); + + let Documentation::MarkupContent(markdown) = item.documentation.as_ref().unwrap() else { + panic!("Expected markdown docs"); + }; + assert!(markdown.value.contains("Some docs")); + } + + #[test] + async fn test_suggests_enum_type() { + let src = r#" + enum ThisIsAnEnum { + } + + fn foo() { + ThisIsA>|< + } + "#; + let items = get_completions(src).await; + assert_eq!(items.len(), 1); + + let item = &items[0]; + assert_eq!(item.kind, Some(CompletionItemKind::ENUM)); + } } diff --git a/noir/noir-repo/tooling/lsp/src/requests/inlay_hint.rs b/noir/noir-repo/tooling/lsp/src/requests/inlay_hint.rs index cbf4ed26ef9..8e091d1eb04 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/inlay_hint.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/inlay_hint.rs @@ -590,6 +590,7 @@ fn get_expression_name(expression: &Expression) -> Option { | ExpressionKind::InternedStatement(..) | ExpressionKind::Literal(..) | ExpressionKind::Unsafe(..) + | ExpressionKind::Match(_) | ExpressionKind::Error => None, } } diff --git a/noir/noir-repo/tooling/nargo_fmt/src/formatter/expression.rs b/noir/noir-repo/tooling/nargo_fmt/src/formatter/expression.rs index ef04276a605..98eabe10e7e 100644 --- a/noir/noir-repo/tooling/nargo_fmt/src/formatter/expression.rs +++ b/noir/noir-repo/tooling/nargo_fmt/src/formatter/expression.rs @@ -2,8 +2,8 @@ use noirc_frontend::{ ast::{ ArrayLiteral, BinaryOpKind, BlockExpression, CallExpression, CastExpression, ConstructorExpression, Expression, ExpressionKind, IfExpression, IndexExpression, - InfixExpression, Lambda, Literal, MemberAccessExpression, MethodCallExpression, - PrefixExpression, TypePath, UnaryOp, UnresolvedTypeData, + InfixExpression, Lambda, Literal, MatchExpression, MemberAccessExpression, + MethodCallExpression, PrefixExpression, TypePath, UnaryOp, UnresolvedTypeData, }, token::{Keyword, Token}, }; @@ -57,6 +57,9 @@ impl<'a, 'b> ChunkFormatter<'a, 'b> { false, // force multiple lines )); } + ExpressionKind::Match(match_expression) => { + group.group(self.format_match_expression(*match_expression)); + } ExpressionKind::Variable(path) => { group.text(self.chunk(|formatter| { formatter.format_path(path); @@ -895,6 +898,68 @@ impl<'a, 'b> ChunkFormatter<'a, 'b> { group } + pub(super) fn format_match_expression( + &mut self, + match_expression: MatchExpression, + ) -> ChunkGroup { + let group_tag = self.new_group_tag(); + let mut group = self.format_match_expression_with_group_tag(match_expression, group_tag); + force_if_chunks_to_multiple_lines(&mut group, group_tag); + group + } + + pub(super) fn format_match_expression_with_group_tag( + &mut self, + match_expression: MatchExpression, + group_tag: GroupTag, + ) -> ChunkGroup { + let mut group = ChunkGroup::new(); + group.tag = Some(group_tag); + group.force_multiple_lines = true; + + group.text(self.chunk(|formatter| { + formatter.write_keyword(Keyword::Match); + formatter.write_space(); + })); + + self.format_expression(match_expression.expression, &mut group); + group.trailing_comment(self.skip_comments_and_whitespace_chunk()); + group.space(self); + + group.text(self.chunk(|formatter| { + formatter.write_left_brace(); + })); + + group.increase_indentation(); + for (pattern, branch) in match_expression.rules { + group.line(); + self.format_expression(pattern, &mut group); + group.text(self.chunk(|formatter| { + formatter.write_space(); + formatter.write_token(Token::FatArrow); + formatter.write_space(); + })); + self.format_expression(branch, &mut group); + + // Add a trailing comma regardless of whether the user specified one or not + group.text(self.chunk(|formatter| { + if formatter.token == Token::Comma { + formatter.write_current_token_and_bump(); + } else { + formatter.write(","); + } + })); + } + group.decrease_indentation(); + group.line(); + + group.text(self.chunk(|formatter| { + formatter.write_right_brace(); + })); + + group + } + fn format_index_expression(&mut self, index: IndexExpression) -> ChunkGroup { let mut group = ChunkGroup::new(); self.format_expression(index.collection, &mut group); @@ -2326,4 +2391,19 @@ global y = 1; "; assert_format_with_max_width(src, expected, " Foo { a: 1 },".len() - 1); } + + #[test] + fn format_match() { + let src = "fn main() { match x { A=>B,C => {D}E=>(), } }"; + // We should remove the block on D for single expressions in the future, + // unless D is an if or match. + let expected = "fn main() { + match x { + A => B, + C => { D }, + E => (), + } +}\n"; + assert_format(src, expected); + } }