Thinking on the Fly: Test-Time Reasoning Enhancement via Latent Thought Policy Optimization

Thinking on the Fly: Test-Time Reasoning Enhancement via Latent Thought Policy Optimization

High-Level Summary

Elevator Pitch

Chain of thought (CoT) was pretty revolutionary, but has many issues - perhaps foremost, its cost/latency. Recently, focus has shifted from text-based to latent reasoning. However, these often struggle on challenging, OOD tasks - those in which robust reasoning is most valuable.

Enter Latent Thought Policy Optimization (LTPO):

No training of the LLM is needed. In fact, AR decoding isn't even needed for the reward. This makes each RL step rapid.

Performance-wise, LTPO frequently matches full COT for the models studied (typically order 10B) for accuracy, yet is much faster on more challenging questions.

Method

Overview of LTPO

Let M\mathcal M denote the frozen LLM, and EE its embedding layer. To enable latent reasoning, the embedded prompt is concatenated with KK placeholder latent thought tokens, denoted HH and initialised as H0=E([THINK],...,[THINK]).H^0 = E([\text{THINK}], ..., [\text{THINK}]). It is HH that is optimised at test time by RL.

The latent thoughts are updated through the policy gradient. The reward function chosen is non-differentiable, so REINFORCE is used instead of standard backprop: J(H)=EAπ(H)[R(A)],HJ(H)=EAπ(H)[R(A)Hlogπ(AH)].\begin{aligned} J(H) &= \mathbb E_{A \sim \pi(\cdot \mid H)}[R(A)], \\ \nabla_H J(H) &= \mathbb E_{A \sim \pi(\cdot \mid H)}[R(A) \nabla_H \log \pi(A \mid H)]. \end{aligned} But, logπ(AH)=12AH22/σ2+const\log \pi(A \mid H) = - \tfrac12 \|A - H\|_2^2 / \sigma^2 + \text{const}. Taking the gradient, Hlogπ(AH)=(AH)/σ2=ε,\nabla_H \log \pi(A \mid H) = (A - H) / \sigma^2 = \varepsilon, writing A=H+σ2εA = H + \sigma^2 \varepsilon as before. The authors then use a single sample to estimate the gradient: HJ(H)R(H+σ2ε)ε,\nabla_H J(H) \approx R(H + \sigma^2 \varepsilon) \varepsilon, leading to a 'gradient-ascent' update of Ht+1:=Ht+ηR(Ht+σ2εt)ε,H^{t+1} := H^t + \eta R(H^t + \sigma^2 \varepsilon^t) \varepsilon, where η>0\eta > 0 is the learning rate. To emphasise, this is a noisy estimate of gradient ascent.

After TT optimisation steps, the optimised latent thought vectors HH^\star are concatenated with the prompt embeddings E(x)E(x) and passed through the AR LLM: y=decode(M(E(x)H)).y = \operatorname{decode}\bigl(\mathcal M(E(x) \mathbin\Vert H^\star)\bigr).

We point out that this gradient estimate uses no baseline. This will, no doubt, lead to very high variance estimators. It is a natural place to use GRPO.

Sample εgN(0,1)\varepsilon_g \sim N(0, 1) and set Rg:=R(H+εg)R_g := R(H + \varepsilon_g) for g=1,...,Gg = 1, ..., G, independently. The gradient update becomes HJ(H)1Gg=1GA^gεg\textstyle \nabla_H J(H) \approx \frac1G \sum_{g=1}^G \hat A_g \varepsilon_g where A^g:=(Rgmean({Rg}g=1G))/std({Rg}g=1G).\hat A^g := \bigl( R_g - \operatorname{mean}(\{R_g\}_{g=1}^G) \bigr) / \operatorname{std}(\{R_g\}_{g=1}^G).

Each RgR_g requires a forward pass through the frozen LLM, so can be batched. It's more compute per optimisation step, but likely far more efficient.

Currently, their method is framed as RL, but it's really just random search that just happens to use the REINFORCE estimator.

Experiments

LTPO is compared against three baselines.

  1. Zero-Shot CoT
    • the standard, discrete-space CoT, instructing the model to generate explicit, step-by-step thinking. In a variant, untuned [UNK] tokens are appended: "[t]his baseline isolates the contribution of our test-time optimization procedure for latent thought tokens."
  2. SoftCoT performs reasoning in the continuous, latent space. It outperforms Coconut in certain cases, for example.
  3. LatentSeek applies test-time optimisation; unlike LTPO, it uses full AR decoding to evaluate intermediate steps.

Performance of LTPO vs baselines

Further experiments are conducted and comparisons in the §4 of the paper, not reported here, including the following.

Conclusion

The paper decouples generation of latent reasoning from the LLM - which was trained on discrete tokens. Instead, the latent thought vectors are optimised directly at test time, via RL.

Getting a good reward signal at test time is the real challenge. They use an internal confidence-based reward, in essence aiming to sharpen the distribution. This definition is somewhat ad hoc, and certainly leaves open research into a better reward.

Overall, the paper is written well, with the experiments clearly laid out.