Soft Thinking Summary

Soft Thinking: Unlocking the Reasoning Potential of LLMs in Continuous Concept Space

Soft Thinking Summary

High-Level Ideas

Some Details

Instead of injecting the random, one-hot encoded token, after embedding, into the LLM, the concept token is injected back:

inject k=1Vct[k]e(k)\sum_{k=1}^{|V|} \mathsf{ct}[k] e(k)

where ct[k]\mathsf{ct}[k]/e(k)e(k) is the selection probability/embedding of the kk-th vocab item.

The theoretical justification is based on linear approximations in the expression

p(yx)=t1p(t1x)t2p(t2x,t1)...tmp(tmx,t1:m1)p(yx,t1:m);\textstyle p(y \mid x) = \sum_{t_1} p(t_1 \mid x) \sum_{t_2} p(t_2 \mid x, t_1) ... \sum_{t_m} p(t_m \mid x, t_{1:m-1}) p(y \mid x, t_{1:m});

here, tjt_j is the jj-th (intermediate reasoning) token. To emphasise, p(x,t1:j)p(\cdot \mid x, t_{1:j}) is the probability of the next token given the input xx and intermediate reasoning tokens t1,...,tjt_1, ..., t_j. Once some stopping criterion is achieved, the model outputs an answer, denoted yy.

This expansion entails exponentially-in-mm many paths, indexed by the choice of intermediate reasoning tokens. If we only expanded one layer,

p(yx)=t1p(t1x)p(yx,t1).\textstyle p(y \mid x) = \sum_{t_1} p(t_1 \mid x) p(y \mid x, t_1).

In expectation,

ct1=E[t1]=t1p(t1x)t1=p(x)/\textstyle \mathsf{ct}_1 = \mathbb E[t_1] = \sum_{t_1} p(t_1 \mid x) t_1 = p(\cdot \mid x)/

Linearising the previous expression about its mean, (ie, replacing random t1t_1 by its non-random mean ct1\mathsf{ct}_1),

p(yx)=t1p(t1x)p(yx,t1)p(yx,t1p(t1x)t1)=p(yx,ct1).\textstyle p(y \mid x) = \sum_{t_1} p(t_1 \mid x) p(y \mid x, t_1) \approx p(y \mid x, \sum_{t_1} p(t_1 \mid x) t_1) = p(y \mid x, \mathsf{ct}_1).

The approximation is repeated given ct1\mathsf{ct}_1:

p(yx,ct1)=t2p(t2ct1)p(yx,ct1,ct2)p(yx,ct1,ct2).\textstyle p(y \mid x, \mathsf{ct}_1) = \sum_{t_2} p(t_2 \mid \mathsf{ct}_1) p(y \mid x, \mathsf{ct}_1, \mathsf{ct}_2) \approx p(y \mid x, \mathsf{ct}_1, \mathsf{ct}_2).

Iterating,

p(yx)p(yx,ct1,ct2,...,ctm).p(y \mid x) \approx p(y \mid x, \mathsf{ct}_1, \mathsf{ct}_2, ..., \mathsf{ct}_m).

In contrast, discrete CoT replaces each summation tjp(tj)\sum_{t_j} p(t_j \mid \cdot) with sampling a single token, which discards the mass from all other paths. Soft thinking preserves the distribution through the concept tokens, whilst collapsing the exponential path summation to a single forward pass.

A cold stop is also implemented, where the entropy is tracked in real time. This is to address issues in which the continuous concept tokens place the model in an out-of-distribution regime. If the entropy (ie, uncertainty) of a concept token drops below a threshold, the process stops. A basic ablation study is conducted; with details below.

Results

In summary, I'd suggest that the accuracy increase is modest, but the generation length reduction is significant.

The first table is the accuracy (higher → better) and the second the generation length (lower → better).

Table of results

The soft thinking results all utilise a cold stop, with threshold τ\tau optimised for the problem at hand. An ablation study is conducted regarding cold stop—namely, τ=0\tau = 0 is forced, ensuring the cold stop is never activated.

The four lines in the table below correspond to different strategies.

Ablation study

This is all summarised in a figure given above the abstract.

Figure of results