فهرست مطالب
Deep Learning with JAX MEAP V06
Copyright
Welcome
Brief contents
Chapter 1: Intro to JAX
1.1 What is JAX?
1.1.1 JAX as NumPy
1.1.2 Composable transformations
1.2 Why use JAX?
1.2.1 Computational performance
1.2.2 Functional approach
1.2.3 JAX ecosystem
1.3 How is JAX different from TensorFlow/PyTorch?
1.4 Summary
Chapter 2: Your first program in JAX
2.1 A toy ML problem: classifying handwritten digits
2.2 Loading and preparing the dataset
2.3 A simple neural network in JAX
2.3.1 Neural network initialization
2.3.2 Neural network forward pass
2.4 vmap: auto-vectorizing calculations to work with batches
2.5 Autodiff: how to calculate gradients without knowing about derivatives
2.6 JIT: compiling your code to make it faster
2.7 Pure functions and composable transformations: why is it important?
2.8 An overview of a JAX deep learning project
2.9 Exercises
2.10 Summary
Chapter 3: Working with tensors
3.1 Image processing with NumPy arrays
3.1.1 Loading and storing images in NumPy arrays
3.1.2 Performing basic image processing with NumPy API
3.2 Tensors in JAX
3.2.1 Switching to JAX NumPy-like API
3.2.2 What is the DeviceArray?
3.2.3 Device-related operations
3.2.4 Asynchronous dispatch
3.2.5 Moving image processing to TPU
3.3 Differences with NumPy
3.3.1 Immutability
3.3.2 Types
3.4 High-level and low-level interfaces: jax.numpy and jax.lax
3.5 Exercises
3.6 Summary
Chapter 4: Autodiff
4.1 Different ways of getting derivatives
4.1.1 Manual differentiation
4.1.2 Symbolic differentiation
4.1.3 Numerical differentiation
4.1.4 Automatic differentiation
4.2 Calculating gradients with autodiff
4.2.1 Working with gradients in TensorFlow
4.2.2 Working with gradients in PyTorch
4.2.3 Working with gradients in JAX
4.2.4 Higher-order derivatives
4.2.5 Multivariable case
4.3 Forward and Reverse mode autodiff
4.3.1 Evaluation trace
4.3.2 Forward mode and jvp()
4.3.3 Reverse mode and vjp()
4.3.4 Going deeper
4.4 Summary
Chapter 5: Compiling your code
5.1 Using compilation
5.1.1 Using Just-in-Time (JIT) compilation
5.1.2 Pure functions
5.2 JIT internals
5.2.1 Jaxpr, an intermediate representation for JAX programs
5.2.2 XLA
5.2.3 Using Ahead-of-Time (AOT) compilation
5.3 JIT limitations
5.4 Summary
Chapter 6: Vectorizing your code
6.1 Different ways to vectorize a function
6.1.1 Naive approaches
6.1.2 Manual vectorization
6.1.3 Automatic vectorization
6.1.4 Speed comparisons
6.2 Controlling vmap() behavior
6.2.1 Controlling array axes to map over
6.2.2 Controlling output array axes
6.2.3 Using named arguments
6.2.4 Using decorator style
6.2.5 Using collective operations
6.3 Real-life use cases for vmap()
6.3.1 Batch data processing
6.3.2 Batching neural network models
6.3.3 Per-sample gradients
6.3.4 Vectorizing loops
6.4 Summary
Chapter 7: Parallelizing your computations
7.1 Parallelizing computations with pmap()
7.1.1 Setting up a problem
7.1.2 Using pmap (almost) like vmap
7.2 Controlling pmap() behavior
7.2.1 Controlling input and output mapping axes
7.2.2 Using names axes and collectives
7.3 Data parallel neural network training example
7.3.1 Preparing data and neural network structure
7.3.2 Implementing data parallel training
7.4 Summary
Copyright
Welcome
Brief contents
Chapter 1: Intro to JAX
1.1 What is JAX?
1.1.1 JAX as NumPy
1.1.2 Composable transformations
1.2 Why use JAX?
1.2.1 Computational performance
1.2.2 Functional approach
1.2.3 JAX ecosystem
1.3 How is JAX different from TensorFlow/PyTorch?
1.4 Summary
Chapter 2: Your first program in JAX
2.1 A toy ML problem: classifying handwritten digits
2.2 Loading and preparing the dataset
2.3 A simple neural network in JAX
2.3.1 Neural network initialization
2.3.2 Neural network forward pass
2.4 vmap: auto-vectorizing calculations to work with batches
2.5 Autodiff: how to calculate gradients without knowing about derivatives
2.6 JIT: compiling your code to make it faster
2.7 Pure functions and composable transformations: why is it important?
2.8 An overview of a JAX deep learning project
2.9 Exercises
2.10 Summary
Chapter 3: Working with tensors
3.1 Image processing with NumPy arrays
3.1.1 Loading and storing images in NumPy arrays
3.1.2 Performing basic image processing with NumPy API
3.2 Tensors in JAX
3.2.1 Switching to JAX NumPy-like API
3.2.2 What is the DeviceArray?
3.2.3 Device-related operations
3.2.4 Asynchronous dispatch
3.2.5 Moving image processing to TPU
3.3 Differences with NumPy
3.3.1 Immutability
3.3.2 Types
3.4 High-level and low-level interfaces: jax.numpy and jax.lax
3.5 Exercises
3.6 Summary
Chapter 4: Autodiff
4.1 Different ways of getting derivatives
4.1.1 Manual differentiation
4.1.2 Symbolic differentiation
4.1.3 Numerical differentiation
4.1.4 Automatic differentiation
4.2 Calculating gradients with autodiff
4.2.1 Working with gradients in TensorFlow
4.2.2 Working with gradients in PyTorch
4.2.3 Working with gradients in JAX
4.2.4 Higher-order derivatives
4.2.5 Multivariable case
4.3 Forward and Reverse mode autodiff
4.3.1 Evaluation trace
4.3.2 Forward mode and jvp()
4.3.3 Reverse mode and vjp()
4.3.4 Going deeper
4.4 Summary
Chapter 5: Compiling your code
5.1 Using compilation
5.1.1 Using Just-in-Time (JIT) compilation
5.1.2 Pure functions
5.2 JIT internals
5.2.1 Jaxpr, an intermediate representation for JAX programs
5.2.2 XLA
5.2.3 Using Ahead-of-Time (AOT) compilation
5.3 JIT limitations
5.4 Summary
Chapter 6: Vectorizing your code
6.1 Different ways to vectorize a function
6.1.1 Naive approaches
6.1.2 Manual vectorization
6.1.3 Automatic vectorization
6.1.4 Speed comparisons
6.2 Controlling vmap() behavior
6.2.1 Controlling array axes to map over
6.2.2 Controlling output array axes
6.2.3 Using named arguments
6.2.4 Using decorator style
6.2.5 Using collective operations
6.3 Real-life use cases for vmap()
6.3.1 Batch data processing
6.3.2 Batching neural network models
6.3.3 Per-sample gradients
6.3.4 Vectorizing loops
6.4 Summary
Chapter 7: Parallelizing your computations
7.1 Parallelizing computations with pmap()
7.1.1 Setting up a problem
7.1.2 Using pmap (almost) like vmap
7.2 Controlling pmap() behavior
7.2.1 Controlling input and output mapping axes
7.2.2 Using names axes and collectives
7.3 Data parallel neural network training example
7.3.1 Preparing data and neural network structure
7.3.2 Implementing data parallel training
7.4 Summary