Skip to content

Commit

Permalink
Merge pull request #511 from FStarLang/protz_renormalize_data_types
Browse files Browse the repository at this point in the history
Renormalize data types properly
  • Loading branch information
msprotz authored Dec 27, 2024
2 parents 3823e3d + ebff0d8 commit f82ecfe
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 5 deletions.
12 changes: 8 additions & 4 deletions lib/Checker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1089,16 +1089,18 @@ and assert_cons_of env t id: fields_t =
checker_error env "the annotated type %a is not a variant type" ptyp (TAnonymous t)

and subtype env t1 t2 =
if Options.debug "checker" then
KPrint.bprintf "%a <=? %a\n" ptyp t1 ptyp t2;
let normalize t =
match MonomorphizationState.resolve (expand_abbrev env t) with
match MonomorphizationState.resolve_deep (expand_abbrev env t) with
| TBuf (TApp ((["Eurydice"], "derefed_slice"), [ t ]), _) ->
TApp ((["Eurydice"], "slice"), [t])
| t ->
t
in
match normalize t1, normalize t2 with
let t1 = normalize t1 in
let t2 = normalize t2 in
if Options.debug "checker" then
KPrint.bprintf "%a <=? %a\n" ptyp t1 ptyp t2;
match t1, t2 with
| TInt w1, TInt w2 when w1 = w2 ->
true
| TInt K.SizeT, TInt K.UInt32 when Options.wasm () ->
Expand Down Expand Up @@ -1184,6 +1186,8 @@ and subtype env t1 t2 =
subtype env t2 t1

| _ ->
if Options.debug "checker" then
MonomorphizationState.debug ();
false

and eqtype env t1 t2 =
Expand Down
18 changes: 18 additions & 0 deletions lib/Monomorphization.ml
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,17 @@ let monomorphize_data_types map = object(self)
TQualified chosen_lid
end)#visit_typ () t

(* We need to renormalize entries in the map for the Checker module. For
instance, the map might contain `t (u v) -> t0` and `u v -> u0`, but at
this stage, we will have a type error when trying to compare `t (u v)` and
`t u0`, since the latter does not appear in the map. *)
method private renormalize_entry (n, ts, cgs) chosen_lid =
(* We do this on the fly to make sure that types that appear in ts have
themselves been renormalized. *)
let ts' = List.map resolve_deep ts in
if not (Hashtbl.mem state (n, ts', cgs)) then
Hashtbl.add state (n, ts', cgs) (Black, chosen_lid)

(* Compute the name of a given node in the graph. *)
method private lid_of (n: node) =
let lid, ts, cgs = n in
Expand Down Expand Up @@ -340,6 +351,7 @@ let monomorphize_data_types map = object(self)
(* For tuples, we immediately know how to generate a definition. *)
let fields = List.mapi (fun i arg -> Some (self#field_at i), (arg, false)) args in
self#record (DType (chosen_lid, [ Common.Private ] @ flag, 0, 0, Flat fields));
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
end else begin
(* This specific node has not been visited yet. *)
Expand All @@ -352,6 +364,7 @@ let monomorphize_data_types map = object(self)
begin match Hashtbl.find map lid with
| exception Not_found ->
(* Unknown, external non-polymorphic lid, e.g. Prims.int *)
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
| flags, ((Variant _ | Flat _ | Union _) as def) when under_ref && not (Hashtbl.mem seen_declarations lid) ->
(* Because this looks up a definition in the global map, the
Expand Down Expand Up @@ -382,10 +395,12 @@ let monomorphize_data_types map = object(self)
let branches = List.map (fun (cons, fields) -> cons, subst fields) branches in
let branches = self#visit_branches_t under_ref branches in
self#record (DType (chosen_lid, flag @ flags, 0, 0, Variant branches));
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
| flags, Flat fields ->
let fields = self#visit_fields_t_opt under_ref (subst fields) in
self#record (DType (chosen_lid, flag @ flags, 0, 0, Flat fields));
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
| flags, Union fields ->
let fields = List.map (fun (f, t) ->
Expand All @@ -394,13 +409,16 @@ let monomorphize_data_types map = object(self)
f, t
) fields in
self#record (DType (chosen_lid, flag @ flags, 0, 0, Union fields));
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
| flags, Abbrev t ->
let t = DeBruijn.subst_tn args t in
let t = self#visit_typ under_ref t in
self#record (DType (chosen_lid, flag @ flags, 0, 0, Abbrev t));
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
| _ ->
self#renormalize_entry n chosen_lid;
Hashtbl.replace state n (Black, chosen_lid)
end
end;
Expand Down
30 changes: 29 additions & 1 deletion lib/MonomorphizationState.ml
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
open Ast
open PrintAst.Ops

(* Various bits of state for monomorphization, the two most important being
`state` (type monomorphization) and `generated_lids` (function
monomorphization). *)

(* Monomorphization of data types. *)
type node = lident * typ list * cg list
type color = Gray | Black

(* Each polymorphic type `lid` applied to types `ts` and const generics `ts`
appears in `state`, and maps to `monomorphized_lid`, the name of its
monomorphized instance. *)
let state: (node, color * lident) Hashtbl.t = Hashtbl.create 41

(* Because of polymorphic externals, one still encounters,
post-monomorphizations, application nodes in types (e.g. after instantiating
a polymorphic type scheme). The `resolve*` functions, below, normalize a type
to only contain monomorphic type names (and no more type applications) *)
let resolve t: typ =
match t with
| TApp _ | TCgApp _ when Hashtbl.mem state (flatten_tapp t) ->
Expand All @@ -27,4 +42,17 @@ let resolve_deep = (object(self)
resolve (TTuple ts)
end)#visit_typ ()

let generated_lids: (lident * expr list * typ list, lident) Hashtbl.t = Hashtbl.create 41
(* Monomorphization of functions *)
type reverse_mapping = (lident * expr list * typ list, lident) Hashtbl.t

let generated_lids: reverse_mapping = Hashtbl.create 41

let debug () =
Hashtbl.iter (fun (lid, ts, cgs) (_, monomorphized_lid) ->
KPrint.bprintf "%a <%a> <%a> ~~> %a\n" plid lid pcgs cgs ptyps ts plid
monomorphized_lid
) state;
Hashtbl.iter (fun (lid, es, ts) monomorphized_lid ->
KPrint.bprintf "%a <%a> <%a> ~~> %a\n" plid lid pexprs es ptyps ts plid
monomorphized_lid
) generated_lids

0 comments on commit f82ecfe

Please sign in to comment.