You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a simple decorator to enable gradient checkpointing (e.g. Chen et al. (2016)) in TF2. It isn't very polished, but it's been letting me train bigger GPT-2 models on smaller hardware, so I thought I'd share it.
Basic Usage
Use the checkpointable decorator to allow a function (or callable object such as a Keras Layer) to use gradient checkpointing. If checkpointing is desired, call the decorated function with the _checkpoint keyword argument set to True.
The example below shows a model with 40000 "layers", but checkpointing allows just 400 to be in memory at any point. On a GTX 1070 Ti, this code will result in an OOM error when the _checkpoint argument is set to False.
Arguments which are not float32 tensors (or nested list/tuple structures of such tensors) are allowed, but ignored for the purposes of gradient computation.
Variables
If the decorated function uses variables which are not arguments, pass a list of them via the _watch_vars keyword argument as shown below.
Because gradient checkpoint relies on re-running the forward pass, stochastic layers such as a dropout will give different results for each pass. There is a hacky workaround available, which you can enable by passing _force_seed=True to the decorated function. This will use python's random library to get a random number, and set that as TensorFlow's random seed before each forward pass. If you have a better idea for addressing this issue, please do let me know.
About
Simple gradient checkpointing for eager mode execution