Skip to content

Commit b1f28f3

Browse files
committed
feat: Add more specific elim constraints to non-prim projections
1 parent 97661e8 commit b1f28f3

File tree

2 files changed

+180
-39
lines changed

2 files changed

+180
-39
lines changed

test-suite/success/sort_poly_elab.v

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -499,14 +499,13 @@ Module Records.
499499
Fail Check fun (A:SProp) (x y : R6 A) =>
500500
eq_refl : Conversion.box _ x.(R6f2 _) = Conversion.box _ y.(R6f2 _).
501501

502-
(* Elimination constraints are accumulated by fields, even on independent fields *)
502+
(* Elimination constraints are added specifically for each projection *)
503503
#[projections(primitive=no)] Record R7 (A:Type) := { R7f1 : A; R7f2 : nat }.
504504
(* Record R7@{α α0 ; u |} (A : Type@{α ; u}) : Type@{α0 ; max(Set,u)} *)
505505
(* R7f1@{α α0 ; u |} : forall A : Type@{α ; u}, R7@{α α0 ; u} A -> A
506506
α α0 ; u |= α0 -> α *)
507507
(* R7f2@{α α0 ; u |} : forall A : Type@{α ; u}, R7@{α α0 ; u} A -> nat
508-
α α0 ; u |= α0 -> α
509-
α0 -> Type *)
508+
α α0 ; u |= α0 -> Type *)
510509

511510
(* sigma as a primitive record works better *)
512511
Record Rsigma@{s;u v|} (A:Type@{s;u}) (B:A -> Type@{s;v}) : Type@{s;max(u,v)}
@@ -534,7 +533,7 @@ Module Records.
534533

535534
Unset Primitive Projections.
536535

537-
(* Elimination constraints are accumulated by fields *)
536+
(* Elimination constraints are added specifically for each projection *)
538537
Record R8 := {
539538
R8f1 : Type;
540539
R8f2 : R8f1
@@ -545,6 +544,94 @@ Module Records.
545544
(* R8f2@{α α0 ; u |} : forall r : R8@{α α0 ; u}, R8f1@{α α0 ; u} r
546545
α α0 ; u |= α -> α0
547546
α -> Type *)
547+
548+
Inductive eq {A} x : A -> Type :=
549+
eq_refl : eq x x.
550+
551+
Inductive bool := true | false.
552+
553+
(* Elimination constraints are added specifically for each projection *)
554+
Record R (A : Type) := {
555+
x : A ;
556+
y : eq x x ;
557+
z : bool
558+
}.
559+
(* R@{α α0 α1 α2 ; u u0} : forall _ : Type@{α0 ; u}, Type@{α ; max(Set,u,u0)} *)
560+
(* x@{α α0 α1 α2 ; u u0} : forall (A : Type@{α0 ; u}) (_ : R@{α α0 α1 α2 ; u u0} A), A *)
561+
(* α α0 α1 α2 ; u u0 |= α -> α0 *)
562+
(* y@{α α0 α1 α2 ; u u0} : forall (A : Type@{α0 ; u}) (r : R@{α α0 α1 α2 ; u u0} A),
563+
@eq@{α0 α1 ; u u0} A (x@{α α0 α1 α2 ; u u0} A r) (x@{α α0 α1 α2 ; u u0} A r) *)
564+
(* α α0 α1 α2 ; u u0 |= α -> α0
565+
α -> α1 *)
566+
(* z@{α α0 α1 α2 ; u u0} : forall (A : Type@{α0 ; u}) (_ : R@{α α0 α1 α2 ; u u0} A), bool@{α2 ; } *)
567+
(* α α0 α1 α2 ; u u0 |= α -> α2 *)
568+
569+
(* Elimination constraints added to the inductive itself and propagated to projections.
570+
Elimination constraints of projections are specifically for each projection *)
571+
Record R' := {
572+
a1 : Type ;
573+
a2 : Type ;
574+
a3 : bool;
575+
a4 : forall (b : bool),
576+
match b with
577+
| true => match a3 with (* Depends on a3 *)
578+
| true => a1
579+
| false => a2
580+
end
581+
| false => bool
582+
end
583+
}.
584+
(* R'@{α α0 α1 α2 ; u u0} : Type@{α ; max(Set,u+1,u0+1)} *)
585+
(* α α0 α1 α2 ; u u0 |= α1 -> Type
586+
α2 -> Type,
587+
u0 <= u *)
588+
(* a3@{α α0 α1 α2 ; u u0} : forall _ : R'@{α α0 α1 α2 ; u u0}, bool@{α1 ; } *)
589+
(* α α0 α1 α2 ; u u0 |= α -> α1
590+
α1 -> Type
591+
α2 -> Type,
592+
u0 <= u *)
593+
(* a4@{α α0 α1 α2 ; u u0} : ... *)
594+
(* α α0 α1 α2 ; u u0 |= α -> α0
595+
α -> α1
596+
α -> Type
597+
α1 -> Type
598+
α2 -> Type,
599+
u0 <= u *)
600+
601+
Record R'' := {
602+
b1 : bool ;
603+
b2 : let r := {| x := true; y := eq_refl true ; z := b1 |} in
604+
if z bool r then
605+
bool
606+
else
607+
bool ;
608+
b3 : bool
609+
}.
610+
(* R''@{α α0 α1 α2 α3 α4 α5 ; u} : Type@{α ; Set} *)
611+
(* α α0 α1 α2 α3 α4 α5 ; u |= α0 -> α3
612+
α3 -> Type *)
613+
(* b2 : ... *)
614+
(* α α0 α1 α2 α3 α4 α5 ; u |= α -> α3
615+
α -> α4
616+
α0 -> α3
617+
α3 -> Type *)
618+
(* b3 : ... *)
619+
(* α α0 α1 α2 α3 α4 α5 ; u |= α -> α5
620+
α0 -> α3
621+
α3 -> Type *)
622+
623+
Record R''' := {
624+
b : bool ;
625+
f : let f' :=
626+
fix F n :=
627+
if b then n else O
628+
in
629+
match f' O with
630+
| O => bool
631+
| S _ => nat
632+
end
633+
}.
634+
548635
End Records.
549636

550637
Module Class.

vernac/record.ml

Lines changed: 89 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -565,39 +565,88 @@ let declare_proj_coercion_instance ~flags ref from =
565565
in
566566
()
567567

568+
(* Collects elimination constraints from other projections that might be referenced
569+
* in the type of the current projection being built.
570+
* elim_cstrs_map keeps the mapping of (field name -> elim constraints) *)
571+
let collect_elim_cstrs elim_cstrs_map proj_type =
572+
let open Sorts in
573+
let rec aux c =
574+
(* Aux function to fold over arrays *)
575+
let array_fold_aux cs =
576+
Array.fold_left
577+
(fun elim_cstrs c -> ElimConstraints.union elim_cstrs (aux c))
578+
ElimConstraints.empty cs
579+
in
580+
match Constr.kind c with
581+
| Cast (c, _, _) -> aux c
582+
| Prod (_, t, b) | Lambda (_, t, b) ->
583+
let t_elim_cstrs = aux t in
584+
let b_elim_cstrs = aux b in
585+
ElimConstraints.union t_elim_cstrs b_elim_cstrs
586+
| LetIn (_, c, t, b) ->
587+
let c_elim_cstrs = aux c in
588+
let t_elim_cstrs = aux t in
589+
let b_elim_cstrs = aux b in
590+
let elim_cstrs = ElimConstraints.union c_elim_cstrs t_elim_cstrs in
591+
ElimConstraints.union elim_cstrs b_elim_cstrs
592+
| App (f, args) ->
593+
let f_elim_cstrs = aux f in
594+
let args_elim_cstrs = array_fold_aux args in
595+
ElimConstraints.union f_elim_cstrs args_elim_cstrs
596+
| Const (c, _) -> (
597+
let label = Constant.label c in
598+
match Id.Map.find_opt label elim_cstrs_map with
599+
| None -> ElimConstraints.empty
600+
| Some elim_cstrs -> elim_cstrs)
601+
| Case (_, _, params, ((_, p), _), _, c, branches) ->
602+
let params_elim_cstrs = array_fold_aux params in
603+
let return_elim_cstrs = aux p in
604+
let discr_elim_cstrs = aux c in
605+
let branches = Array.map snd branches in
606+
let branches_elim_cstrs = array_fold_aux branches in
607+
let elim_cstrs =
608+
ElimConstraints.union params_elim_cstrs return_elim_cstrs
609+
in
610+
let elim_cstrs = ElimConstraints.union elim_cstrs discr_elim_cstrs in
611+
ElimConstraints.union elim_cstrs branches_elim_cstrs
612+
| Fix (_, (_, tys, bs)) | CoFix (_, (_, tys, bs)) ->
613+
let tys_elim_cstrs = array_fold_aux tys in
614+
let bs_elim_cstrs = array_fold_aux bs in
615+
ElimConstraints.union tys_elim_cstrs bs_elim_cstrs
616+
| Rel _ -> ElimConstraints.empty
617+
| _ -> ElimConstraints.empty
618+
in
619+
aux proj_type
620+
568621
(* Checks whether the record's quality can be eliminated into the projection's
569622
quality. If not, then it adds the elimination constraint. *)
570-
let check_add_elimination_constraints ~primitive univs record_quality proj_typ =
571-
(* Each field is assigned the elimination constraints from its own definition, plus the
572-
constraints from previous fields.
573-
We accumulate these constraints in case a field depends on a previous field.
574-
This accumulation is an over-approximation, since a field may be independent from the rest,
575-
but checking for dependence at this point seems more complicated and costly. *)
576-
if primitive then univs
623+
let check_add_elimination_constraints ~primitive (entry, binders as univs) fld_id elim_cstrs_map record_quality proj_typ =
624+
(* When the record has primitive projections, then the constraints are added to the record itself,
625+
* not to the projections *)
626+
if primitive then univs, elim_cstrs_map
577627
else
578-
(* XXX: I hope there's a better way to do this... *)
579628
let env = Global.env () in
580629
let evd = Evd.from_env env in
581630
let proj_quality = EConstr.ESorts.quality evd @@ Retyping.get_sort_of env evd @@ EConstr.of_constr proj_typ in
582631
let open QGraph in
583632
let qgraph = Environ.qualities env in
584633
let qgraph = try add_quality record_quality qgraph with AlreadyDeclared -> qgraph in
585634
let qgraph = try add_quality proj_quality qgraph with AlreadyDeclared -> qgraph in
586-
if eliminates_to qgraph record_quality proj_quality then univs
635+
if eliminates_to qgraph record_quality proj_quality then univs, elim_cstrs_map
587636
else
588-
let open Sorts in
589-
let new_elim_cstr = record_quality, ElimConstraint.ElimTo, proj_quality in
590-
let (entry, binders) = univs in
591-
let entry = match entry with
637+
let entry, elim_cstrs_map' = match entry with
592638
| UState.Polymorphic_entry uctx ->
593-
let open UVars.UContext in
594-
let (elim_cstrs, univ_cstrs) = constraints uctx in
595-
let elim_cstrs' = ElimConstraints.add new_elim_cstr elim_cstrs in
596-
let uctx' = make (names uctx) (instance uctx, (elim_cstrs', univ_cstrs)) in
597-
UState.Polymorphic_entry uctx'
598-
| _ -> entry
639+
let open Sorts in
640+
let new_elim_cstr = record_quality, ElimConstraint.ElimTo, proj_quality in
641+
let (elim_cstrs, univ_cstrs) = UVars.UContext.constraints uctx in
642+
let related_elim_cstrs = collect_elim_cstrs elim_cstrs_map proj_typ in
643+
let elim_cstrs' = ElimConstraints.add new_elim_cstr elim_cstrs in
644+
let elim_cstrs' = ElimConstraints.union related_elim_cstrs elim_cstrs' in
645+
let uctx' = UVars.UContext.make (UVars.UContext.names uctx) (UVars.UContext.instance uctx, (elim_cstrs', univ_cstrs)) in
646+
UState.Polymorphic_entry uctx', Id.Map.add fld_id elim_cstrs' elim_cstrs_map
647+
| _ -> entry, elim_cstrs_map
599648
in
600-
(entry, binders)
649+
(entry, binders), elim_cstrs_map'
601650

602651
(* TODO: refactor the declaration part here; this requires some
603652
surgery as Evarutil.finalize is called too early in the path *)
@@ -607,7 +656,7 @@ let check_add_elimination_constraints ~primitive univs record_quality proj_typ =
607656
this could be refactored as noted above by moving to the
608657
higher-level declare constant API *)
609658
let build_named_proj ~primitive ~flags ~univs ~uinstance ~kind env paramdecls
610-
paramargs decl impls {CAst.v=fid; loc} subst nfi ti i indsp mib lifted_fields x rp record_quality =
659+
paramargs decl impls {CAst.v=fid; loc} subst nfi ti i indsp mib lifted_fields x rp record_quality elim_cstrs_map =
611660
let ccl = subst_projection fid subst ti in
612661
let body, p_opt = match decl with
613662
| LocalDef (_,ci,_) -> subst_projection fid subst ci, None
@@ -629,9 +678,13 @@ let build_named_proj ~primitive ~flags ~univs ~uinstance ~kind env paramdecls
629678
in
630679
let proj = it_mkLambda_or_LetIn (mkLambda (x, rp, body)) paramdecls in
631680
let proj_typ = it_mkProd_or_LetIn (mkProd (x, rp, ccl)) paramdecls in
632-
let univs = match decl with
633-
| LocalDef _ -> univs (* A local def might need previous elim constraints but it doesn't introduce new ones *)
634-
| LocalAssum _ -> check_add_elimination_constraints ~primitive univs record_quality proj_typ
681+
let univs, elim_cstrs_map =
682+
match decl with
683+
(* A local def might need previous elim constraints but it doesn't introduce new ones *)
684+
| LocalDef _ -> univs, elim_cstrs_map
685+
| LocalAssum _ ->
686+
check_add_elimination_constraints ~primitive univs fid elim_cstrs_map
687+
record_quality proj_typ
635688
in
636689
let entry = Declare.definition_entry ~univs ~types:proj_typ proj in
637690
let kind = Decls.IsDefinition kind in
@@ -657,29 +710,29 @@ let build_named_proj ~primitive ~flags ~univs ~uinstance ~kind env paramdecls
657710
Impargs.maybe_declare_manual_implicits false refi impls;
658711
declare_proj_coercion_instance ~flags refi (GlobRef.IndRef indsp);
659712
let i = if is_local_assum decl then i+1 else i in
660-
(env, univs, Some kn, i, Projection term::subst)
713+
(elim_cstrs_map, Some kn, i, Projection term::subst)
661714

662715
(** [build_proj] will build a projection for each field, or skip if
663716
the field is anonymous, i.e. [_ : t] *)
664-
let build_proj mib indsp primitive x rp lifted_fields paramdecls paramargs record_quality ~uinstance ~kind
665-
(env, univs, nfi, i, kinds, subst) flags loc decl impls =
717+
let build_proj env mib indsp primitive x rp lifted_fields paramdecls paramargs record_quality ~uinstance ~kind ~univs
718+
(elim_cstrs_map, nfi, i, kinds, subst) flags loc decl impls =
666719
let fi = RelDecl.get_name decl in
667720
let ti = RelDecl.get_type decl in
668-
let (env, univs, sp_proj, i, subst) =
721+
let (elim_cstrs_map, sp_proj, i, subst) =
669722
match fi with
670723
| Anonymous ->
671-
(env, univs, None, i, NoProjection fi::subst)
724+
(elim_cstrs_map, None, i, NoProjection fi::subst)
672725
| Name fid ->
673726
let fid = CAst.make ?loc fid in
674727
try build_named_proj
675728
~primitive ~flags ~univs ~uinstance ~kind env paramdecls paramargs decl impls fid
676-
subst nfi ti i indsp mib lifted_fields x rp record_quality
729+
subst nfi ti i indsp mib lifted_fields x rp record_quality elim_cstrs_map
677730
with NotDefinable why as exn ->
678731
let _, info = Exninfo.capture exn in
679732
warning_or_error ?loc ~info flags indsp why;
680-
(env, univs, None, i, NoProjection fi::subst)
733+
(elim_cstrs_map, None, i, NoProjection fi::subst)
681734
in
682-
(env, univs, nfi - 1, i,
735+
(elim_cstrs_map, nfi - 1, i,
683736
{ Structure.proj_name = fi
684737
; proj_true = is_local_assum decl
685738
; proj_canonical = flags.Data.pf_canonical
@@ -701,6 +754,7 @@ let declare_projections indsp ~kind ~inhabitant_id flags ?fieldlocs fieldimpls =
701754
| Polymorphic auctx -> UState.Polymorphic_entry (UVars.AbstractContext.repr auctx)
702755
in
703756
let univs = univs, UnivNames.empty_binders in
757+
let elim_cstrs_map : Sorts.ElimConstraints.t Id.Map.t = Id.Map.empty in
704758
let record_quality = Sorts.quality mip.mind_sort in
705759
let fields, _ = mip.mind_nf_lc.(0) in
706760
let fields = List.firstn mip.mind_consnrealdecls.(0) fields in
@@ -720,10 +774,10 @@ let declare_projections indsp ~kind ~inhabitant_id flags ?fieldlocs fieldimpls =
720774
| None -> List.make (List.length fields) None
721775
| Some fieldlocs -> fieldlocs
722776
in
723-
let (_, _, _, _, canonical_projections, _) =
777+
let (_, _, _, canonical_projections, _) =
724778
List.fold_left4
725-
(build_proj mib indsp primitive x rp lifted_fields paramdecls paramargs record_quality ~uinstance ~kind)
726-
(env, univs, List.length fields,0,[],[]) flags (List.rev fieldlocs) (List.rev fields) (List.rev fieldimpls)
779+
(build_proj env mib indsp primitive x rp lifted_fields paramdecls paramargs record_quality ~uinstance ~kind ~univs)
780+
(elim_cstrs_map, List.length fields,0,[],[]) flags (List.rev fieldlocs) (List.rev fields) (List.rev fieldimpls)
727781
in
728782
List.rev canonical_projections
729783

0 commit comments

Comments
 (0)