A Case Study: Language Modeling with Pairwise Preference
An autoregressive language model is described as a repeated application of the next-token conditional probability, as in
A conditional autoregressive language model is exactly the same except that it is conditioned on another variable [math]X[/math]:
There are many different ways to build a neural network to implement the next-token conditional distribution. We do not discuss any of those approaches, as they are out of the course's scope. An interesting property of a language model is that it can be used for two purposes:
- Scoring a sequence: we can use [math]p(w_1, w_2, \ldots, w_T | X)[/math] to score an answer sequence [math]w[/math] given a query [math]x[/math].
- Approximately finding the best sequence: we can use approximate decoding to find [math]\arg\max_w p(w | x)[/math].
This allows us to perform causal inference and outcome maximization simultaneously. Consider the problem of query-based text generation, where the goal is to produce an open-ended answer [math]w[/math] to a query [math]x[/math]. Because it is often impossible to give an absolute score to the answer [math]w[/math] given a query [math]x[/math], it is customary to ask a human annotator a relative ranking between two (or more) answers [math]w_+[/math] and [math]w_-[/math] given a query [math]x[/math]. Without loss of generality, let [math]w_+[/math] be the preferred answer to [math]w_-[/math]. We assume that there exists a strict total order among all possible answers. That is,
- Irreflexive: [math]r(w|x) \lt r(w|x)[/math] cannot hold.
- Asymmetric: If [math]r(w|x) \lt r(w'|x)[/math], then [math]r(w|x) \gt r(w'|x)[/math] cannot hold.
- Transitive: If [math]r(w|x) \lt r(w'|x)[/math] and [math]r(w'|x) \lt r(w''|x)[/math], then [math]r(w|x) \lt r(w''|x)[/math].
- Connected: If [math]w \neq w'[/math], then either [math]r(w|x) \lt r(w'|x)[/math] or [math]r(w|x) \gt r(w'|x)[/math] holds.
In other words, we can enumerate all possible answers according to their (unobserved) ratings on a 1-dimensional line.
A non-causal approach. It is then relatively trivial to train this language model, assuming that we have a large amount of triplets
For each triplet, we ensure that the language model puts a higher probability on [math]w_+[/math] than on [math]w_-[/math] given [math]x[/math] by minimizing the following loss function:
where [math]m \in [0, \infty)[/math] is a margin hyperparameter. For each triplet, the loss inside the summation is zero, if the language model puts the log-probability on [math]w_+[/math] more than that on [math]w_-[/math] with the minimum margin of [math]m[/math]. This loss alone is however not enough to train a well-trained language model from which we can produce a high-quality answer. For we have only pair-wise preference triplets for reasonable answers only. The language model trained in this way is not encouraged to put low probabilities on gibberish. We avoid this issue by ensuring that the language model puts reasonably high probabilities on all reasonable answer by minimizing the following extra loss function:
which corresponds to the so-called negative log-likelihood loss.
A causal consideration. This approach works well under the assumption that it is only the content that is embedded in the answer [math]w[/math]. This is unfortunately not the case. Any answer is a combination of the content and the style, and the latter should not be the basis on which the answer is rated. For instance, one aspect of style is the verbosity. Often, a longer answer is considered to be highly rated, because of the subconscious bias by a human rater believing a better answer would be able to write a longer answer, although there is no reason why there should not be a better and more concise answer. This process can be described as the graph below, where [math]r[/math] is the rating and [math]s[/math] is the style: \begin{center} \begin{tikzpicture}
\node[latent] (w) {[math]w[/math]}; \node[latent, right=1.5cm of w] (r) {[math]r[/math]}; \node[latent, above=0.5cm of w, xshift=1cm] (s) {[math]s[/math]}; \node[obs, left=1cm of w] (x) {[math]x[/math]}; \edge{w}{r}; \edge{s}{w}; \edge{s}{r}; \edge{x}{w}; \end{tikzpicture}
\end{center} The direct effect of [math]w[/math] on the rating [math]r[/math] is based on the content, but then there is spourious correlation between [math]w[/math] and [math]r[/math] via the style [math]s[/math]. For instance, [math]s[/math] could encode the verbosity which affects both how [math]w[/math] is written and how a human rater perceives the quality and gives the rating [math]r[/math]. In the naive approach above, the language model, as a scorer, will fail to distinguish between these two and capture both, which is clearly undesirable; a longer answer is not necessarily a better answer. In other words, a language model [math]p_0[/math] trained in a purely supervised learning way above will score [math]w[/math] high for both causal and spurious (via [math]s[/math]) reasons. An answer [math]w[/math] sampled from [math]p_0[/math] can then be considered dependent upon not only the question [math]x[/math] itself but also of an unobserved style variable [math]s[/math].
Direct preference optimization~[1].} We can resolve this issue by combining two ideas we have studied earlier; randomized controlled trials (RCT; \SRandomized Controlled Trials) and inverse probability weighting (IPW; \SInverse Probability Weighting). First, we sample two answers, [math]w[/math] and [math]w'[/math], from the already trained model [math]p_0[/math], using supervised learning above:
These two answers (approximately) maximize the estimated outcome (rating) by capturing both the content and style. One interesting side-effect of imperfect learning and inference (generation) is that both of these answers would largely share the style. If we use [math]s'[/math] to denote that style, we can think of each answer as sampled from [math]w | x, s'[/math]. With a new language model [math]p_1[/math] (potentially initialized from [math]p_0[/math]), we can compute the rating after removing the dependence on the style [math]s[/math] by IPW:
This reminds us of [math]\mathrm{do}[/math] operation, resulting in the following modified graph: \begin{center} \begin{tikzpicture}
\node[latent] (w) {[math]w[/math]}; \node[latent, right=1.5cm of w] (r) {[math]r[/math]}; \node[latent, above=0.5cm of w, xshift=1cm] (s) {[math]s[/math]}; \node[obs, left=1cm of w] (x) {[math]x[/math]}; \edge{w}{r}; \edge{s}{r}; \edge{x}{w}; \end{tikzpicture}
\end{center} Of course, this score [math]\hat{r}[/math] does not mean anything, since [math]p_1[/math] does not mean anything yet. We have to train [math]p_1[/math] by asking an expert to provide their preference between [math]w[/math] and [math]w'[/math]. Without loss of generality, let [math]w[/math] be the preferred answer over [math]w'[/math]. That is, [math]w_+=w[/math] and [math]w_-=w'[/math]. We train [math]p_1[/math] by minimizing
where we assume have [math]N[/math] pairs. [math]m[/math] is a margin as before. It is possible to replace the margin loss with another loss function, such as a log loss or linear loss. This procedure encourages [math]p_1[/math] to capture only the direct (causal) effect of the answer on the rating, dissecting out the indirect (spurious) effect via the style [math]s[/math]. One training is done, we use [math]p_1[/math] to produce a better answer, which dependes less on the spurious correlation between the answer and the rating via the style. Because this procedure is extremely implicit about the existence of and the dependence on the style, it can be beneficial to repeat this procedure multiple rounds in order to further remove the effect of the spurious correlation and improve the quality of a generated answer~[2].
General references
Cho, Kyunghyun (2024). "A Brief Introduction to Causal Inference in Machine Learning". arXiv:2405.08793 [cs.LG].