JAX: accelerated machine learning research via composable function transformations in Python
This talk is about JAX, a system for high-performance machine learning research. It offers the familiarity of Python+NumPy together with hardware acceleration on GPUs and TPUs. JAX combines these features with user-wielded function transformations, including automatic differentiation, automatic vectorized batching, end-to-end optimized compilation, parallelization over multiple accelerators, and more. Composing these transformations is key to JAX's power and simplicity.
In this talk you'll see a demo of what JAX does, and then look behind the curtain to see how it all works. You'll also learn what's new, what's next, and how to get involved.