You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This was the original automatic differentiation engine for Flux.jl, before being replaced by Zygote.jl in 2019. Both were written by Mike Innes.
This package is solid and still in active use, but is no longer heavily maintained. PRs and issues may go unanswered.
Introduction
Like ReverseDiff.jl and AutoGrad.jl, Tracker traces through a program by wrapping arrays in a special TrackedArray type. The final answer contains a "tape" of the operations performed, which is reversed by back!:
x =param([1,2,3]) # Tracked 3-element Vector{Float64}f(x) =sum(abs2, x) +prod(x[2:end])
y =f(x) # TrackedRealback!(y) # run back-propagation
Tracker.grad(x) # extract gradient from TrackedArray
This is a much lower-tech approach than that of Zygote, Yota and Diffractor. At best, those can produce fast, compiled Julia code for the reverse pass, instead of an interpreted tape. At worst, they can have extremely long compile-times and can be difficult to debug.
Interface
Instead of calling back! yourself, you can pass the function and the input to gradient:
The original interface to Flux.jl involved a dictionary of arrays called Params, much like Zygote's "implicit" parameter interface. This appears not to be documented.
A more modern way to use Flux relies on withgradient's ability to take gradients with respect to complex nested structures. This is what Optimisers.jl is designed to accept:
julia>using Flux, Tracker
julia> model =Chain(Dense(2=>1, tanh), Dense(1=>1, bias=false));
julia>withgradient(model, rand(Float32, 2)) do m, x
sum(abs2, m(x))
end
(val =0.035716165f0,
grad = ((layers = ((weight = Float32[-0.4241869-0.16741231], bias = Float32[-0.5529184], σ =nothing),
(weight = Float32[-0.04804218;;], bias =nothing, σ =nothing)),),
Float32[0.12706584, -0.08858479]))
Rules
Tracker.jl contains rules for many common operations. It relies on DiffRules.jl for many definitions, and does not connect to the newer ChainRules.jl at all.
To define more rules, use track and @grad. See the source for more examples:
f(x::TrackedArray) =track(f, x) # entry point, via dispatch@gradfunctionf(x)
y =f(data(x)) # forward pass, withtout trackingback(dy) = (dy *∂f∂x(data(x)),) # pullback function, returns a tuplereturn y, back
end