JAX MD: A Framework for Differentiable Physics

Sam Schoenholz
Google Brain

I will talk about JAX MD, a software package for performing differentiable physics simulations with a focus on molecular dynamics. JAX MD includes a number of physics simulation environments, as well as interaction potentials and neural networks that can be integrated into these environments without writing any additional code. Since the simulations themselves are differentiable functions, entire trajectories can be differentiated to perform meta-optimization. These features are built on primitive operations, such as spatial partitioning, that allow simulations to scale to hundreds-of-thousands of particles on a single GPU. My talk will include an introduction to automatic differentiation through the JAX software package (www.github.com/google/jax) so no background is required. If you are interested in trying out JAX MD, it is available at github.com/google/jax-md

Speaker Bio
Sam is a Senior Research Scientist at Google Brain working at the intersection between Machine Learning and Physics. His work focuses on better understanding neural networks using techniques from statistical physics as well as applying advances in Machine Learning to physical systems. Sam received his PhD from the University of Pennsylvania where he used machine learning to study disordered materials and glassy liquids.

Sam Schoenholz
Start date
Tuesday, Dec. 8, 2020, Noon
Location

Zoom

Share