Reparamterziation Trick of MNV Distribution
Home / Forums / AI & ML: Learn It Yourself / Linear Algebra / Reparamterziation Trick of MNV Distribution
- This topic has 1 reply, 1 voice, and was last updated 3 months ago by
Wolf.
-
AuthorPosts
-
Understanding the Reparameterization Trick for MVN
In the context of Variational Autoencoders (VAEs) and Deep Learning, the Reparameterization Trick is a fundamental technique used to allow backpropagation through stochastic layers. Specifically, for a Multivariate Normal (MVN) distribution, it shifts the randomness to an auxiliary variable, making the sampling process differentiable with respect to the distribution’s parameters.
1. The Core Problem
When we sample a latent vector $\mathbf{z}$ from a distribution $q_{\phi}(\mathbf{z} | \mathbf{x}) = \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\Sigma})$, the sampling operation is non-differentiable.
$$\mathbf{z} \sim \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\Sigma})$$
If we want to optimize $\boldsymbol{\mu}$ and $\boldsymbol{\Sigma}$ using gradient descent, the gradient cannot flow through the stochastic sampling process because the “noise” has no derivative.
2. The Reparameterization Solution
We rewrite $\mathbf{z}$ as a deterministic function of the parameters and an independent noise variable $\boldsymbol{\epsilon}$. For a Multivariate Normal distribution, we use the property that any MVN random variable can be expressed as a linear transformation of a standard normal variable.
3. Step-by-Step Derivation
Step A: Define the Standard Noise
We introduce an auxiliary variable $\boldsymbol{\epsilon}$ sampled from a standard Multivariate Normal distribution:
$$\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$$Step B: Decompose the Covariance Matrix
To transform $\boldsymbol{\epsilon}$ into a variable with covariance $\boldsymbol{\Sigma}$, we need the “square root” of the matrix. For a general MVN, we use the Cholesky Decomposition:
$$\boldsymbol{\Sigma} = \mathbf{L}\mathbf{L}^T$$
where $\mathbf{L}$ is a lower-triangular matrix.In many VAE implementations, we simplify the model by assuming the latent variables are conditionally independent. This means the covariance matrix $\boldsymbol{\Sigma}$ is a diagonal matrix:
$$\boldsymbol{\Sigma} = \text{diag}(\sigma_1^2, \sigma_2^2, …, \sigma_k^2)$$
This reduces the number of parameters the neural network needs to output from $O(d^2)$ to $O(d)$, making the model significantly more efficient and stable during training.
In this case, the Cholesky factor $\mathbf{L}$ is simply:
$$\mathbf{L} = \text{diag}(\sigma_1, \sigma_2, …, \sigma_k)$$Step C: The Transformation Formula
We express $\mathbf{z}$ as:
$$\mathbf{z} = \boldsymbol{\mu} + \mathbf{L}\boldsymbol{\epsilon}$$If we use the diagonal covariance assumption (common in VAEs), where $\odot$ denotes the element-wise (Hadamard) product:
$$\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}$$Step D: Matrix Representation
For the general case where $\boldsymbol{\mu}$ is a vector and $\mathbf{L}$ is the Cholesky factor:
$$\mathbf{z} = \begin{pmatrix} \mu_1 \cr \mu_2 \cr \vdots \cr \mu_k \end{pmatrix} + \begin{pmatrix} L_{11} & 0 & \dots \cr L_{21} & L_{22} & \dots \cr \vdots & \vdots & \ddots \end{pmatrix} \begin{pmatrix} \epsilon_1 \cr \epsilon_2 \cr \vdots \cr \epsilon_k \end{pmatrix}$$4. Why This Works for Gradients
Now, $\mathbf{z}$ is a deterministic function of $\boldsymbol{\mu}$ and $\mathbf{L}$. When we calculate the gradient of a loss function $J$ with respect to the parameters:
$$\frac{\partial J}{\partial \boldsymbol{\mu}} = \frac{\partial J}{\partial \mathbf{z}} \frac{\partial \mathbf{z}}{\partial \boldsymbol{\mu}} = \frac{\partial J}{\partial \mathbf{z}} (1)$$
$$\frac{\partial J}{\partial \mathbf{L}} = \frac{\partial J}{\partial \mathbf{z}} \frac{\partial \mathbf{z}}{\partial \mathbf{L}} = \frac{\partial J}{\partial \mathbf{z}} (\boldsymbol{\epsilon}^T)$$The randomness $\boldsymbol{\epsilon}$ is treated as a constant during the backpropagation step, allowing the model to learn the optimal mean and variance.
Breakdown of $q_{\phi}(\mathbf{z} | \mathbf{x})$
In the context of Variational Autoencoders (VAEs) and Bayesian Inference, the expression $q_{\phi}(\mathbf{z} | \mathbf{x})$ represents the Encoder (or Inference Network). Here is a detailed breakdown of each symbol for an AI learner:
1. The Symbols
Symbol Name Meaning in AI/VAEs $q$ Approximate Posterior Represents a probability distribution that estimates the true but intractable posterior $p(\mathbf{z}|\mathbf{x})$. Usually, we assume this is a Gaussian (Normal) distribution. $\phi$ Variational Parameters These are the weights and biases of the neural network (the Encoder). The network learns these parameters to output the best mean and variance. $\mathbf{z}$ Latent Variable The low-dimensional, “hidden” representation of the data. It captures the essential features (factors of variation) of the input. $\mathbf{x}$ Input Data The raw high-dimensional observation (e.g., an image of a cat or a digit) that is fed into the model. | Conditioning The vertical bar denotes a conditional probability. It means the distribution of $\mathbf{z}$ is determined “given” a specific input $\mathbf{x}$.
2. Conceptual Meaning
In plain English, $q_{\phi}(\mathbf{z} | \mathbf{x})$ answers the question:
“Given this specific input $\mathbf{x}$, what is the most likely range of values for the hidden features $\mathbf{z}$?”
Instead of outputting a single fixed vector, the Encoder outputs the parameters of a distribution:
1. Mean ($\boldsymbol{\mu}$): The most likely location of the input in the latent space.
2. Variance ($\boldsymbol{\sigma}^2$): The degree of uncertainty or “spread” around that mean.3. The Mathematical Form
Usually, we define $q_{\phi}(\mathbf{z} | \mathbf{x})$ as a Multivariate Normal distribution with a diagonal covariance matrix:
$$q_{\phi}(\mathbf{z} | \mathbf{x}) = \mathcal{N}(\mathbf{z}; \boldsymbol{\mu}{\phi}(\mathbf{x}), \text{diag}(\boldsymbol{\sigma}{\phi}^2(\mathbf{x})))$$
Where:
* $\boldsymbol{\mu}{\phi}(\mathbf{x})$ is a function (neural network) that maps $\mathbf{x}$ to the mean.
* $\boldsymbol{\sigma}{\phi}^2(\mathbf{x})$ is a function that maps $\mathbf{x}$ to the variance.
Would you like to see how this relates to the KL Divergence loss used to train the encoder?
-
This reply was modified 3 months ago by
Wolf.
-
This reply was modified 3 months ago by
Wolf.
-
This reply was modified 3 months ago by
Wolf.
-
This reply was modified 3 months ago by
yRocket.
-
This reply was modified 3 months ago by
yRocket.
-
This reply was modified 3 months ago by
yRocket.
-
This reply was modified 3 months ago by
yRocket.
-
This reply was modified 3 months ago by
yRocket.
-
This reply was modified 3 months ago by
yRocket.
-
This reply was modified 3 months ago by
yRocket.
-
This reply was modified 3 months ago by
yRocket.
-
This reply was modified 3 months ago by
yRocket.
-
This reply was modified 3 months ago by
-
AuthorPosts
- You must be logged in to reply to this topic.
