Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
ImplicitDifferentiationChainRulesCoreExt = "ChainRulesCore"
ImplicitDifferentiationEnzymeExt = ["EnzymeCore", "ADTypes"]
ImplicitDifferentiationForwardDiffExt = "ForwardDiff"
ImplicitDifferentiationZygoteExt = "Zygote"

Expand All @@ -27,6 +29,7 @@ ChainRulesTestUtils = "1.13.0"
ComponentArrays = "0.15.27"
DifferentiationInterface = "0.6.1,0.7"
Documenter = "1.12.0"
EnzymeCore = "0.8"
ExplicitImports = "1"
FiniteDiff = "2.27.0"
ForwardDiff = "0.10.36, 1"
Expand Down Expand Up @@ -54,6 +57,8 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -80,6 +85,7 @@ test = [
"ComponentArrays",
"DifferentiationInterface",
"Documenter",
"Enzyme",
"ExplicitImports",
"FiniteDiff",
"ForwardDiff",
Expand Down
11 changes: 10 additions & 1 deletion examples/1_basic.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# # Basic use cases

versioninfo()

#=
We show how to differentiate through very common routines:
- an unconstrained optimization problem
Expand All @@ -16,6 +18,7 @@ using NLsolve
using Optim
using Test #src
using Zygote
using Enzyme

#=
In all three cases, we will use the square root as our forward mapping, but expressed in three different ways.
Expand Down Expand Up @@ -75,7 +78,7 @@ end;
We now have all the ingredients to construct our implicit function.
=#

implicit_optim = ImplicitFunction(forward_optim, conditions_optim)
const implicit_optim = ImplicitFunction(forward_optim, conditions_optim)

# And indeed, it behaves as it should when we call it:

Expand All @@ -87,6 +90,12 @@ first(implicit_optim(x, LBFGS())) .^ 2
ForwardDiff.jacobian(_x -> first(implicit_optim(_x, LBFGS())), x)
@test ForwardDiff.jacobian(_x -> first(implicit_optim(_x, LBFGS())), x) ≈ J #src

Enzyme.jacobian(Forward, _x -> first(implicit_optim(_x, LBFGS())), x) |> only
Enzyme.jacobian(Reverse, _x -> first(implicit_optim(_x, LBFGS())), x) |> only

# Fails due to mismatched activity.
# Enzyme.jacobian(Forward, _x -> first(forward_optim(_x, LBFGS())), x)

#=
In this instance, we could use ForwardDiff.jl directly on the solver:
=#
Expand Down
133 changes: 133 additions & 0 deletions ext/ImplicitDifferentiationEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
module ImplicitDifferentiationEnzymeExt

using ADTypes: AutoEnzyme
using EnzymeCore
import EnzymeCore: EnzymeRules
using ImplicitDifferentiation:
ImplicitFunction,
ImplicitFunctionPreparation,
IterativeLeastSquaresSolver,
build_A,
build_Aᵀ,
build_B,
build_Bᵀ

import .EnzymeRules: AugmentedReturn

const AnyDuplicated{T} = Union{Duplicated{T}, BatchDuplicated{T}, DuplicatedNoNeed{T}, BatchDuplicatedNoNeed{T}}

function EnzymeRules.forward(config, implicit::Const{<:ImplicitFunction}, ::Type{<:AnyDuplicated}, x::AnyDuplicated, args::Vararg{<:Const})
implicit = implicit.val

dx = x.dval
x = x.val
args = ntuple(length(args)) do i
args[i].val
end

prep = ImplicitFunctionPreparation(eltype(x))
(; conditions, linear_solver) = implicit

y, z = implicit(x, args...)
c = conditions(x, y, z, args...)

y0 = zero(y)
forward_backend = AutoEnzyme(mode = Forward)
reverse_backend = AutoEnzyme(mode = Reverse)

A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend = forward_backend)
B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend = forward_backend)
Aᵀ = if linear_solver isa IterativeLeastSquaresSolver
build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend = reverse_backend)
else
nothing
end

return if EnzymeRules.width(config) == 1
dc = B(dx)
dy = linear_solver(A, Aᵀ, dc, y0)::typeof(y0)
dz = nothing

if EnzymeRules.needs_primal(config)
return Duplicated((y, z), (dy, dz))
else
return dy, dz
end
else
dc = map(B, dx)
dy = map(dc) do dₖc
linear_solver(A, Aᵀ, -dₖc, y0)
end

df = ntuple(Val(EnzymeRules.width(config))) do i
(dy[i]::typeof(y0), nothing)
end

if EnzymeRules.needs_primal(config)
return BatchDuplicated((y, z), df)
else
# TODO: We need to heal the type instability from the linear solver here
# df::NTuple{EnzymeRules.width(config), Tuple{typeof(y0), Nothing}}
return df::NTuple{EnzymeRules.width(config), Tuple{Vector{Float64}, Nothing}}
end
end
end

function EnzymeRules.augmented_primal(config, implicit::Const{<:ImplicitFunction}, RT::Type{<:AnyDuplicated}, x::AnyDuplicated, args::Vararg{<:Const})
@assert EnzymeRules.width(config) == 1
implicit = implicit.val

x = x.val
args = ntuple(length(args)) do i
args[i].val
end

prep = ImplicitFunctionPreparation(eltype(x))
(; conditions, linear_solver) = implicit

y, z = implicit(x, args...)
c = conditions(x, y, z, args...)
c0 = zero(c)

forward_backend = AutoEnzyme(mode = Forward)
reverse_backend = AutoEnzyme(mode = Reverse)

Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend = reverse_backend)
Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend = reverse_backend)
if linear_solver isa IterativeLeastSquaresSolver
A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend = forward_backend)
else
A = nothing
end

if EnzymeRules.needs_primal(config)
primal = (y, z)
else
primal = nothing
end

dy = EnzymeCore.make_zero(y)
if EnzymeRules.needs_shadow(config)
shadow = (dy, EnzymeCore.make_zero(z))
else
shadow = nothing
end

tape = (; Aᵀ, Bᵀ, A, linear_solver, dy, c0)

AR = EnzymeRules.augmented_rule_return_type(config, RT)

return AR(primal, shadow, tape)
end

function EnzymeRules.reverse(_, ::Const{<:ImplicitFunction}, ::Type, tape, x::AnyDuplicated, ::Vararg{<:Const})
dx = x.dval
(; Aᵀ, Bᵀ, A, linear_solver, dy, c0) = tape

dc = linear_solver(Aᵀ, A, -dy, c0)
dx .+= Bᵀ(dc)

return (nothing, nothing)
end

end # modul
Loading