JAX
Differential Programming with JAX course, JAX in Action & JaxTon are great. Meta Optimal Transport is nice JAX repo to run/study.
Robert Lange has nice JAX repos.
Learning to use BlackJAX to do Bayesian inference on models.
Equinox is great JAX library.
Notes
Links
- audax - Home for audio ML in JAX. Has common features, learnable frontends, pretrained supervised and self-supervised models.
- tinygp - Extremely lightweight library for building Gaussian Process models in Python, built on top of jax.
- GPJax - Didactic Gaussian process package for researchers in Jax.
- Mctx - Monte Carlo tree search in JAX.
- Pipelined Swarm Training - Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes.
- JAX MuZero - JAX implementation of the MuZero agent.
- Jax Influence - Scalable implementation of Influence Functions in JaX.
- BlackJAX - Library of samplers for JAX that works on CPU as well as GPU. (Twitter) (Contribute) (Sampling Book) (Sampling Book Code)
- GPax - Jax/Flax codebase for Gaussian processes including meta and multi-task Gaussian processes.
- jax-fenics-adjoint - Differentiable interface to FEniCS/Firedrake for JAX using dolfin-adjoint/pyadjoint.
- jax-ekf - Generic EKF, with support for non-Euclidean manifolds.
- PaLM - Jax - Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax.
- Pre-trained image classification models for Jax/Haiku
- Flaxformer: transformer architectures in JAX/Flax
- KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX
- flowjax - Normalizing flow implementations in jax.
- Jax3D - Library for neural rendering in Jax and aims to be a nimble NeRF ecosystem.
- DALL·E 2 in JAX
- JAXNS - Nested sampling in JAX.
- AUX - Audio processing library in JAX, for JAX.
- Nice DeepMind Jax libraries
- Machine Learning with JAX - From Zero to Hero (2021)
- Flax - Neural network library for JAX designed for flexibility. (Docs)
- JAX talks by HuggingFace
- Homomorphic Encryption in JAX
- JAX implementation of Learning to learn by gradient descent by gradient descent
- Normalizing Flows in JAX
- Big Vision - Designed for training large-scale vision models on Cloud TPU VMs. Based on Jax/Flax libraries.
- Jax vs. Julia (Vs PyTorch) (2022) (HN)
- minGPT in JAX
- flaxvision - Selection of neural network models ported from torchvision for JAX & Flax.
- JAX version of clip guided diffusion scripts
- Functorch - Jax-like composable function transforms for PyTorch. (HN)
- Ninjax - Module system for JAX that offers full state access and allows to easily combine modules from other libraries.
- Functional Transformer - Pure-functional implementation of a machine learning transformer model in Python/JAX.
- JAX + Units - Provides and interface between JAX and Pint to allow JAX to support operations with units.
- Infinite Recommendation Networks (∞-AE) in JAX
- Differential Programming with JAX course (Code)
- Algorithms for Privacy-Preserving Machine Learning in JAX
- Connex - Small JAX library built on Equinox whose aim is to incorporate artificial analogues of biological neural network attributes into deep learning research and architecture design.
- Rax - Composable Learning to Rank using JAX.
- JaX is faster than PyTorch but harder to debug
- JAX Meta Learning - Collection of meta-learning algorithms in JAX.
- Gymnax - RL Environments in JAX.
- Pax - Framework to configure and run machine learning experiments on top of Jax.
- SymPy2Jax - Turn SymPy expressions into trainable JAX expressions.
- JAX Typing - Type annotations and runtime checking for shape and dtype of JAX arrays, and PyTrees.
- CoDeX - Data compression in JAX.
- DiBS - Python JAX implementation for DiBS, fully differentiable method for joint Bayesian inference of the DAG and parameters of general, causal Bayesian networks.
- Generative Adversarial Networks in JAX
- Neural implicit queries - Perform geometric queries on neural implicit surfaces like ray casting, intersection testing, fast mesh extraction, closest points, and more.
- CLIP-JAX - Train CLIP models using JAX and transformers.
- BLOOM Inference in JAX
- v-diffusion-jax - V objective diffusion inference code for JAX.
- Euclidean Neural Networks Jax
- Jax + CUDA boilerplate
- mlff - Build neural networks for machine learning force fields with JAX.
- Jax-Triton - Integrations between JAX and Triton.
- JAX Synergistic Memory Inspector - Tool for real-time inspection of the memory usage of a JAX process.
- Jaxformer - Minimal library to train LLMs on TPU in JAX.
- xpag - Modular reinforcement learning library with JAX agents.
- Orbax - Library providing common utilities for JAX users.
- Praxis - Layer library for Pax. While Praxis is optimized for ML at scale, Praxis has a goal to be usable by other JAX-based ML projects.
- JAX in Action (2022)
- DeepMind JAX Ecosystem
- JAX Tutorial
- SBX: Stable Baselines Jax (SB3 + Jax)
- Ciclo - Training loop utilities and abstractions for JAX.
- DYNAMAX - State Space Models library in JAX.
- Myriad - Real-world testbed that aims to bridge trajectory optimization and deep learning.
- JaQMC - JAX accelerated Quantum Monte Carlo.
- Rieoptax - Riemannian Optimization Using JAX.
- Jax 0.4.0
- JAXGA - JAX Geometric Algebra. (HN)
- PyTorch Lightning + Jax
- Structural Time Series (STS) in JAX
- safejax - Serialize JAX/Flax models with safetensors.
- dejax - Accelerated replay buffers in JAX.
- JaxTon - 100 exercises to learn JAX.
- JMP - Mixed Precision library for JAX.
- PIPs JAX - JAX implementation of Persistent Independent Particles.
- Jax Decompiler
- MACE - Equivariant machine learning interatomic potentials in JAX.
- JAX Wavelets - 2D discrete wavelet transform for JAX.
- Autodidax: JAX core from scratch (HN)
- Aesara on JAX
- Training Deep Networks with Data Parallelism in Jax (2023) (HN)
- JAX – Augments numpy and Python code with function transformations (2019) (HN)
- PureJaxRL - End-to-End RL Training in Pure Jax.
- 4000x Speedup in Reinforcement Learning with Jax (2023) (HN)
- Creative Machine Learning - Creative Machine Learning course and notebooks in JAX, PyTorch and Numpy.
- LAST - JAX library for building lattice-based speech transducer models.
- MaxText - Simple, performant and scalable Jax LLM.
- Stable Diffusion JAX
- JAX-AM - Additive manufacturing simulation with JAX.
- AutoBound - Automatically computes upper and lower bounds on functions.
- WAX-ML - Python library for machine-learning and feedback loops on streaming data.
- Oryx - Library for probabilistic programming and deep learning built on top of Jax.
- Path Tracing in JAX
- JAX Metrics - Metrics library for the JAX ecosystem.
- Lineax - Linear solvers in JAX and Equinox.
- Levanter and Haliax - Legibile, Scalable, Reproducible Foundation Models with Named Tensors and Jax.
- MAGVIT: Masked Generative Video Transformer in JAX
- Calabi-Yau metrics with JAX
- JAX Implementation of Llama 2
- SynJax - Neural network library for JAX structured probability distributions. (HN)
- Graph Learning with JAX