Loading [MathJax]/jax/output/CommonHTML/jax.js

Efficient and Modular Implicit Differentiation

Part of Advances in Neural Information Processing Systems 35 (NeurIPS 2022) Main Conference Track

Bibtex Paper Supplemental

Authors

Mathieu Blondel, Quentin Berthet, Marco Cuturi, Roy Frostig, Stephan Hoyer, Felipe Llinares-Lopez, Fabian Pedregosa, Jean-Philippe Vert

Abstract

Automatic differentiation (autodiff) has revolutionized machine learning. Itallows to express complex computations by composing elementary ones in creativeways and removes the burden of computing their derivatives by hand. Morerecently, differentiation of optimization problem solutions has attractedwidespread attention with applications such as optimization layers, and inbi-level problems such as hyper-parameter optimization and meta-learning.However, so far, implicit differentiation remained difficult to use forpractitioners, as it often required case-by-case tedious mathematicalderivations and implementations. In this paper, we proposeautomatic implicit differentiation, an efficientand modular approach for implicit differentiation of optimization problems. Inour approach, the user defines directly in Python a function F capturing theoptimality conditions of the problem to be differentiated. Once this is done, weleverage autodiff of F and the implicit function theorem to automaticallydifferentiate the optimization problem. Our approach thus combines the benefitsof implicit differentiation and autodiff. It is efficient as it can be added ontop of any state-of-the-art solver and modular as the optimality conditionspecification is decoupled from the implicit differentiation mechanism. We showthat seemingly simple principles allow to recover many existing implicitdifferentiation methods and create new ones easily. We demonstrate the ease offormulating and solving bi-level optimization problems using our framework. Wealso showcase an application to the sensitivity analysis of molecular dynamics.