Skip to content

Commit b027a5a

Browse files
committed
Create repo
1 parent e2692e6 commit b027a5a

4 files changed

Lines changed: 206 additions & 3 deletions

File tree

.JuliaFormatter.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
indent = 2

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@ uuid = "4a4311fe-81ff-4bcc-86a3-5026a4bf5b9f"
33
authors = ["chriselrod <elrodc@gmail.com> and contributors"]
44
version = "0.1.0"
55

6+
[deps]
7+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
9+
610
[compat]
11+
ForwardDiff = "0.10"
12+
StaticArrays = "1"
713
julia = "1"
814

915
[extras]

src/RecursiveTupleMath.jl

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,174 @@
1+
"""
2+
Implements broadcasting functions that operate elementwise and recursively across `Tuple`s, `NamedTuple`s, `SArray`s, and `ForwardDiff.Dual` numbers.
3+
4+
The functions are:
5+
- `badd` (broadcast add)
6+
- `bsub` (broadcast sub)
7+
- `bmul` (broadcast mul)
8+
- `bdiv` (broadcast div)
9+
- `bmax` (broadcast max)
10+
- `bmin` (broadcast min)
11+
12+
Note that because it is recursive, `bmul(a, b)` will not necessarilly do the same thing as `map(*, a, b)`. For example, if `a` and `b` are tuples of `SMatrices`, `bmul` will multiply them elementwise, while `map` will multiply them as matrices. This is because `bmul` will call `bmul` on the elements of `a` and `b`, while `map` will call `*` on the elements of `a` and `b`.
13+
"""
114
module RecursiveTupleMath
215

3-
# Write your package code here.
16+
export bmax, bmin, badd, bsub, bmul, bdiv
17+
18+
using StaticArrays, ForwardDiff
19+
20+
@inline lt_fast(a, b) = a < b
21+
@inline lt_fast(a::Float64, b::Float64) = Base.lt_float_fast(a, b)
22+
@inline lt_fast(a::Float32, b::Float32) = Base.lt_float_fast(a, b)
23+
@inline le_fast(a, b) = a <= b
24+
@inline le_fast(a::Float64, b::Float64) = Base.le_float_fast(a, b)
25+
@inline le_fast(a::Float32, b::Float32) = Base.le_float_fast(a, b)
26+
27+
@inline gt_fast(a, b) = lt_fast(b, a)
28+
@inline ge_fast(a, b) = le_fast(b, a)
29+
30+
@inline badd(a, b) = Base.FastMath.add_fast(a, b)
31+
@inline bsub(a, b) = Base.FastMath.sub_fast(a, b)
32+
@inline bmul(a, b) = Base.FastMath.mul_fast(a, b)
33+
@inline bdiv(a, b) = Base.FastMath.div_fast(a, b)
34+
@inline bmax(a, b) = ifelse(gt_fast(a, b), a, b)
35+
@inline bmin(a, b) = ifelse(lt_fast(a, b), a, b)
36+
37+
for bf [:bmax, :bmin, :badd, :bsub, :bmul, :bdiv]
38+
@eval begin
39+
# fall back to fast
40+
# terminating case
41+
@inline $bf(::Number, ::Tuple{}) = ()
42+
@inline $bf(::Tuple{}, ::Number) = ()
43+
@inline $bf(::Number, ::Nothing) = nothing
44+
@inline $bf(::Nothing, ::Number) = nothing
45+
46+
# broadcast
47+
@inline $bf(x::Number, y::StaticArray{S}) where {S} = SArray{S}($bf(x, Tuple(y)))
48+
@inline $bf(y::StaticArray{S}, x::Number) where {S} = SArray{S}($bf(Tuple(y), x))
49+
@inline $bf(x::Number, y::NamedTuple{S}) where {S} = NamedTuple{S}($bf(x, Tuple(y)))
50+
@inline $bf(y::NamedTuple{S}, x::Number) where {S} = NamedTuple{S}($bf(Tuple(y), x))
51+
52+
@inline $bf(a::NamedTuple{S}, b::NamedTuple{S}) where {S} =
53+
NamedTuple{S}($bf(Tuple(a), Tuple(b)))
54+
@inline $bf(a::StaticArray{S}, b::StaticArray{S}) where {S} =
55+
SArray{S}($bf(Tuple(a), Tuple(b)))
56+
57+
# recurse
58+
@inline $bf(a::Number, b::Tuple{T,Vararg}) where {T} =
59+
($bf(a, first(b)), $bf(a, Base.tail(b))...)
60+
@inline $bf(b::Tuple{T,Vararg}, a::Number) where {T} =
61+
($bf(first(b), a), $bf(Base.tail(b), a)...)
62+
@inline $bf(a::Tuple{T,Vararg}, b::Tuple{T,Vararg}) where {T} =
63+
($bf(first(a), first(b)), $bf(Base.tail(a), Base.tail(b))...)
64+
65+
@inline $bf(a::Tuple, b::Tuple) = map($bf, a, b)
66+
end
67+
end
68+
@inline bsub(x::Number) = Base.FastMath.sub_fast(x)
69+
@inline bsub(x::Tuple) = map(bsub, x)
70+
@inline bsub(x::NamedTuple) = map(bsub, x)
71+
@inline bsub(x::StaticArray{S}) where {S} = SArray{S}(map(bsub, Tuple(x)))
72+
73+
@static if VERSION < v"1.7"
74+
struct Returns{T}
75+
v::T
76+
end
77+
(r::Returns)(_) = r.v
78+
end
79+
@inline function btuple(v, ::Val{D}) where {D}
80+
ntuple(Returns(v), Val(D))
81+
end
82+
83+
# @inline
84+
85+
ForwardDiff.@define_binary_dual_op(
86+
RecursiveTupleMath.badd,
87+
ForwardDiff.Dual{Txy}(badd(x.value, y.value), badd(x.partials.values, y.partials.values)),
88+
ForwardDiff.Dual{Tx}(badd(x.value, y), x.partials),
89+
ForwardDiff.Dual{Ty}(badd(x, y.value), y.partials.values)
90+
)
91+
ForwardDiff.@define_binary_dual_op(
92+
RecursiveTupleMath.bsub,
93+
ForwardDiff.Dual{Txy}(bsub(x.value, y.value), bsub(x.partials.values, y.partials.values)),
94+
ForwardDiff.Dual{Tx}(bsub(x.value, y), x.partials),
95+
ForwardDiff.Dual{Ty}(bsub(x, y.value), bsub(y.partials.values))
96+
)
97+
ForwardDiff.@define_binary_dual_op(
98+
RecursiveTupleMath.bmul,
99+
ForwardDiff.Dual{Txy}(
100+
bmul(x.value, y.value),
101+
badd(bmul(x.value, y.partials.values), bmul(x.partials.values, y.value)),
102+
),
103+
ForwardDiff.Dual{Tx}(bmul(x.value, y), bmul(x.partials.values, y)),
104+
ForwardDiff.Dual{Ty}(bmul(x, y.value), bmul(x, y.partials.values))
105+
)
106+
ForwardDiff.@define_binary_dual_op(
107+
RecursiveTupleMath.bdiv,
108+
ForwardDiff.Dual{Txy}(
109+
bdiv(x.value, y.value),
110+
bdiv(
111+
bsub(bmul(x.partials.values, y.value), bmul(x.value, y.partials.values)),
112+
bmul(y.value, y.value),
113+
),
114+
),
115+
ForwardDiff.Dual{Tx}(bdiv(x.value, y), bdiv(bmul(x.partials.values, y), bmul(y, y))),
116+
ForwardDiff.Dual{Ty}(
117+
bdiv(x, y.value),
118+
bdiv(bsub(bmul(x, y.partials.values)), bmul(y.value, y.value)),
119+
),
120+
)
121+
ForwardDiff.@define_binary_dual_op(
122+
RecursiveTupleMath.bmax,
123+
begin
124+
cmp = gt_fast(x.value, y.value)
125+
v = ifelse(cmp, x.value, y.value)
126+
bcmp = btuple(cmp, Val(length(x.partials)))
127+
p = map(ifelse, bcmp, x.partials.values, y.partials.values)
128+
ForwardDiff.Dual{Txy}(v, p)
129+
end,
130+
begin
131+
cmp = gt_fast(x.value, y)
132+
v = ifelse(cmp, x.value, y)
133+
bcmp = btuple(cmp, Val(length(x.partials)))
134+
bnil = map(zero, x.partials.values)
135+
p = map(ifelse, bcmp, x.partials.values, bnil)
136+
ForwardDiff.Dual{Tx}(v, p)
137+
end,
138+
begin
139+
cmp = gt_fast(x, y.value)
140+
v = ifelse(cmp, x, y.value)
141+
bcmp = btuple(cmp, Val(length(y.partials)))
142+
bnil = map(zero, y.partials.values)
143+
p = map(ifelse, bcmp, bnil, y.partials.values)
144+
ForwardDiff.Dual{Ty}(v, p)
145+
end,
146+
)
147+
ForwardDiff.@define_binary_dual_op(
148+
RecursiveTupleMath.bmin,
149+
begin
150+
cmp = lt_fast(x.value, y.value)
151+
v = ifelse(cmp, x.value, y.value)
152+
bcmp = btuple(cmp, Val(length(x.partials)))
153+
p = map(ifelse, bcmp, x.partials.values, y.partials.values)
154+
ForwardDiff.Dual{Txy}(v, p)
155+
end,
156+
begin
157+
cmp = lt_fast(x.value, y)
158+
v = ifelse(cmp, x.value, y)
159+
bcmp = btuple(cmp, Val(length(x.partials)))
160+
bnil = map(zero, x.partials.values)
161+
p = map(ifelse, bcmp, x.partials.values, bnil)
162+
ForwardDiff.Dual{Tx}(v, p)
163+
end,
164+
begin
165+
cmp = lt_fast(x, y.value)
166+
v = ifelse(cmp, x, y.value)
167+
bcmp = btuple(cmp, Val(length(y.partials)))
168+
bnil = map(zero, y.partials.values)
169+
p = map(ifelse, bcmp, bnil, y.partials.values)
170+
ForwardDiff.Dual{Ty}(v, p)
171+
end,
172+
)
4173

5174
end

test/runtests.jl

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,33 @@
1-
using RecursiveTupleMath
1+
using Base: Forward
2+
using RecursiveTupleMath, StaticArrays, ForwardDiff
23
using Test
34

5+
foo(x, y, z) = sum(min.(z .+ x, (x .- y)) ./ max.(z .- y, (x .* y)))
6+
bar(x, y, z) = sum(bdiv(bmin(badd(z, x), bsub(x, y)), bmax(bsub(z, y), bmul(x, y))))
7+
48
@testset "RecursiveTupleMath.jl" begin
5-
# Write your tests here.
9+
x = rand()
10+
xv = @SVector rand(3)
11+
for y in (rand(), @SVector(rand(3))), z in (rand(), @SVector(rand(3)))
12+
13+
@test ForwardDiff.derivative(x -> foo(x, y, z), x)
14+
ForwardDiff.derivative(x -> bar(x, y, z), x)
15+
@test ForwardDiff.derivative(x -> foo(y, x, z), x)
16+
ForwardDiff.derivative(x -> bar(y, x, z), x)
17+
@test ForwardDiff.derivative(x -> foo(y, z, x), x)
18+
ForwardDiff.derivative(x -> bar(y, z, x), x)
19+
20+
@test ForwardDiff.gradient(x -> foo(x, y, z), xv)
21+
ForwardDiff.gradient(x -> bar(x, y, z), xv)
22+
@test ForwardDiff.gradient(x -> foo(y, x, z), xv)
23+
ForwardDiff.gradient(x -> bar(y, x, z), xv)
24+
@test ForwardDiff.gradient(x -> foo(y, z, x), xv)
25+
ForwardDiff.gradient(x -> bar(y, z, x), xv)
26+
@test ForwardDiff.hessian(x -> foo(x, y, z), xv)
27+
ForwardDiff.hessian(x -> bar(x, y, z), xv)
28+
@test ForwardDiff.hessian(x -> foo(y, x, z), xv)
29+
ForwardDiff.hessian(x -> bar(y, x, z), xv)
30+
@test ForwardDiff.hessian(x -> foo(y, z, x), xv)
31+
ForwardDiff.hessian(x -> bar(y, z, x), xv)
32+
end
633
end

0 commit comments

Comments
 (0)