9/19/2021

Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies

Paul Vicol, Luke Metz, and Jascha Sohl-Dickstein, Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies, ICML 2021. (paperOutstanding Paper Awards)

Unrolled computation graphs arise in many scenarios, including training RNNs, tuning hyperparameters through unrolled optimization, and training learned optimizers. Current approaches to optimizing parameters in such computation graphs suffer from high variance gradients, bias, slow updates, or large memory usage. We introduce a method called Persistent Evolution Strategies (PES), which divides the computation graph into a series of truncated unrolls, and performs an evolution strategies-based update step after each unroll. PES eliminates bias from these truncations by accumulating correction terms over the entire sequence of unrolls. PES allows for rapid parameter updates, has low memory usage, is unbiased, and has reasonable variance characteristics. We experimentally demonstrate the advantages of PES compared to several other methods for gradient estimation on synthetic tasks, and show its applicability to training learned optimizers and tuning hyperparameters.
ES
Evolution strategies (ES) is a family of algorithms that estimate gradients using stochastic finite-differences, and which provide an unbiased estimate of the gradient of the objective smoothed with a Gaussian. ES works well on pathological meta-optimization loss surfaces (Metz et al., 2019); however, due to the computational expense of running full unrolls, ES can only practically be applied in a truncated fashion, introducing bias.

Contributions 

  • We introduce a method called Persistent Evolution Strategies (PES) to obtain unbiased gradient estimates for the parameters of an unrolled system from partial unrolls of the system. 
  • We prove that PES is an unbiased gradient estimate for a smoothed version of the loss, and an unbiased estimate of the true gradient for quadratic losses. We provide theoretical and empirical analyses of its variance. 
  • We demonstrate the applicability of PES in several illustrative scenarios: 1) we apply PES to tune hyperparameters including learning rates and momentums, by estimating hypergradients through partial unrolls of optimization algorithms; 2) we use PES to meta-train a learned optimizer; 3) we use PES to learn policy parameters for a continuous control task.
JAX implementation (Supplementary PDF, information)

沒有留言:

張貼留言