CoT2 Summary

Continuous Chain of Thought Enables Parallel Exploration and Reasoning

CoT2 Summary

Some Details

High-Level Ideas

Some Details

The paper introduces a continuous-token approach. Classical CoT feeds the thought token back into the LLM to produce the next thought token. Their underlying idea is very similar to soft thinking: instead of sampling a single token, it samples/deterministically selects a continuous superposition. Two primary methods are suggested.

  1. Base (aka soft thinking): deterministically feed full distribution αt\alpha_t each step
  2. MTS (multi-token sampling): sample KK discrete tokens t1,...,tKt_1, ..., t_K and average them; CoT corresponds to K=1K = 1

For more formal details on how to implement this, see the summary on soft thinking or mixture of inputs, for example.

Unlike soft thinking or mixture of inputs, this paper introduces a training approach: continuous supervised fine tuning (CSFT). This fits a target distribution for the within-trace tokens. For example, in a graph search, the tt-th target distribution could be an average over (an embedding of) the vertices reachable within tt steps. The idea is that it allows the models to explicitly track multiple "teacher traces" in parallel. The strategy is fitted to the problem, rather than a generic idea of 'reasoning'/'intelligence'. On top of this, GRPO-style RL is employed for policy optimisation. We detail the two parts below.

CSFT (Supervised Training)

A target distribution αt\alpha^*_t is specified for each t=1,...,m1t = 1, ..., m-1, where mm is the (preset) length, and αm\alpha^*m is the one-hot distribution on the target token. If αt\alpha_t is the LLM-provided distribution for the tt-th token, the final loss is the sum of the relative entropies:

LCSFT=t=1mKL(αtαt).\textstyle L^\textsf{CSFT} = \sum_{t=1}^m \mathsf{KL}(\alpha^*_t \mathrel{\|} \alpha_t).

By minimising this loss, the model is taught to learn the soft targets (αt)t0(\alpha^*_t)_{t\ge0}. Two ways of providing prefixes to the language model are considered.

  1. Teacher forcing, in which each step tt is conditioned on the ground-truth prefix: zt=Eαtz^*_{t'} = E^\top \alpha^*_{t'} for t<tt' < t.

  2. Self-feeding in which each step autoregressively uses the model's previously generated outputs: zt=Eαtz_{t'} = E^\top \alpha_{t'} or zt=1Kk=1KEtkz_{t'} = \tfrac1K \sum_{k=1}^K E^\top t_k for t<tt' < t.

The authors found teacher forcing to lead to better performance; at inference time, on potentially unseen problems, the model runs in an autoregressive manner, of course.

Additionally, a discrete baseline is considered, in which ztz^*_t is required to be a token in the vocabulary—in other words, the αt\alpha^*_t are one-hot distributions, rather than arbitrary ones, over the vocab.

GRPO (Reinforcement Learning)

The evaluation are all question–answer style, making them ideal for GRPO. Sparse rewards are used: 11 for the correct final answer (regardless of the intermediate tokens) and 00 otherwise. The GRPO implementation appears pretty standard; see my summary of GRPO for the DeepSeekMath paper for general GRPO details.

Two methods for policy optimisation—namely, defining the policy ratio in the GRPO objective—are proposed.

  1. Multi-Token Sampling (MTS). A rollout is emulated by sampling KK discrete tokens and averaging them. Suppose the step-tt tokens are vt,1,...,vt,kv_{t,1}, ..., v_{t,k} with respective probabilities αt,1new/old,...,αt,Knew/old\alpha^\textsf{new/old}_{t,1}, ..., \alpha^\textsf{new/old}_{t,K} under the new/old policy. The policy ratio for these continuous steps is the ratio of geometric means: rt=(αt,1newαt,Knewαt,1oldαt,Kold)1/K.r_t = \biggl( \frac{\alpha^\textsf{new}_{t,1} \cdots \alpha^\textsf{new}_{t,K}}{\alpha^\textsf{old}_{t,1} \cdots \alpha^\textsf{old}_{t,K}} \biggr)^{1/K}.

  2. Dirichlet Sampling. A scaling hyperparameter γ>0\gamma > 0 is introduced, and a distribution α^t\widehat \alpha_t is sampled from the Dirichlet distribution Dir(γαt)\mathop{\textsf{Dir}}(\gamma \alpha_t), given an LLM distribution αt\alpha_t. The continuous token is formed by zt=Eα^tz_t = E^\top \widehat \alpha_t. The policy ratio is the ratio of the Dirichlet densities: rt=fθnew(zt)fθold(zt).r_t = \frac{f_{\theta^\textsf{new}}(z_t)}{f_{\theta^\textsf{old}}(z_t)}. This parallels computation for discrete actions, but replaces the categorical pmf with a continuous Dirichlet pdf.

For the final step, in either case, only one token is selected, so the policy ratio is just αm,jnew/αm,jold\alpha^\textsf{new}_{m,j} / \alpha^\textsf{old}_{m,j}, where jj is the index of the chosen token.

Results

Two benchmark-styles are considered; both require exploration over states. Let Γt\Gamma_t be the set of all states that could result from building upon step t1t-1 (ie, an element of Γt1\Gamma_{t-1}), with Γ0={g0}\Gamma_0 = \{g_0\} some initial state. For each gΓtg \in \Gamma_t, assign a probability αt,g\alpha^*_{t, g} reflecting how many times gg occurs in a search:

αt,g=countt(g)/hΓtcountt(h),\textstyle \alpha^*_{t,g} = \mathop{\textsf{count}}_t(g) / \sum_{h \in \Gamma_t} \mathop{\textsf{count}}_t(h),

where countt(g)\mathop{\textsf{count}}_t(g) is the number of times state gg appears amongst all expansions at step tt.

  1. Minimum Non-Negative Sum Task: given a list d1,...,dmd_1, ..., d_m of integers, choose signs σ1,...,σm\sigma_1, ..., \sigma_m such that iσidi0\sum_i \sigma_i d_i \ge 0 is minimised.

    • Supervision for CoT2: at step tt, there are Γt=2t|\Gamma_t| = 2^t partial sums; αt,i:=countt(i)/2t\alpha^*_{t,i} := \mathop{\textsf{count}}_t(i) / 2^t if token ii appears countt(i)\mathop{\textsf{count}}_t(i) many times as a partial sum of length tt and 00 otherwise.
    • Supervision for discrete model: a correct chain of partial sums—ie, (σ1d1+...+σtdt)t=1m(\sigma_1 d_1 + ... + \sigma_t d_t)_{t=1}^m—is provided.
  2. ProntoQA and ProsQA: graph search tasks requiring exploration over multiple paths.

    • Each question in ProntoQA asks whether a certain target word is reachable from a root word within a fixed number of steps (55 here), whilst for ProsQA it asks which of two is reachable.
    • The counts are based on vertices reachable from the root within tt steps.

CoT2 evaluation MMNS

CoT2-MTS improves validation accuracy relative to the discrete baseline. Smaller KK values lead to larger reductions in token-level entropies, suggesting it learns to commit to fewer tokens. A curriculum on KK may help: start small, and gradually increase.

Appendix C.2 shows that CoT2 with CSFT performs well once the embedding dimension is large, and Dirichlet sampling (rather than MTS) can improve performance even further.

CoT2 evaluation ProsQA & ProntoQA

All choices seem to work pretty well, with fairly significant improvements for CoT2 over Discrete CoT. The difference is particularly pronounced in ProsQA, but only minor on ProntoQA.

Finally, we display a figure from earlier in the paper. They show that continuous SFT outperforms the discrete baseline once the embedding dimension is high enough, and actually convergence may be faster too; see Figures 2b and 2c. Figure 2a demonstrates that the discrete model requires multiple samples (Pass@k) to approach a single attempt from the CSFT CoT2.

CoT2 evaluation

Further results are given in Appendix C of the paper.