Out-of-Distribution Generalization

[math] \newcommand{\indep}[0]{\ensuremath{\perp\!\!\!\perp}} \newcommand{\dpartial}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\abs}[1]{\left| #1 \right|} \newcommand\autoop{\left(} \newcommand\autocp{\right)} \newcommand\autoob{\left[} \newcommand\autocb{\right]} \newcommand{\vecbr}[1]{\langle #1 \rangle} \newcommand{\ui}{\hat{\imath}} \newcommand{\uj}{\hat{\jmath}} \newcommand{\uk}{\hat{k}} \newcommand{\V}{\vec{V}} \newcommand{\half}[1]{\frac{#1}{2}} \newcommand{\recip}[1]{\frac{1}{#1}} \newcommand{\invsqrt}[1]{\recip{\sqrt{#1}}} \newcommand{\halfpi}{\half{\pi}} \newcommand{\windbar}[2]{\Big|_{#1}^{#2}} \newcommand{\rightinfwindbar}[0]{\Big|_{0}^\infty} \newcommand{\leftinfwindbar}[0]{\Big|_{-\infty}^0} \newcommand{\state}[1]{\large\protect\textcircled{\textbf{\small#1}}} \newcommand{\shrule}{\\ \centerline{\rule{13cm}{0.4pt}}} \newcommand{\tbra}[1]{$\bra{#1}$} \newcommand{\tket}[1]{$\ket{#1}$} \newcommand{\tbraket}[2]{$\braket{1}{2}$} \newcommand{\infint}[0]{\int_{-\infty}^{\infty}} \newcommand{\rightinfint}[0]{\int_0^\infty} \newcommand{\leftinfint}[0]{\int_{-\infty}^0} \newcommand{\wavefuncint}[1]{\infint|#1|^2} \newcommand{\ham}[0]{\hat{H}} \newcommand{\mathds}{\mathbb}[/math]

Setup: I.I.D.

We must start by defining what we mean by ‘prediction’. In this particular course, we first assume that each and every input-output pair [math](x,y)[/math], input [math]x[/math] or output [math]y[/math] is sampled independently of each other. This is a pretty strong assumption, since the world often changes based on what we have seen, because those who saw a sample pair may and often do change their behaviors. For instance, consider building a stock price forecasting model. Once you use a predictor to predict whether the price of a particular stock goes up or down and trade based on the outcome, the next input [math]x[/math], that is the stock of your next interest, is not anymore independently selected but based on your own success/failure from the previous trade. This assumption is however also reasonable, because there are many phenomena in which our behaviours do not matter much in a reasonably short horizon. For instance, consider installing and using a bird classifier at a particular forest. With a fixed camera, the input to this classifier will be largely independent of which birds (or not) were seen earlier, although spotting of a particular bird may attract poachers to this forest who would dramatically affect the bird population in a longer time frame. Next, we assume that all these pairs are drawn from the ‘identical’ distribution. This is similar if not identical to the stationarity assumption from RCT. In RCT, we often rely on a double blind experiment design, in order to ensure that the causal effect [math]p^*(y|a,x)[/math] does not change over the trial. In this section as well as conventional statistical learning theory, we assume all input-output pairs were drawn from the same distribution. Combining these two assumptions, we arrive at a so-called training set [math]D[/math] which satisfies

[[math]] \begin{align} p(D) = \prod_{(x,y) \in D} p^*(x, y), \end{align} [[/math]]

according to the definition of independence. We do not have access to nor have knowledge of [math]p^*[/math]. We use this training set [math]D[/math] for both model fitting (training) and selection (validation). Once the predictive model [math]\hat{p}[/math] is ready, we deploy it to make a prediction on a novel input [math]x'[/math] drawn from a distribution [math]q^*[/math]. That is,

[[math]] \begin{align} \hat{y} \sim \hat{p}(y | x'), \end{align} [[/math]]

where [math](x', y') \sim q^*[/math]. We are often not given [math]y'[/math]. After all, [math]y'[/math] is what we want to use our predictive model to infer. We say that the predictive model is accurate, if the following quantity is low:

[[math]] \begin{align} R(\hat{p}) = \mathbb{E}_{(x',y') \sim q^*} \left[ l(y', \hat{p}(y|x')) \right], \end{align} [[/math]]

where [math]l(\cdot, \cdot) \geq 0[/math] is the loss (misclassification rate). In traditional statistical learning theory, [math]q^*[/math] is assumed to be [math]p^*[/math], and under this assumption, the goal of designing a learning algorithm is to minimize a so-called excess risk:

[[math]] \begin{align} R_{\mathrm{excess}}(\hat{p}) = R(\hat{p}) - R(p^*) \end{align} [[/math]]

with respect to [math]\hat{p}[/math]. Since we do not have access to [math]p^*[/math], we often use Monte Carlo approximation to compute [math]R(\hat{p})[/math], as follows

[[math]] \begin{align} \label{eq:empirical-risk} R(\hat{p}) \approx \hat{R}_N(\hat{p}) = \frac{1}{N} \sum_{n=1}^N l(y_n, p(y|x_n)), \end{align} [[/math]]

where [math](x_n, y_n) \sim p^*[/math]. With a (strong) assumption of uniform convergence, which is defined as

[[math]] \begin{align} \sup_{\hat{p}} \left| R(\hat{p}) - \hat{R}_N(\hat{p}) \right| \to_p 0, \end{align} [[/math]]

we can minimize [math]R[/math] using [math]\hat{R}[/math] with a large enough data set, i.e., [math]N \to \infty[/math], and find a good predictive model [math]\hat{p}[/math]. Of course, since [math]N[/math] is always finite in reality, there is almost always non-zero generalization error. Since we never have access to [math]R(\hat{p})[/math] even after learning, it is a usual practice to use a separate (held-out) set of examples again drawn from the same distribution [math]p^* = q^*[/math] as the test set to approximate the generalization error of a trained model [math]\hat{p}[/math]. Let [math]D' = \left\{ (x'_1,y'_1), \ldots, (x'_K, y'_K) \right\}[/math]. Then,

[[math]] \begin{align} R(\hat{p}) \approx \frac{1}{K} \sum_{k=1}^K l(y'_k, \hat{p}(y|x'_k)). \end{align} [[/math]]

Such a test-set accuracy, or more simply a test accuracy, has been a workhorse behind rapid advances in machine learning over the past several decades. With this whole paradigm in your mind, it is important to notice that the key assumption here is [math]q^*(x,y)=p^*(x,y)[/math]. In other words, we assume that an instance a predictive model would be tested in the deployment would follow the same distribution as that from which the training examples were drawn, i.e., [math]q^*(x) = p^*(x)[/math]. Furthermore, the conditional distribution over the outcome does not change either, i.e., [math]q^*(y|x) = p^*(y|x)[/math]. In this case, there is no reason for us to consider the underlying generating process behind [math]p^*[/math] nor [math]q^*[/math] separately.

Out-of-Distribution Generalization

Impossibility of Out-of-Distribution (ood) generalization. In reality, it is rarely that [math]q^* = p^*[/math], because the world changes. When [math]q^* \neq p^*[/math], we must be careful about discussing generalization. We must be careful, because we can always choose [math]q^*[/math] to be such that minimizing [math]R(\hat{p})[/math] in Eq.~\eqref{eq:empirical-risk} would lead to maximizing

[[math]] \begin{align} R^{q^*}(\hat{p}) = \mathbb{E}_{(x,y) \sim q^*}[ l(y, \hat{p}(y|x)) ]. \end{align} [[/math]]


Assume [math]y \in \left\{0, 1\right\}[/math]. Consider the following [math]q^*[/math], given [math]p^*(x,y) = p^*(x) p^*(y|x)[/math],

[[math]] \begin{align} q^*(x, y) = p^*(x) q^*(y|x), \end{align} [[/math]]

where

[[math]] \begin{align} \label{eq:contrarian-q} q^*(y|x) = 1 - p^*(y|x). \end{align} [[/math]]

That is, the mapping from [math]x[/math] to [math]y[/math] is reversed. When [math]x[/math] was more probable to be observed together with [math]y=1[/math] under [math]p^*[/math], it is now more probable to be observed together with [math]y=0[/math] now under [math]q^*[/math], and vice versa. If we take the log loss, which is defined as

[[math]] \begin{align} l(y, \hat{p}(y|x)) = -\log \hat{p}(y|x), \end{align} [[/math]]

learning corresponds to minimizing the KL divergence from the true distribution to the learned, predictive distribution. Mathematically,

[[math]] \begin{align} \arg\min_{\hat{p}} \frac{1}{N} \sum_{n=1}^N l(y_n, \hat{p}(y_n|x_n)) \approx \arg\min_{\hat{p}} \mathbb{E}_x \mathrm{KL}( p^*(\cdot | x) \| \hat{p}(\cdot |x) ). \end{align} [[/math]]

In other words, learning corresponds to recovering [math]p^*[/math] as much as we can for as many probable [math]x[/math]'s under [math]p^*(x)[/math]. It is clear that minimizing this loss function would make our predictive model worse on a new distribution \eqref{eq:contrarian-q}. Because the following holds for any particular example [math](x,y)[/math]:

[[math]] \begin{align} \log p^*(y|x) = \log (1 - q^*(y|x)). \end{align} [[/math]]

Since [math]\log[/math] is a monotonic function, maximizing [math]p^*[/math] is equivalent to minimizing [math]q^*[/math]. As soon as we start minimizing the log loss for learning, out-of-distribution generalization to [math]q^*[/math] gets worse, and there is no way to avoid it, other than not learning at all. This is a simple but clear example showing how out-of-distribution generalization is not possible in general. There will always be a target distribution that disagrees with the original distribution, such that learning on the latter is guaranteed to hurt the generalization accuracy on the former. In general, such a target distribution can be written down as

[[math]] \begin{align} \log q^*(y|x) \propto \log (1 - p^*(y|x)). \end{align} [[/math]]

We can also come up with a similar formula for [math]q^*(x)[/math], such that there is almost no support overlap between [math]p^*(x)[/math] and [math]q^*(x)[/math].


Out-of-distribution generalization. We then must narrow down the scope in order to discuss out-of-distribution generalization. There are many different ways to narrow the scope, and one way is to ensure that the target distribution [math]q^*[/math] is not too far from the original distribution [math]p^*[/math]. Let [math]D: \mathcal{P} \times \mathcal{P} \to \mathrm{R}_+[/math] be a (asymmetric) divergence between two distributions, such that the larger [math]D(p, q)[/math] implies the greater difference between these two distributions, [math]p[/math] and [math]q[/math]. Then, we can write a so-called distributionally-robust loss as

[[math]] \begin{align} \min_{\hat{p}} \sup_{q: D(p^*,q)\leq \delta} \mathbb{E}_{(x,y) \sim q} \left[ l(y, \hat{p}(y|x)) \right], \end{align} [[/math]]

where [math]\sup[/math] is the supremum which is the smallest item that is greater than equal to all the other items in a partially ordered set~[1]. The distributionally-robust loss above minimizes ([math]\min_{\hat{p}}[/math]) the expected loss ([math]\mathbb{E}_{(x,y) \sim q} \left[ l(y, \hat{p}(y|x)) \right][/math]) over the worst-case distribution ([math]\sup_{q}[/math]) within the divergence constraint ([math]q: D(p^*,q)\leq \delta[/math]). Despite its generality, due to the freedom in the choice of the divergence [math]D[/math] and the universality (the worst case), such distributionally-robust optimization is challenging to use in practice. The challenge mainly comes from the fact that we must solve a nested optimization problem, where for each update of [math]\hat{p}[/math] we must solve another optimization problem that maximizes the loss w.r.t. the distribution [math]q[/math]. This problem can be cast as a two-player minimax game which is more challenging, both in terms of convergence and its speed, than a more conventional optimization problem. Furthermore, it is often unclear how to choose an appropriate divergence [math]D[/math] and the threshold [math]\delta[/math], as these choices are not grounded in the problem of interest. Instead, we are more interested in an alternative to the distributionally robust optimization approach. Instead of specifying a divergence, we can describe how the distribution changes in terms of the probabilistic graphical model, or equivalently the structural causal model underlying [math]p^*[/math] and [math]q^*[/math]. Depending on such a distributional change, we may be able to characterize the degree of generalization or even to come up with a better learning algorithm.

Case Studies

The label proportion shift. Let us consider a very basic example of a generative classier which assumes the following generating process: \begin{center}

\end{center} Under this generating process, the joint probability is written as

[[math]] \begin{align} p^*(x,y) = p^*(y) p^*(x|y), \end{align} [[/math]]

and the posterior distribution over the output [math]y[/math] is

[[math]] \begin{align} p(y|x) = \frac{p(y) p(x|y)}{p(x)} = \frac{p(y) p(x|y)}{\sum_{y' \in \mathcal{Y}} p(y') p(x|y')}. \end{align} [[/math]]


Given a training set [math]D=\left\{ (x_1, y_1), \ldots, (x_N, y_N) \right\}[/math], where each [math](x_n,y_n)[/math] was drawn from the generating process above, that is,

[[math]] \begin{align} &y_n \sim p^*(y) \\ &x_n \sim p^*(x|y_n). \end{align} [[/math]]

We can train a neural network classifier that takes as input [math]x[/math] and outputs a probability for each possible value of [math]y[/math]. This neural network can be written as

[[math]] \begin{align} \label{eq:softmax-nn} \hat{p}(y|x; \theta, b) = \frac{\exp(f_y(x; \theta) + b_y)} {\sum_{y' \in \mathcal{Y}} \exp(f_{y'}(x; \theta)+ b_{y'})}, \end{align} [[/math]]

where [math]f_y(x; \theta)[/math] is the [math]y[/math]-th element of the [math]|\mathcal{Y}|[/math]-dimensional output from the neural network [math]f[/math], parametrized by [math]\theta[/math] and the bias vector [math]b \in \mathbb{R}^{|\mathcal{Y}|}[/math]. Inspecting this neural net's formulation, based on the so-called softmax output, we notice the following correspondences:

  • [math]p^*(y) \approx \frac{1}{Z_y} \exp(b_y)[/math]
  • [math]p^*(x|y) \approx \frac{1}{Z_{x|y}} \exp(f_y(x; \theta))[/math],

where [math]Z_y[/math]'s and [math]Z_{x|y}[/math]'s are the normalization constants, which are cancelled out in Eq.~\eqref{eq:softmax-nn}.[Notes 1] In other words, the bias [math]b_y[/math] captures the marginal distribution over the output, and the rest the conditional distribution over the input given the output. This view suggests a two-stage learning process. In the first stage, we simply set [math]b_y[/math] to be [math]\log p^*(y)[/math] (and thereby set [math]Z_y=1[/math] implicitly.) Then, we use optimization, such as stochastic gradient descent, to estimate the rest of the parameters, [math]\theta[/math]. After learning is over, we get

[[math]] \begin{align} \label{eq:y-given-x} \hat{p}(y|x) = \hat{p}(y) \frac{\exp(f_y(x; \hat{\theta}))}{\sum_{y'} \exp(f_{y'}(x; \hat{\theta}))}. \end{align} [[/math]]

It is important to notice that the second term on the right hand side is not the estimate of [math]p^*(x|y)[/math], since the denominator must include the extra normalization, i.e. [math]p(x)[/math]. In other words,

[[math]] \begin{align} \frac{\exp(f_y(x; \hat{\theta}))}{\sum_{y'} \exp(f_{y'}(x; \hat{\theta}))} = \frac{\hat{p}(x|y)}{\hat{p}(x)}. \end{align} [[/math]]


This predictive model [math]\hat{p}(y|x)[/math] would work well even on a new instance under the iid assumption, that is, [math]p^*(y|x)=q^*(y|x)[/math]. It is however not the case, because [math]q^*(y) \neq p^*(y)[/math]. For instance, imagine we trained a COVID-19 diagnosis model based on various symptoms, including cough sound, temperature and others, during the winter of 2021. During this period, COVID-19 was rampant, that is, [math]p^*(y=1)[/math] was very high. If we use this model however in the winter of 2024, the overall incident rate of COVID-19 is much lower. In other words, [math]q^*(y=1) \ll p^*(y=1)[/math]. This would lead to the overestimation of [math]p(y=1|x)[/math], because the prediction is proportional to [math]\hat{p}(y=1)[/math] which is an estimate of the outdated prior [math]p^*(y=1)[/math] over the output not of the latest prior [math]q^*(y=1)[/math]. The prediction becomes worse as [math]q^*[/math] deviates further away from [math]p^*[/math]. One simple way to address this is to assume that a priori it is more probable for the label marginal, i.e., the marginal distribution over the output, to be closer to the uniform distribution. This is a reasonable assumption in many contexts when we are not allowed any information about the situation. For instance, it is perfectly sensible to assume that any given coin is likely to be fair (that is, it has the equal chance of landing head or tail.) In that case, we would simply set the bias [math]b[/math] to be an all-zero vector so that

[[math]] \begin{align} \hat{p}(y|x) = \frac{\exp(f_y(x; \hat{\theta}))}{\sum_{y'} \exp(f_{y'}(x; \hat{\theta}))}. \end{align} [[/math]]


Sometimes we are given some glimpse into [math]q^*[/math]. In the case of COVID-19, it is difficult to collect [math](x,y)[/math] pairs but it is often easy to collect [math]y[/math]'s by various means, including the survey and rapid testing in various event venues. Let [math]\hat{q}(y)[/math] be the estimate of [math]q^*(y)[/math] from such a source. We can then replace [math]\hat{p}(y)[/math] with this new estimate in Eq.~\eqref{eq:y-given-x}, resulting in

[[math]] \begin{align} \hat{p}(y|x) = \hat{q}(y) \frac{\exp(f_y(x; \hat{\theta}))}{\sum_{y'} \exp(f_{y'}(x; \hat{\theta}))}. \end{align} [[/math]]

This is equivalently to replacing the bias [math]b_y[/math] with [math]\log \hat{q}(y)[/math]. In practice, it is often the case that the number of [math]y[/math] samples we can collect is limited, leading to a high-variance estimate of [math]q^*[/math]. We do not want to rely solely on such an estimate. Instead, we can interpolate between [math]\hat{p}(y)[/math] and [math]\hat{q}(y)[/math], leading to replacing the bias of each output with

[[math]] \begin{align} b_y \leftarrow \log \left(\alpha \hat{p}(y) + (1-\alpha) \hat{q}(y) \right), \end{align} [[/math]]

with [math]\alpha \in [0, 1][/math]. [math]\alpha[/math] describes the degree of our trust in the original estimate of the label marginal. if [math]\alpha = 1[/math], we end up with the original iid setup, and with [math]\alpha=0[/math], we fully trust our new estimate of the label marginal.

Data augmentation. Consider an object classification task, where the goal is to build a classifier that categorizes the object in the center of an image into one of [math]K[/math] predefined classes. Just like before, we assume generative classification in which the object label produces the image. We however further assume that there exists an extra variable [math]z=(i,j)[/math] that determines the precise position of the object. \begin{center}

\end{center} During the training time, [math]z[/math] follows a Normal distribution centered at the center of the image, i.e., [math]z \sim \mathcal{N}(\mu_z=[0, 0]^\top, I_2)[/math]. Assuming that the background is randomly produced and does not correlate with the identity of the object in the center, a classifier we train on data produced from this data generating process should become blind to periphery pixels, since [math]\mathrm{cov}(x_{mn}, y) \approx 0[/math], where [math]|m| \gg 0[/math] and [math]|n| \gg 0[/math]. This can be written down as

[[math]] \begin{align} p(x_{mn} | y) \approx p(x_{nm}), \end{align} [[/math]]

meaning that [math]x_{mn}[/math] is independent of [math]y[/math]. If we make the na\"ive Bayes assumption, that is, all pixels are independent conditioned on the label, we get the following expression of the posterior over the label:

[[math]] \begin{align} p(y|x) \propto p(y) \prod_{m, n} p(x_{mn} | y) \propto p(y) \prod_{(m, n) \in C} p(x_{mn} | y), \end{align} [[/math]]

where [math]C[/math] is a set of pixels near the center. In other words, if the object is outside the center of the image, the posterior distribution over the label would not capture the actual identity of the object. This dependence on the position arises from the existence of the hidden variable [math]z[/math] and its prior distribution [math]p^*(z)[/math]. If this prior distribution over [math]z[/math] shifts in the test time, such that [math]q^*(z) = \mathcal{N}(\mu_z=[100, 100]^\top, I_2)[/math], all objects in the images would be positioned on the top-right corners. The classifier based on the training set with [math]p^*(z)[/math] will then completely fail to detect and classify these objects. Because we assume to know the precise type of shift that is possible, we can now mitigate this issue by data augmentation~[2]. During training, we randomly shift a training image such that the position of the object in the image varies more greatly than it usually does in the original training set. This can be thought of as introducing another random variable [math]u[/math] such that

[[math]] \begin{align} p(l | z, u) = p(l), \end{align} [[/math]]

where [math]l[/math] indicates the position of the object in an image. In other words, [math]u[/math] makes the position of an object independent of [math]z[/math], such that a classifier trained on the training data with such data augmentation is able to detect objects in any position, making it invariant to the distributional shift of [math]z[/math].

General references

Cho, Kyunghyun (2024). "A Brief Introduction to Causal Inference in Machine Learning". arXiv:2405.08793 [cs.LG].

Notes

  1. [math]\exp(a + b) = \exp(a) \exp(b)[/math].

References

  1. "Distributionally robust stochastic programming" (2017). SIAM Journal on Optimization 27. SIAM. 
  2. "Effective training of a neural network character classifier for word recognition" (1996). Advances in neural information processing systems 9.