@@ -235,13 +235,36 @@ let make_variant_id_to_containing_type ast =
235235 acc_map variants)
236236 StringMap. empty typedefs
237237
238+ let make_field_to_containing_variant ast =
239+ let typedefs =
240+ List. filter_map (function Elem_Type def -> Some def | _ -> None ) ast
241+ in
242+ List. fold_left
243+ (fun acc_map { Type. variants } ->
244+ List. fold_left
245+ (fun acc_map ({ TypeVariant. term } as variant ) ->
246+ match term with
247+ | Term. Record { fields } ->
248+ let field_names = List. map (fun { Term. name } -> name) fields in
249+ List. fold_left
250+ (fun acc_map field_name ->
251+ StringMap. add field_name variant acc_map)
252+ acc_map field_names
253+ | _ -> acc_map)
254+ acc_map variants)
255+ StringMap. empty typedefs
256+
238257type t = {
239258 ast : AST .t ; (* * The AST, added with builtin definitions, transformed. *)
240259 id_to_defining_node : definition_node StringMap .t ;
241260 (* * Associates identifiers with the AST nodes where they are defined. *)
242261 variant_id_to_containing_type : string StringMap .t ;
243262 (* * Associates variant labels with the name of the type that contains
244263 them. *)
264+ field_to_containing_variant : TypeVariant .t StringMap .t ;
265+ (* * Associates field names with the variant containing them. Since field
266+ names are unique, there is only one variant containing a given field.
267+ *)
245268 assign : Relation .t ;
246269 reverse_assign : Relation .t ;
247270 bottom_constant : Constant .t ;
@@ -279,6 +302,7 @@ let update_spec_ast spec ast =
279302 ast;
280303 id_to_defining_node = make_symbol_table ast;
281304 variant_id_to_containing_type = make_variant_id_to_containing_type ast;
305+ field_to_containing_variant = make_field_to_containing_variant ast;
282306 }
283307
284308let defined_ids self =
@@ -297,13 +321,26 @@ let defining_node_for_id self id =
297321 | Some def -> def
298322 | None -> Error. undefined_element id
299323
300- (* * [relation_for_id self id] returns the relation definition node for the given
301- identifier [id], which is assumed to correspond to a relation definition. *)
302324let relation_for_id self id =
303325 match defining_node_for_id self id with
304326 | Node_Relation def -> def
305327 | _ -> assert false
306328
329+ let record_variant_for_expr spec expr =
330+ match expr with
331+ | Expr. Record { fields } ->
332+ let first_field_name =
333+ match fields with
334+ | (field_name , _ ) :: _ -> field_name
335+ | _ -> failwith " Record expression must have a non-empty list of fields"
336+ in
337+ StringMap. find first_field_name spec.field_to_containing_variant
338+ | _ ->
339+ let msg =
340+ Format. asprintf " Expected record expression, found %a" PP. pp_expr expr
341+ in
342+ failwith msg
343+
307344let is_defined_id self id = StringMap. mem id self.id_to_defining_node
308345let elements self = self.ast
309346
@@ -2351,9 +2388,15 @@ module Check = struct
23512388 in
23522389 Term. Record { label_opt; fields = record_fields }
23532390 in
2391+ let { TypeVariant. term = declared_type } =
2392+ record_variant_for_expr spec expr
2393+ in
23542394 let () =
2355- check_subsumed_by_opt_labelled_type spec inferred_type label_opt
2356- ~context_expr: expr
2395+ if
2396+ not
2397+ (CheckTypeInstantiations. subsumed spec inferred_type
2398+ declared_type)
2399+ then Error. type_subsumption_failure inferred_type declared_type
23572400 in
23582401 (inferred_type, type_env)
23592402 | RecordUpdate { record_expr; updates } -> (
@@ -2403,8 +2446,19 @@ module Check = struct
24032446 Term. Tuple { label_opt; args = anonymous_typed_args }
24042447 in
24052448 let () =
2406- check_subsumed_by_opt_labelled_type spec inferred_type label_opt
2407- ~context_expr: expr
2449+ match label_opt with
2450+ | Some label -> (
2451+ match StringMap. find label spec.id_to_defining_node with
2452+ | Node_TypeVariant { TypeVariant. term = declared_type } ->
2453+ if
2454+ not
2455+ (CheckTypeInstantiations. subsumed spec inferred_type
2456+ declared_type)
2457+ then
2458+ Error. type_subsumption_failure inferred_type
2459+ declared_type
2460+ | _ -> Error. invalid_labelled_type label ~context_expr: expr)
2461+ | None -> ()
24082462 in
24092463 (inferred_type, type_env)
24102464 | Relation { is_operator = true ; name; args = [ lhs; rhs ] }
@@ -2585,24 +2639,6 @@ module Check = struct
25852639 in
25862640 (List. rev types, type_env)
25872641
2588- (* * [check_subsumed_by_opt_labelled_type spec actual_type label_opt
2589- ~context_expr] checks that [actual_type] is subsumed by the labelled
2590- type indicated by [label_opt], if any. The [context_expr] is used for
2591- error reporting. *)
2592- and check_subsumed_by_opt_labelled_type spec actual_type label_opt
2593- ~context_expr =
2594- match label_opt with
2595- | Some label -> (
2596- match StringMap. find label spec.id_to_defining_node with
2597- | Node_TypeVariant { TypeVariant. term = declared_type } ->
2598- if
2599- not
2600- (CheckTypeInstantiations. subsumed spec actual_type
2601- declared_type)
2602- then Error. type_subsumption_failure actual_type declared_type
2603- | _ -> Error. invalid_labelled_type label ~context_expr )
2604- | None -> ()
2605-
26062642 and check_arg_types spec arg_exprs arg_types arg_formal_types ~context_expr
26072643 =
26082644 Utils. list_iter3
@@ -2957,29 +2993,31 @@ module Check = struct
29572993 ~expected: (List. length type_components)
29582994 ~actual: (List. length args)
29592995 | None -> () )
2960- | Record { label_opt; fields } -> (
2996+ | Record { fields } ->
29612997 let expr_field_names, expr_field_inits = List. split fields in
29622998 let () = check_expr_list_in_context expr_field_inits in
2963- match label_opt with
2964- | Some label ->
2965- let record_type_fields =
2966- match StringMap. find label id_to_defining_node with
2967- | Node_TypeVariant { TypeVariant. term = Record { fields } } ->
2968- fields
2969- | _ -> Error. illegal_lhs_application expr
2970- in
2971- let record_type_field_names =
2972- List. map (fun { Term. name } -> name) record_type_fields
2973- in
2974- if
2975- not
2976- (Utils. list_is_equal String. equal expr_field_names
2977- record_type_field_names)
2978- then
2979- Error. invalid_record_field_names expr expr_field_names
2980- record_type_field_names
2981- else ()
2982- | None -> () )
2999+ let { TypeVariant. term = record_term } =
3000+ record_variant_for_expr spec expr
3001+ in
3002+ let record_type_field_names =
3003+ match record_term with
3004+ | Term. Record { fields = record_fields } ->
3005+ List. map (fun { Term. name } -> name) record_fields
3006+ | _ -> failwith " Expected record term."
3007+ in
3008+ let expr_field_names_sorted =
3009+ List. sort String. compare expr_field_names
3010+ in
3011+ let record_type_field_names_sorted =
3012+ List. sort String. compare record_type_field_names
3013+ in
3014+ if
3015+ not
3016+ (Utils. string_list_is_subset expr_field_names
3017+ record_type_field_names)
3018+ then
3019+ Error. invalid_record_field_names expr expr_field_names_sorted
3020+ record_type_field_names_sorted
29833021 | RecordUpdate { record_expr; updates } ->
29843022 let () = check_expr_in_context record_expr in
29853023 let update_field_names, update_field_inits = List. split updates in
@@ -3421,6 +3459,7 @@ let make_spec_with_builtins ast =
34213459 some_operator = get_relation " some" ;
34223460 cond_operator = get_relation " cond_op" ;
34233461 variant_id_to_containing_type = make_variant_id_to_containing_type ast;
3462+ field_to_containing_variant = make_field_to_containing_variant ast;
34243463 }
34253464
34263465let from_ast ast =
0 commit comments