Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions src/Init/Internal/Order/While.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,33 @@ public import Init.Internal.Order.MonadTail
set_option linter.missingDocs true

/-!
# Order-theoretic unfolding for `whileM`
# Order-theoretic unfolding for `repeatM`

This module exposes the user-facing one-step unfolding lemma `whileM_eq_of_monadTail`,
This module exposes the user-facing one-step unfolding lemma `repeatM_eq_of_monadTail`,
which holds for any monad with a `Lean.Order.MonadTail` instance. It works by exhibiting
the order-theoretic least fixed point `Lean.Order.fix (whileM.body f)` and feeding it to
the module-internal `whileM_eq`.
the order-theoretic least fixed point `Lean.Order.fix (repeatM.body f)` and feeding it to
the module-internal `repeatM_eq`.

The lemma does not appear in `Init.While` itself to keep that module's import closure
small; clients that want to unfold `whileM`/`Loop.forIn` should import this file.
small; clients that want to unfold `repeatM`/`Loop.forIn` should import this file.
-/

variable {α : Type u} {m : Type u → Type v} [Monad m]

/-- `whileM.body f` is monotone in its `recur` argument whenever `m` admits `MonadTail`. -/
private theorem whileM.body_monotone_of_monadTail [Lean.Order.MonadTail m] [Nonempty β]
/-- `repeatM.body f` is monotone in its `recur` argument whenever `m` admits `MonadTail`. -/
private theorem repeatM.body_monotone_of_monadTail [Lean.Order.MonadTail m] [Nonempty β]
(f : α → m (α ⊕ β)) :
Lean.Order.monotone (whileM.body f) :=
Lean.Order.monotone (repeatM.body f) :=
fun _ _ h _ => Lean.Order.MonadTail.bind_mono_right fun
| .inl a' => h a'
| .inr _ => Lean.Order.PartialOrder.rel_refl

/-- One-step unfolding of `whileM` for any `MonadTail m`. -/
public theorem whileM_eq_of_monadTail [Lean.Order.MonadTail m] [Nonempty β]
/-- One-step unfolding of `repeatM` for any `MonadTail m`. -/
public theorem repeatM_eq_of_monadTail [Lean.Order.MonadTail m] [Nonempty β]
{f : α → m (α ⊕ β)} (a : α) :
whileM f a = whileM.body f (whileM f) a :=
have hMono := whileM.body_monotone_of_monadTail f
whileM_eq a ⟨Lean.Order.fix (whileM.body f) hMono, (Lean.Order.fix_eq hMono).symm⟩
repeatM f a = repeatM.body f (repeatM f) a :=
have hMono := repeatM.body_monotone_of_monadTail f
repeatM_eq a ⟨Lean.Order.fix (repeatM.body f) hMono, (Lean.Order.fix_eq hMono).symm⟩

namespace Lean

Expand All @@ -51,8 +51,8 @@ public theorem Loop.forIn_eq_of_monadTail [LawfulMonad m] [Lean.Order.MonadTail
| .done val => pure val
| .yield val => Loop.forIn l val f) := by
haveI : Nonempty β := ⟨b⟩
show whileM _ b = _
rw [whileM_eq_of_monadTail]; unfold whileM.body; rw [bind_assoc]
show repeatM _ b = _
rw [repeatM_eq_of_monadTail]; unfold repeatM.body; rw [bind_assoc]
exact bind_congr fun
| .done _ => by simp
| .yield _ => by simp [Loop.forIn]
Expand Down
76 changes: 38 additions & 38 deletions src/Init/While.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,81 +10,81 @@ public import Init.Core
public import Init.Classical

/-!
# `whileM`
# `repeatM`

`whileM f a` iterates `f : α → m (α ⊕ β)`, recursing on `.inl` and terminating on
`.inr`. The public unfolding lemma `whileM_eq_of_monadTail`, which requires a
`repeatM f a` iterates `f : α → m (α ⊕ β)`, recursing on `.inl` and terminating on
`.inr`. The public unfolding lemma `repeatM_eq_of_monadTail`, which requires a
`Lean.Order.MonadTail m` instance, lives in `Init.Internal.Order.While` to keep this
module's import closure small.
-/

variable {α : Type u} {m : Type u → Type v} [Monad m]

/-- The body of `whileM`: run `f a`, recurse via `recur` on `.inl`, return on `.inr`. -/
public abbrev whileM.body (f : α → m (α ⊕ β)) (recur : α → m β) (a : α) : m β := do
/-- The body of `repeatM`: run `f a`, recurse via `recur` on `.inl`, return on `.inr`. -/
public abbrev repeatM.body (f : α → m (α ⊕ β)) (recur : α → m β) (a : α) : m β := do
match ← f a with
| .inl a => recur a
| .inr b => pure b

/-- Pinning predicate for `whileM.impl`: trivial unless `whileM.body f` has a fixed point,
/-- Pinning predicate for `repeatM.impl`: trivial unless `repeatM.body f` has a fixed point,
in which case `r` is logically pinned to that fixed point applied to `a`. -/
-- For monads like `List`, `Multiset`, no fixed point of `whileM.body f` need exist:
-- For monads like `List`, `Multiset`, no fixed point of `repeatM.body f` need exist:
-- e.g. for `List`, `f a = [.inr 0, .inl a]` forces `g a = [0] ++ g a`, unsatisfiable in
-- finite lists because `++` isn't idempotent. There this `Pred` collapses to `True`;
-- a future per-point `Acc` / `MonadAttach` branch could pin `r` for the cases where
-- execution from `a` is structurally well-founded.
private abbrev whileM.Pred (f : α → m (α ⊕ β)) (a : α) (r : m β) : Prop :=
private abbrev repeatM.Pred (f : α → m (α ⊕ β)) (a : α) (r : m β) : Prop :=
open scoped Classical in
if h : ∃ g, whileM.body f g = g then
if h : ∃ g, repeatM.body f g = g then
r = h.choose a
else
True

private instance [Nonempty β] {f : α → m (α ⊕ β)} {a : α} :
Nonempty (Subtype (whileM.Pred f a)) :=
Nonempty (Subtype (repeatM.Pred f a)) :=
open scoped Classical in
if h : ∃ g, whileM.body f g = g then
⟨⟨h.choose a, by simp only [whileM.Pred, dif_pos h]⟩⟩
if h : ∃ g, repeatM.body f g = g then
⟨⟨h.choose a, by simp only [repeatM.Pred, dif_pos h]⟩⟩
else
⟨⟨pure (Classical.choice inferInstance), by simp only [whileM.Pred, dif_neg h]⟩⟩
⟨⟨pure (Classical.choice inferInstance), by simp only [repeatM.Pred, dif_neg h]⟩⟩

/-- Computational core of `whileM`: returns the loop value paired with its
`whileM.Pred` proof. -/
private partial def whileM.impl [Nonempty β]
/-- Computational core of `repeatM`: returns the loop value paired with its
`repeatM.Pred` proof. -/
private partial def repeatM.impl [Nonempty β]
(f : α → m (α ⊕ β)) (a : α) :
Subtype (whileM.Pred f a) :=
whileM.body f (whileM.impl f · |>.val) a, by
simp only [whileM.Pred]
Subtype (repeatM.Pred f a) :=
repeatM.body f (repeatM.impl f · |>.val) a, by
simp only [repeatM.Pred]
split <;> rename_i h
· have key : (fun x => (whileM.impl f x).val) = h.choose := funext fun x => by
simpa only [whileM.Pred, dif_pos h] using (whileM.impl f x).property
· have key : (fun x => (repeatM.impl f x).val) = h.choose := funext fun x => by
simpa only [repeatM.Pred, dif_pos h] using (repeatM.impl f x).property
rw [key]; exact congrFun h.choose_spec a
· trivial⟩

/--
An erased version of `whileM.impl` that eta-expands better in the compiler.
Can be removed once `whileM.impl` optimizes to the same code.
An erased version of `repeatM.impl` that eta-expands better in the compiler.
Can be removed once `repeatM.impl` optimizes to the same code.
-/
@[specialize] private partial def whileM.erased [Nonempty β] (f : α → m (α ⊕ β)) (a : α) : m β :=
whileM.body f (whileM.erased f ·) a
@[specialize] private partial def repeatM.erased [Nonempty β] (f : α → m (α ⊕ β)) (a : α) : m β :=
repeatM.body f (repeatM.erased f ·) a

/--
`whileM f a` iterates `f` at `a`, recursing on `.inl` and terminating on `.inr`.
`repeatM f a` iterates `f` at `a`, recursing on `.inl` and terminating on `.inr`.

Its unfolding lemma is `whileM_eq_of_monadTail`.
Its unfolding lemma is `repeatM_eq_of_monadTail`.
-/
@[implemented_by whileM.erased] -- See comment above `whileM.erased`.
public def whileM [Nonempty β] (f : α → m (α ⊕ β)) (a : α) : m β :=
(whileM.impl f a).val
@[implemented_by repeatM.erased] -- See comment above `repeatM.erased`.
public def repeatM [Nonempty β] (f : α → m (α ⊕ β)) (a : α) : m β :=
(repeatM.impl f a).val

-- This lemma is intentionally private. Users are expected to unfold using
-- `whileM_eq_of_monadTail` instead.
private theorem whileM_eq [Nonempty β] {f : α → m (α ⊕ β)} (a : α)
(h : ∃ g, whileM.body f g = g) :
whileM f a = whileM.body f (whileM f) a := by
have key : (fun x => (whileM.impl f x).val) = h.choose := funext fun x => by
simpa only [whileM.Pred, dif_pos h] using (whileM.impl f x).property
show (whileM.impl f a).val = whileM.body f (fun x => (whileM.impl f x).val) a
-- `repeatM_eq_of_monadTail` instead.
private theorem repeatM_eq [Nonempty β] {f : α → m (α ⊕ β)} (a : α)
(h : ∃ g, repeatM.body f g = g) :
repeatM f a = repeatM.body f (repeatM f) a := by
have key : (fun x => (repeatM.impl f x).val) = h.choose := funext fun x => by
simpa only [repeatM.Pred, dif_pos h] using (repeatM.impl f x).property
show (repeatM.impl f a).val = repeatM.body f (fun x => (repeatM.impl f x).val) a
rw [key, congrFun key a]; exact (congrFun h.choose_spec a).symm

namespace Lean
Expand All @@ -102,7 +102,7 @@ public structure Loop
@[inline, expose] public protected def Loop.forIn {β : Type u} {m : Type u → Type v} [Monad m]
(_ : Loop) (init : β) (f : Unit → β → m (ForInStep β)) : m β :=
haveI : Nonempty β := ⟨init⟩
whileM (a := init) fun b => do
repeatM (a := init) fun b => do
match ← f () b with
| .done b' => pure (.inr b')
| .yield b' => pure (.inl b')
Expand Down
18 changes: 9 additions & 9 deletions src/Std/Do/Triple/SpecLemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2210,14 +2210,14 @@ open Std.Do
variable {α β : Type u} {m : Type u → Type v} {ps : PostShape.{u}}

/--
An invariant for a `whileM` loop, given as a `PostCond` over the `α ⊕ β` cursor:
An invariant for a `repeatM` loop, given as a `PostCond` over the `α ⊕ β` cursor:
`.inl a` is the `continue` case at `a`; `.inr b` is the `break` case with result `b`.
-/
@[spec_invariant_type]
def WhileInvariant (α β : Type u) (ps : PostShape.{u}) :=
PostCond (α ⊕ β) ps

/-- A termination measure for a `whileM` loop, SVal-typed so it can read monadic state. -/
/-- A termination measure for a `repeatM` loop, SVal-typed so it can read monadic state. -/
@[spec_invariant_type]
def WhileVariant (α : Type u) (ps : PostShape.{u}) :=
α → SVal ps.args (ULift Nat)
Expand All @@ -2242,13 +2242,13 @@ private theorem WhileVariant.add_eval {P Q : SPred ps.args} (variant : WhileVari
variable [Monad m] [Lean.Order.MonadTail m] [WPMonad m ps]

/--
Specification for `whileM`. The user supplies a (possibly state-dependent) termination
Specification for `repeatM`. The user supplies a (possibly state-dependent) termination
`measure`, an invariant, and a step `Triple` whose pre asserts the variant evaluates to `ma`
and the in-progress invariant holds, and whose post either continues with a strictly smaller
variant value (the invariant still holding) or finishes with the `.inr` invariant.
-/
@[spec]
theorem Spec.whileM
theorem Spec.repeatM
{init : α} {f : α → m (α ⊕ β)} [Nonempty β]
(measure : WhileVariant α ps)
(inv : WhileInvariant α β ps)
Expand All @@ -2259,22 +2259,22 @@ theorem Spec.whileM
| .inl a' => spred(∃ ma', WhileVariant.eval measure a' ma' ∧ ⌜ma' < ma⌝ ∧ inv.1 (.inl a'))
| .inr b => inv.1 (.inr b),
inv.2)) :
Triple (whileM f init) spred(inv.1 (.inl init))
Triple (repeatM f init) spred(inv.1 (.inl init))
(fun b => inv.1 (.inr b), inv.2) := by
apply WhileVariant.add_eval measure init
apply SPred.exists_elim
intro minit
suffices key : ∀ (n : Nat) (a : α),
(spred(WhileVariant.eval measure a n ∧ inv.1 (.inl a)) ⊢ₛ
wp⟦(_root_.whileM f a : m β)⟧ (fun b => inv.1 (.inr b), inv.2)) from
wp⟦(_root_.repeatM f a : m β)⟧ (fun b => inv.1 (.inr b), inv.2)) from
key minit init
intro n
induction n using Nat.strongRecOn with
| _ n ih =>
intro a
rw [whileM_eq_of_monadTail (f := f) a]
rw [repeatM_eq_of_monadTail (f := f) a]
refine Triple.bind (f := fun x => match x with
| .inl a' => _root_.whileM f a' | .inr a' => Pure.pure a')
| .inl a' => _root_.repeatM f a' | .inr a' => Pure.pure a')
(f a) (step a n) ?_
rintro (a' | b)
· refine Triple.iff.mpr ?_
Expand Down Expand Up @@ -2306,7 +2306,7 @@ theorem Spec.forIn_loop
haveI : Nonempty β := ⟨init⟩
change Triple (_root_.Lean.Loop.forIn l init f) _ _
simp only [_root_.Lean.Loop.forIn]
apply Spec.whileM (β := β) (measure := measure) (inv := inv)
apply Spec.repeatM (β := β) (measure := measure) (inv := inv)
intro b mb
apply Triple.bind
· exact step b mb
Expand Down
44 changes: 29 additions & 15 deletions src/Std/Internal/Do/Triple/SpecLemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2418,28 +2418,28 @@ variable [Monad m] [Lean.Order.MonadTail m] [Assertion Pred] [Assertion EPred]
[WPMonad m Pred EPred]

/--
An invariant for a `whileM` loop, given as a predicate over the `α ⊕ β` cursor:
An invariant for a `repeatM` loop, given as a predicate over the `α ⊕ β` cursor:
`.inl a` is the `continue` case at `a`; `.inr b` is the `break` case with result `b`.
-/
@[spec_invariant_type, simp, grind =]
def WhileInvariant (α β : Type u) (Pred : Type u) :=
def RepeatInvariant (α β : Type u) (Pred : Type u) :=
α ⊕ β → Pred

/-- A termination measure for a `whileM` loop. -/
/-- A termination measure for a `repeatM` loop. -/
@[spec_invariant_type]
def WhileVariant (α : Type u) :=
def RepeatVariant (α : Type u) :=
α → Nat

/--
Specification for `whileM`. The user supplies a termination `measure`, an invariant, and a step
Specification for `repeatM`. The user supplies a termination `measure`, an invariant, and a step
`Triple` whose post either continues with a strictly smaller measure or finishes with the `.inr`
invariant.
-/
@[spec]
theorem Spec.whileM
theorem Spec.repeatM
{init : α} {f : α → m (α ⊕ β)} [Nonempty β]
(measure : WhileVariant α)
(inv : WhileInvariant α β Pred)
(measure : RepeatVariant α)
(inv : RepeatInvariant α β Pred)
(einv : EPred)
(step : ∀ a,
Triple
Expand All @@ -2451,23 +2451,23 @@ theorem Spec.whileM
einv) :
Triple
(inv (.inl init))
(whileM f init)
(repeatM f init)
(fun b => inv (.inr b))
einv := by
suffices key : ∀ (n : Nat) (a : α), measure a ≤ n →
Triple
(inv (.inl a))
(_root_.whileM f a)
(_root_.repeatM f a)
(fun b => inv (.inr b))
einv
from key (measure init) init (Nat.le_refl _)
intro n
induction n using Nat.strongRecOn with
| _ n ih =>
intro a hle
rw [whileM_eq_of_monadTail (f := f) a]
rw [repeatM_eq_of_monadTail (f := f) a]
refine Triple.bind (f := fun x => match x with
| .inl a' => _root_.whileM f a'
| .inl a' => _root_.repeatM f a'
| .inr b => Pure.pure b)
(f a) (fun r => match r with
| .inl a' => ⌜measure a' < measure a⌝ ⊓ inv (.inl a')
Expand All @@ -2480,15 +2480,29 @@ theorem Spec.whileM
exact Triple.iff.mp (ih (measure a') (Nat.lt_of_lt_of_le hlt hle) a' (Nat.le_refl _))
· exact Triple.pure b Lean.Order.PartialOrder.rel_refl

/--
Construct an invariant from a loop invariant `inv` and a break condition `onBreak`.

`inv` holds at the end of every loop iteration (including the breaking one), and `onBreak` holds in
addition to `inv` once the loop is done. For a normal `while` loop `onBreak` can be taken as the
negation of the loop condition.
-/
@[simp]
noncomputable abbrev RepeatInvariant.ofInvariantAndBreak {α : Type u} {Pred : Type u} [Assertion Pred]
(inv : α → Pred) (onBreak : α → Pred) : RepeatInvariant α α Pred
| .inl a => inv a
| .inr a => inv a ⊓ onBreak a


/--
Specification for `forIn` over a `Lean.Loop`. The cursor is `β ⊕ β`: `.inl b` means
"still iterating with `b`", `.inr b` means "finished with result `b`".
-/
@[spec]
theorem Spec.forIn_loop
{l : Lean.Loop} {init : β} {f : Unit → β → m (ForInStep β)}
(measure : WhileVariant β)
(inv : WhileInvariant β β Pred)
(measure : RepeatVariant β)
(inv : RepeatInvariant β β Pred)
(einv : EPred)
(step : ∀ b,
Triple
Expand All @@ -2507,7 +2521,7 @@ theorem Spec.forIn_loop
change Triple (pre := inv (.inl init)) (_root_.Lean.Loop.forIn l init f)
(fun b => inv (.inr b)) einv
simp only [_root_.Lean.Loop.forIn]
apply Spec.whileM (measure := measure) (inv := inv) (einv := einv)
apply Spec.repeatM (measure := measure) (inv := inv) (einv := einv)
intro b
apply Triple.bind
· exact step b
Expand Down
25 changes: 25 additions & 0 deletions tests/bench/mvcgen/sym/test_do_logic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -729,3 +729,28 @@ theorem incr_poly (amounts : List Nat) :


end TopBetaReduction

namespace RepeatInvariantOfInvariantAndBreak

/-! Verifies a `while` loop whose `mvcgen'` invariant is supplied via
`RepeatInvariant.ofInvariantAndBreak`: a loop invariant `inv` that holds after every iteration plus
an `onBreak` condition (here the negated loop condition) that additionally holds once the loop
exits. -/

/-- Counts `i` down from `n`, incrementing the state on each iteration, so the final state is `n`. -/
def countdown (n : Nat) : StateT Nat Id Unit := do
let mut i := n
while i > 0 do
i := i - 1
modify (· + 1)
return

theorem countdown_spec (n : Nat) :
⦃ fun s => s = 0 ⦄ countdown n ⦃ fun _ s => s = n ⦄ := by
mvcgen' [countdown]
case inv1 => exact RepeatInvariant.ofInvariantAndBreak (fun i s => s + i = n) (fun i _ => i = 0)
case inv2 => exact fun i => i
any_goals simp at *
all_goals grind

end RepeatInvariantOfInvariantAndBreak
Loading
Loading