Skip to content

Commit 5822584

Browse files
authored
feat: Float is no longer opaque (#14091)
This PR changes the definition of the `Float` and `Float32` types to wrap the `Float.Model` type introduced in #14079. We then redefine `add`, `sub`, `mul`, `div`, `sqrt`, `abs`, `neg`, `isNaN`, `isInf`, `isFinite`, `le`, `lt`, `beq`, `ofInt`, `ofNat`, `ofIntX`, `ofUIntX`, `toIntX`, `toUIntX` on `Float`/`Float32` to delegate to their `Float.Model` counterparts. This does not change anything about compiled code; compiled code uses the accelerated functions from the runtime as before. Minor adjustments to the compiler are necessary for proper handling of the `Float.toModel` and `Float.ofModel` functions. A small breaking change: `Float.lt` and `Float.le` now are `Float -> Float -> Bool` instead of `Float -> Float -> Prop`. This is unlikely to affect anyone as hopefully everyone was using the `LE`/`LT` instances instead. Note that this will *not* grow into a fully-featured float library. Instead, the idea is to make it technically possible for users to connect up our `Float` type to downstream fully-featured float libraries, so that theorems about `Float` can be deduced by transporting lemmas from those libraries. The material about `Float.Model`/`UnpackedFloat` should also not be used as a starting point for such downstream libraries; they are simply not designed for that purpose. Instead, downstream libraries should start from scratch with whatever development of floating-point numbers they deem appropriate and then show equivalence between `Float.Model` and their 64-bit float type afterwards.
1 parent 1fb7efa commit 5822584

9 files changed

Lines changed: 391 additions & 161 deletions

File tree

src/Init/Data/Float/Float.lean

Lines changed: 98 additions & 83 deletions
Large diffs are not rendered by default.

src/Init/Data/Float/Float32.lean

Lines changed: 97 additions & 78 deletions
Large diffs are not rendered by default.

src/Init/Data/Float/Model/Float.lean

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ Negate a `Float.Model`.
100100
def neg (a : Float.Model) : Float.Model :=
101101
pack a.unpack.neg
102102

103+
instance : Neg Float.Model where
104+
neg a := a.neg
105+
103106
/--
104107
Return a `Float.Model` with positive sign.
105108
-/
@@ -292,4 +295,7 @@ Converts a `Float.Model` to an `ISize`, truncating after the decimal point, send
292295
-/
293296
def toISize (f : Float.Model) : ISize := f.unpack.toISize
294297

298+
instance : Inhabited Float.Model where
299+
default := ofNat 0
300+
295301
end Float.Model

src/Init/Data/Float/Model/Float32.lean

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ Negate a `Float32.Model`.
103103
def neg (a : Float32.Model) : Float32.Model :=
104104
pack a.unpack.neg
105105

106+
instance : Neg Float32.Model where
107+
neg a := a.neg
108+
106109
/--
107110
Return a `Float32.Model` with positive sign.
108111
-/
@@ -295,4 +298,7 @@ Converts a `Float32.Model` to an `ISize`, truncating after the decimal point, se
295298
-/
296299
def toISize (f : Float32.Model) : ISize := f.unpack.toISize
297300

301+
instance : Inhabited Float32.Model where
302+
default := ofNat 0
303+
298304
end Float32.Model

src/Lean/Compiler/LCNF/ToMono.lean

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,28 @@ partial def casesStringToMono (c : Cases .pure) (_ : c.typeName == ``String) : T
290290
let k ← k.toMono
291291
return .let decl k
292292

293+
/-- Eliminate `cases` for `Float`. -/
294+
partial def casesFloatToMono (c : Cases .pure) (_ : c.typeName == ``Float) : ToMonoM (Code .pure) := do
295+
assert! c.alts.size == 1
296+
let .alt _ ps k := c.alts[0]! | unreachable!
297+
eraseParams ps
298+
let p := ps[0]!
299+
let decl := { fvarId := p.fvarId, binderName := p.binderName, type := anyExpr, value := .const ``Float.toModel [] #[.fvar c.discr] }
300+
modifyLCtx fun lctx => lctx.addLetDecl decl
301+
let k ← k.toMono
302+
return .let decl k
303+
304+
/-- Eliminate `cases` for `Float32`. -/
305+
partial def casesFloat32ToMono (c : Cases .pure) (_ : c.typeName == ``Float32) : ToMonoM (Code .pure) := do
306+
assert! c.alts.size == 1
307+
let .alt _ ps k := c.alts[0]! | unreachable!
308+
eraseParams ps
309+
let p := ps[0]!
310+
let decl := { fvarId := p.fvarId, binderName := p.binderName, type := anyExpr, value := .const ``Float32.toModel [] #[.fvar c.discr] }
311+
modifyLCtx fun lctx => lctx.addLetDecl decl
312+
let k ← k.toMono
313+
return .let decl k
314+
293315
/-- Eliminate `cases` for `Thunk. -/
294316
partial def casesThunkToMono (c : Cases .pure) (_ : c.typeName == ``Thunk) : ToMonoM (Code .pure) := do
295317
assert! c.alts.size == 1
@@ -373,6 +395,10 @@ partial def Code.toMono (code : Code .pure) : ToMonoM (Code .pure) := do
373395
casesFloatArrayToMono c h
374396
else if h : c.typeName == ``String then
375397
casesStringToMono c h
398+
else if h : c.typeName == ``Float then
399+
casesFloatToMono c h
400+
else if h : c.typeName == ``Float32 then
401+
casesFloat32ToMono c h
376402
else if h : c.typeName == ``Thunk then
377403
casesThunkToMono c h
378404
else if h : c.typeName == ``Task then

tests/compile/float.lean

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,27 @@ def tst3 (xs : List Float) (y : Float) : IO Unit :=
7575
def tst4 (xs : List Float) : IO Unit :=
7676
IO.println (fMap (fun x => x.abs) xs)
7777

78+
/--
79+
`Float.ofModel`/`Float.toModel` behave correctly: the triangle relating the two
80+
`toBits` functions commutes (`Float.toBits = Float.Model.toBits ∘ Float.toModel`),
81+
and round-tripping a `Float` through its `Float.Model` reproduces it bit-for-bit.
82+
-/
83+
def checkModel (f : Float) : Bool :=
84+
-- the two ways of obtaining the bits agree
85+
(f.toBits == f.toModel.toBits) &&
86+
-- `Float.ofModel ∘ Float.toModel = id`, bit-for-bit
87+
((Float.ofModel f.toModel).toBits == f.toBits)
88+
89+
def tst5 : IO Unit := do
90+
let samples : List Float :=
91+
[0.0, -0.0, 1.0, -1.0, 3.14, -2.5, 1e308, 1e-308, 1 / 0, -1 / 0, 0 / 0]
92+
IO.println (samples.all checkModel)
93+
7894
def main : IO Unit := do
7995
tst1
8096
IO.println "-----"
8197
tst2 7
8298
tst3 [3, 4, 7, 8, 9, 11] 2
8399
tst4 [3, -3, 0, -0, -1 / 0, -0 / 0]
100+
IO.println "-----"
101+
tst5

tests/compile/float.lean.out.expected

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,5 @@ true
4646
3.500000
4747
[1.500000, 2.000000, 3.500000, 4.000000, 4.500000, 5.500000]
4848
[3.000000, 3.000000, 0.000000, 0.000000, inf, NaN]
49+
-----
50+
true

tests/elab/float_compiler.lean

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/-!
2+
Check that the compiler correctly deals with `Float` in various situations.
3+
-/
4+
5+
-- If the output of `trace.Compiler.result` turns out to be too unstable, we'll have to
6+
-- assert this some other way.
7+
8+
set_option trace.Compiler.result true
9+
10+
namespace F
11+
12+
/--
13+
trace: [Compiler.result] size: 1
14+
def F.f x.1 : UInt64 :=
15+
let toModel.2 := Float.toModel x.1;
16+
return toModel.2
17+
[Compiler.result] size: 4
18+
def F.f._boxed x.1 : obj :=
19+
let x.10.boxed := unbox x.1;
20+
dec[ref] x.1;
21+
let res := F.f x.10.boxed;
22+
let r := box res;
23+
return r
24+
-/
25+
#guard_msgs in
26+
def f : Float → UInt64
27+
| ⟨model⟩ => model.toBits
28+
29+
/--
30+
trace: [Compiler.result] size: 1
31+
def F.g x.1 : Float :=
32+
let _x.2 := Float.ofModel x.1;
33+
return _x.2
34+
[Compiler.result] size: 4
35+
def F.g._boxed x.1 : obj :=
36+
let x.4.boxed := unbox x.1;
37+
dec[ref] x.1;
38+
let res := F.g x.4.boxed;
39+
let r := box res;
40+
return r
41+
-/
42+
#guard_msgs in
43+
def g : Float.Model → Float
44+
| a => ⟨a⟩
45+
46+
structure Foo where
47+
foo : Float
48+
49+
/--
50+
trace: [Compiler.result] size: 0
51+
def F.h x.1 : Float :=
52+
return x.1
53+
[Compiler.result] size: 4
54+
def F.h._boxed x.1 : obj :=
55+
let x.4.boxed := unbox x.1;
56+
dec[ref] x.1;
57+
let res := F.h x.4.boxed;
58+
let r := box res;
59+
return r
60+
-/
61+
#guard_msgs in
62+
def h : Float → Foo
63+
| a => ⟨a⟩
64+
65+
end F
66+
67+
namespace F32
68+
69+
/--
70+
trace: [Compiler.result] size: 1
71+
def F32.f x.1 : UInt32 :=
72+
let toModel.2 := Float32.toModel x.1;
73+
return toModel.2
74+
[Compiler.result] size: 4
75+
def F32.f._boxed x.1 : tobj :=
76+
let x.10.boxed := unbox x.1;
77+
dec[ref] x.1;
78+
let res := F32.f x.10.boxed;
79+
let r := box res;
80+
return r
81+
-/
82+
#guard_msgs in
83+
def f : Float32 → UInt32
84+
| ⟨model⟩ => model.toBits
85+
86+
/--
87+
trace: [Compiler.result] size: 1
88+
def F32.g x.1 : Float32 :=
89+
let _x.2 := Float32.ofModel x.1;
90+
return _x.2
91+
[Compiler.result] size: 4
92+
def F32.g._boxed x.1 : obj :=
93+
let x.4.boxed := unbox x.1;
94+
dec x.1;
95+
let res := F32.g x.4.boxed;
96+
let r := box res;
97+
return r
98+
-/
99+
#guard_msgs in
100+
def g : Float32.Model → Float32
101+
| a => ⟨a⟩
102+
103+
structure Foo where
104+
foo : Float32
105+
106+
/--
107+
trace: [Compiler.result] size: 0
108+
def F32.h x.1 : Float32 :=
109+
return x.1
110+
[Compiler.result] size: 4
111+
def F32.h._boxed x.1 : obj :=
112+
let x.4.boxed := unbox x.1;
113+
dec[ref] x.1;
114+
let res := F32.h x.4.boxed;
115+
let r := box res;
116+
return r
117+
-/
118+
#guard_msgs in
119+
def h : Float32 → Foo
120+
| a => ⟨a⟩
121+
122+
end F32

tests/elab/float_model_tobits.lean

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/-!
2+
Checks that `Float.ofModel` and `Float.toModel` behave correctly at elaboration
3+
time via `#guard`: the triangle relating the two `toBits` functions commutes
4+
(`Float.toBits = Float.Model.toBits ∘ Float.toModel`), and round-tripping a
5+
`Float` through its `Float.Model` reproduces it bit-for-bit.
6+
-/
7+
8+
-- The `toBits` triangle commutes for a single value.
9+
#guard (1.0 : Float).toBits == (1.0 : Float).toModel.toBits
10+
11+
-- `Float.ofModel ∘ Float.toModel = id`, bit-for-bit.
12+
#guard (Float.ofModel (3.14 : Float).toModel).toBits == (3.14 : Float).toBits
13+
14+
-- Both properties hold across normals, signed zeroes, infinities, and `NaN`.
15+
#guard [0.0, -0.0, 1.0, -1.0, 3.14, -2.5, 1e308, 1e-308, 1 / 0, -1 / 0, 0 / 0].all
16+
fun f => f.toBits == f.toModel.toBits && (Float.ofModel f.toModel).toBits == f.toBits

0 commit comments

Comments
 (0)