diff --git a/ocaml/core.ml b/ocaml/core.ml index 12bf3c33ef..ee58cb46dd 100644 --- a/ocaml/core.ml +++ b/ocaml/core.ml @@ -75,9 +75,7 @@ let init env = begin | _ -> T.Int 0)); Env.set env (Types.symbol "=") (Types.fn (function - | [T.List a; T.Vector b] -> T.Bool (a = b) - | [T.Vector a; T.List b] -> T.Bool (a = b) - | [a; b] -> T.Bool (a = b) + | [a; b] -> T.Bool (Types.mal_equal a b) | _ -> T.Bool false)); Env.set env (Types.symbol "pr-str") diff --git a/ocaml/types.ml b/ocaml/types.ml index 9df9761042..45d10bdb30 100644 --- a/ocaml/types.ml +++ b/ocaml/types.ml @@ -48,3 +48,22 @@ let rec list_into_map target source = | k :: v :: more -> list_into_map (MalMap.add k v target) more | [] -> map target | _ :: [] -> raise (Invalid_argument "Literal maps must contain an even number of forms") + +let rec mal_list_equal a b = + List.length a = List.length b && List.for_all2 mal_equal a b + +and mal_hash_equal a b = + if MalMap.cardinal a = MalMap.cardinal b + then + let identical_to_b k v = MalMap.mem k b && mal_equal v (MalMap.find k b) in + MalMap.for_all identical_to_b a + else false + +and mal_equal a b = + match (a, b) with + | (Types.List a, Types.List b) + | (Types.List a, Types.Vector b) + | (Types.Vector a, Types.List b) + | (Types.Vector a, Types.Vector b) -> mal_list_equal a.Types.value b.Types.value + | (Types.Map a, Types.Map b) -> mal_hash_equal a.Types.value b.Types.value + | _ -> a = b