Multinomial stick-breaking

Exploration of multinomial stickbreaking
Author

Karl Tayeb

Published

March 19, 2023

Modelling dependent categorical and multionomial data

We’re interested in modelling multinomial or categorical data in the case where the probability of each category depends on side information. For \(\pi: \mathcal X \rightarrow \Delta^{K-1}\)

\[ {\bf y} \sim \text{Multinomial}(n, {\bf \pi}({\bf x})) \\ \]

Commonly \(\pi({\bf x})\) is written as a composition \(\pi = \sigma \circ \eta\), where \(\sigma: \mathbb R^K \rightarrow \Delta^{K-1}\) is the softmax functions defined element wise as \(\sigma({\bf \eta})_i = \left(\frac{e^{\eta_i}}{\sum_{j=1}^K e^{\eta_j}}\right)\), and \(\eta:\mathcal X \rightarrow \mathbb R^K\) is some other function mapping the covariates \({\bf x}\) to a set of unormalized log probabilities.

The trouble with this formulation is it is not easy to express uncertainty in the map \(\eta\). As a simple example consider multinomial linear regression where \(\eta(z)_k = \beta_k^T z\) for some \(\beta_k \in \mathbb R^d\). \(\pi = \sigma \circ \eta\) is differential, and point estimates of \(B =\{\beta_k\}_{k=1, \dots, K}\) could be obtained through gradient based optimization. In contrast if we take a Bayesian approach and specify a prior on \(\beta_k \sim g_k\; k \in [K]\) obtaining the posterior distribution over \(B\) involves evaluating a nasty integral of the soft max.

\[ \int_{B} \sigma(\eta(z ; B)) dB \]

There is plenty of work on bounding softmax with functions that are easier to integrate (1,2), but it is hard problem to get anlytic bounds that are easy to work.

There is also quite a bit of work developing bounds for the sigmoid function (softmax with \(K=2\), usually people describe softmax as a generalization of sigmoid to \(K > 2\)). In particular, techniques for constructing local approximations are popular in variational inference (3,4). These local approximations are tight at a point, but the quality of the bound decays as you get far from that point. Thus, these approximation techniques require selecting/optimizing at what point the bound is tight.

We’re operating under the assumption that it is easier to construct good bounds for the sigmoid function compared to the softmax function. We are going to explore a construction of the Categorical/Multinomial distribution that let us utilize these bounds.

Multinomial stick breaking

The the multinomial logit construction \(\eta\) is a set of unnormalized log probabilities This is not the only way to construct a multinomial distribution. We can also use a stick breaking construction. In stick breaking we start with a “stick” of length \(1\). At the first step we break off a fraction of the stick \(p_1\). The remainder of the stick is now length \(1 - p_1\). At each successive step we break off a fraction of the remaining stick. After \(K-1\) breaks we have broken the stick into \(K\) pieces, giving a discrete probability distribution over \(K\) categories. Clearly, we can use this process to construct and distribution \(\pi\) over \(K\) categories where

\[ \begin{aligned} \pi_1 &= p_1 \\ \pi_k &= p_k \prod_{j < k}(1 - p_k) \end{aligned} \]

Noting that \(\left(1 - \sum_{j < k} \pi_j \right)\) is the length of the remaining stick after \(k-1\) breaks, we can also write

\[ \begin{aligned} \pi_k &= p_k \left(1 - \sum_{j < k} \pi_j \right) \end{aligned} \]

In the stick breaking construction, \(\nu_k,\; k \in[K-1]\) will be a set of log odds such that \(p_k = \sigma(\nu_k)\) gives the proportion of the stick broken off at step \(k\). Using the stick breaking constructiong we can write the multinational pmf as a product of binomial pmfs.

\[ \text{Multinomial}({\bf y}; n, \pi) = \prod_{k=1}^{K-1} \text{Binomial}(y_k; n_k, p_k) \]

Where \(n_k = n - \sum_{j < k} y_j\) counts the number of remaining trials, conditional on the first \(k-1\) draws. This constructing is not new, it has been proposed by several authors (5,6).

To do multinomial regression we will write \(\nu_k = \beta_k^T {\bf z}\). \(nu_k\) gives the log odds of selecting category \(k\) given that we did not select category \(1, \dots, k-1\).

Stick breaking for variational inference

The stick breaking construction is particular useful for variational inference. The multinomial log likelihood can be written as a sum of \(K-1\) terms, each a binomial log-likelihood. By selecting a variational approximation where the \(\nu_k\) factorize, the variational objective can be optimized in an embarrassingly parallel fashion– the multinomial regression reduces to a set of \(K-1\) independent binomial regression problems. Each of these problems still requires additional approximation of the sigmoid function for tractable inference, but these can be dealt with more easily.

A distribution of \(\pi\)

While stick breaking can be used to construct any discrete distribution, we should take note that the distribution on \(\pi\) is dependent on the distribution we specify for the breakpoints and

The Dirichlet \(Dir((\alpha_1, \dots, \alpha_K))\) can be constructed through stick breaking, where the break points are

\[p_k \sim Beta(\alpha_k, \sum_{j > k } \alpha_j)\]

Again \(\pi_1 = p_1\), and \(p_k = (1 - \sum_{j < k} \pi_j) p_k\). If \(\alpha_i = \alpha\; \forall i \in [K]\) then then the Dirichlet is said to be symmetric– permuting category labels won’t change the likelihood of the sample. Notice that in this case \(p_k \sim Beta (\alpha, (K- k) \alpha)\). We should expect to break off smaller fractions of the stick for small \(k\) than for large \(k\). This makes sense. A necessary condition for the Dirichlet to be exchangeable is that the stick lengths have the same marginal distribution. In order for the stick lengths to have the same marginal distribution, at each successive step we need to balance out the fact that the stick is getting shorter by taking larger fraction of the stick at each step (ultimately \(\mathbb E[p_{K-1}] = \frac{1}{2}\)).

In the code below we simulate the Dirichlet distribution using stick breaking with a Beta distribution. We see that across 10000 simulations each category is equally likely to show up on top.

Code
#' Sample from a Dirichlet distribution using the stick breaking construction
dirichlet_from_beta_stick_breaking <- function(alpha, K){
  if(length(alpha) == 1){ 
    alpha <- rep(alpha, K)
  }
  beta <- rev(cumsum(rev(alpha))) - alpha # sum {j < k} \alpha_j
  p <- rbeta(K, alpha, beta)
  tmp <- c(1, head(cumprod(1 - p), -1))
  pi <- c(p * tmp)
  return(pi)
}

# each component equally likely to have the most probability mass
table(purrr::map_int(1:10000, ~which.max(
  dirichlet_from_beta_stick_breaking(1, 4))))

   1    2    3    4 
2458 2539 2492 2511 

TODO: sample for \(K=3\)

Q: What distribution of \(\nu_k\) would give an exchangeable distribution for \(\pi\) (basically, what is the stick-breaking construction for a symmetric Dirichlet?)

Ordering of the categories

Successive categories seem to have less and less information, as \(n_k \leq n_j\) for \(k > j\). It seems odd that permuting the category labels would change how certain we are about each \(\nu_k\). Can we make sense of this?

References

1.
Bouchard G. Efficient Bounds for the Softmax Function and Applications to Approximate Inference in Hybrid models. :9.
2.
Titsias RC AUEB M. One-vs-Each Approximation to Softmax for Scalable Estimation of Probabilities. In: Advances in Neural Information Processing Systems [Internet]. Curran Associates, Inc.; 2016 [cited 2022 Oct 4]. Available from: https://proceedings.neurips.cc/paper/2016/hash/814a9c18f5abff398787c9cfcbf3d80c-Abstract.html
3.
Jaakkola TS, Jordan MI. A variational approach to Bayesian logistic regression models and their extensions. Sixth International Workshop on Artificial Intelligence and Statistics. :283–94.
4.
Saul L, Jordan M. A Mean Field Learning Algorithm for Unsupervised Neural Networks. In: Jordan MI, editor. Learning in Graphical Models [Internet]. Dordrecht: Springer Netherlands; 1998 [cited 2022 Dec 24]. p. 541–54. Available from: http://link.springer.com/10.1007/978-94-011-5014-9_20
5.
Khan M, Mohamed S, Marlin B, Murphy K. A Stick-Breaking Likelihood for Categorical Data Analysis with Latent Gaussian Models. In: Proceedings of the Fifteenth International Conference on Artificial Intelligence and Statistics [Internet]. PMLR; 2012 [cited 2023 Jan 19]. p. 610–8. Available from: https://proceedings.mlr.press/v22/khan12.html
6.
Linderman S, Johnson M, Adams RP. Dependent Multinomial Models Made Easy: Stick-Breaking with the Polya-gamma Augmentation. :9.