Short answer: not really, mostly because of the dynamic shapes of batches in Geometric Deep Learning (GDL).

Wait, what is JAX?

To build neural networks people nowadays mostly use PyTorch. It already has lots of helper tools, all modern AI models are implemented in transformers library, and there’s another library for GDL called pytorch geometric. To make it fast people write custom GPU kernels to reduce memory movements and speed up models. The computation graph is built on-the-fly during every forward pass.

JAX is a library from Google, something like NumPy on steroids with autodiff. Flax is a helper library with common NN layers and classes implemented. Here computation graph is JIT compiled and GPU kernels are merged together, which makes training/inference fast, especially on non-NVidia devices like TPUs other GPUs. It’s used by almost every major generative AI player, including Anthropic, xAI, DeepMind, Apple, and Cohere.

And what is geometric deep learning?

From https://geometricdeeplearning.com:

Geometric Deep Learning is an attempt for geometric unification of a broad class of ML problems from the perspectives of symmetry and invariance.

Most visible examples are Graph Neural Networks, which are the best for Drug Discovery & Chemistry, Social Networks, Recommender Systems and other graph data.

Still too complicated? There’s a GDL course with mostly theory and lots of math, and there’s a deep learning for molecules and materials course with math, ML, deep learning lessons. These are great courses, but it can be difficult to get started and stay committed. I’m a big fan of Practical Deep Learning https://course.fast.ai, which first builds models, shows you results, and then peels layers of complexity later.

Do people do GDL in JAX?

Way less then in PyTorch! There was a library for graph neural networks in jax called Jraph, but it’s archived now, and all other ones are old and not actively maintained. Good thing now with AI it’s easier to build things from scratch, and rely less on frameworks. Right?

Wrong! Current AI is very good in developing prototypes or things it already saw somewhere, but pretty bad in creating new things or working with new things and new paradigms.

How do you know?

After my previous post about Papers To Datasets I had several follow up conversations, where I realized there’s a big field of predicting molecules’ qualities, but much smaller field of predicting properties of mixtures of molecules. I thought it would be a good idea to explore this a bit more, and inspired by my friend’s course about machine learning I want to wrap it into a fast.ai-like course too.

I pushed some first few helper tools to IGDL repo:

  • I started with a simple Directed Message Passing Neural Network based on chemprop. That was easy and fast even on CPU
  • Then I implemented predicting conductivities through Arrhenius parameters for polymer electrolytes based on chemproppred
  • To implement attention layers and use bigger data I trained a conformer model like in MARCEL on drugs-75k dataset. That’s where I discovered first issues from the title: different molecules have different lengths when encoded, that will create dynamic batch sizes, which triggers recompilation on every iteration in JAX. That loads CPU and your GPU stays idle. To fix it, I implemented dummy padding for batches, but it wastes some compute and is not optimal.
  • Finally, I implemented mixture property prediction like in MolSets, and matches the accuracy of the pytorch version To also expand into other areas, I implemented GraphSAGE recommender system for BlueSky, where instead of many small molecular graphs I have a one big graph that doesn’t fit into a GPU. There I used grain to help me load the data, and coding with AI was worse than doing it by hand, because it’s a new library and LLMs couldn’t understand the structure of it. I wrote a small sampler, but scaling this model will bring more changes, as I saw even here my GPU utilization was not 100%.

What to do then?

  • If you need some results for a common task - just use pytorch-geometric
  • If you need to play with models, understand the details, learn AI and do some research - use JAX