COCONUT Summary

Training Large Language Models to Reason in a Continuous Latent Space

COCONUT Summary

High-Level Ideas

Some Details

A standard LLM M\mathcal M takes an input sequence x=(x1,...,xT)x = (x_1, ..., x_T) and proceeds as follows, for tTt \ge T.

Contrastingly, COCONUT directly utilises the last hidden states hth_t as the next input embedding. Special tokens <bot>/<eot> mark the beginning/end of the latent thought mode. Suppose this occurs between positions ii and jj—ie, xi=<bot>x_i = \texttt{<bot>} and xj=<eot>x_j = \texttt{<eot>}. When the model is in latent mode (ie, tt satisfies i<t<ji < t < j), then Et=[e(x1),...,e(xi),hi,hi+1,...,ht1]E_t = [e(x_1), ..., e(x_i), h_i, h_{i+1}, ..., h_{t-1}].

The paper focuses on the question–answer regime. The usual negative log-likelihood loss is optimised during training, but the loss is masked on questions and latent thoughts. This does not (explicitly) encourage the continuous thought to compress the removed language thought, but rather facilitate the prediction of future reasoning.

COCONUT requires training, so that the outputs are appropriately aligned with the expected inputs. Reviewer qvXf proposed adding a trainable linear projection between the output and input; the authors said they tried this, but found minimal difference.

Results

In summary, the results aren't stellar, with CoT frequently outperforming COCONUT in accuracy, often by a large margin. However, the number of tokens generated by COCONUT is typically much smaller than for CoT.

A pre-trained GPT-2 is used as the base model for all the experiments. This is one of the major critiques of the paper: why evaluate on an old, small model? Four baselines are considered.

  1. CoT. The complete reasoning chains are used to train the LLM with supervised tine-tuning; during inference, reasoning chains are generated before outputting an answer
  2. No-CoT. The LLM is trained to directly generate the answer
  3. ICoT. Language reasoning chains are used in trainto to 'internalise' CoT; direct prediction is used in inference
  4. Pause. Training only uses questions and answers, no reasoning; special <pause> tokens are inserted between these, potentially providing additional capacity to derive the answer

evaluation table

The model needs to switch from latent to language mode—ie, output <eot>. Two options were considered.

  1. Train a binary classifier on latent thoughts to enable autonomous termination.
  2. Always pad the latent thoughts to a constant length.

The authors report comparative performance, and so use the second, simpler version.

Critiques

The choice of model (GPT-2) was raised in the reviews. They only report COCONUT vs No-CoT, and do not mention the benchmark (🤦).

Model No-CoT COCONUT Improvement
Llama 3.2-3B 26.0 31.7 5.7 (22%)
Llama 3-8B 42.2 43.6 1.4 (3.3%)

Regardless of the set-up, the increase over no CoT is minimal for the 8B model. They suggest that a potential reason is that continuous thoughts increase effective network depth, which may particularly benefit smaller models.

The continuous thoughts are fully differentiable, allowing for backpropagation. Some repetitive computation can be avoided using KV cache, but the sequential nature poses challenges for parallelisation. Further, a multi-stage training framework is required, to internalise from explicit to latent, which may restrict usability on larger models with longer reasoning chains.

training procedure

Take-Aways

COCONUT is far from matching CoT in terms of accuracy (see Table 1). It has potential, however, to significantly shorten reasoning traces—not least by removing the need of textual coherence, but also simply by avoiding language altogether.

There are serious issues around the scalability of training, and the results in the paper attend only to GPT-2—the authors don't even say which one. Accuracy improvements on Llama 3-8B were much minimal (3.3% relative). The number of tokens were not reported.

Comparison with Soft Thinking and Mixture of Inputs

The soft thinking and mixture of inputs approaches are related, but fundamentally different. There, as well as in CoT, the softmax distribution softmax(Wht)\mathop{\mathsf{softmax}}(W h_t) is computed.

COCONUT requires training to implement the new approach, whilst Soft Thinking and Mixture of Inputs are training free.