AI바라기의 인공지능
LLM : 논문리뷰 : Training Large Language Models to Reason in a Continuous Latent Space 본문
LLM : 논문리뷰 : Training Large Language Models to Reason in a Continuous Latent Space
AI바라기 2025. 1. 26. 19:15논문 정리 노트: Training Large Language Models to Reason in a Continuous Latent Space
Purpose of the Paper
기존 Large Language Models (LLMs)는 복잡한 추론 문제를 해결하기 위해 "language space" 내에서 chain-of-thought (CoT) 와 같은 방식으로 언어적 추론 과정을 생성하도록 학습됩니다. 하지만 이 논문에서는 언어 공간 자체가 추론에 항상 최적인 공간이 아닐 수 있다는 점을 지적하며, 특히 언어 토큰 중 상당수가 텍스트의 일관성을 위해서만 필요하고 실제 추론 과정에 필수적이지 않다는 점을 문제 삼습니다. 기존 연구들은 언어 공간 내에서 추론의 효율성을 높이거나 특정 토큰에 집중하는 방식으로 문제를 해결하려 했지만, 근본적인 문제인 언어 공간의 제약을 벗어나지 못했습니다.
따라서 이 논문의 목적은 LLM이 자연어 대신 비제약적인 latent space에서 추론할 수 있는 새로운 패러다임을 탐색하고, 이를 통해 언어 공간의 제약에서 벗어나 더 효율적이고 강력한 추론 능력을 갖도록 하는 것입니다. 특히, 기존 CoT 방식의 한계인 단일 결정적 경로에 대한 조기 commitment를 극복하고, breadth-first search (BFS) 와 유사한 방식으로 다양한 추론 경로를 탐색할 수 있는 가능성을 제시하고자 합니다.
Key Contributions
- 새로운 추론 패러다임 COCONUT (Chain of Continuous Thought) 제안: LLM의 마지막 hidden state를 "continuous thought"로 정의하고, 이를 언어 토큰으로 디코딩하는 대신 다음 입력 embedding으로 직접 feedback하는 방식을 도입했습니다. 이를 통해 LLM이 언어 공간의 제약 없이 latent space에서 추론하도록 유도합니다.
- 다단계 학습 전략 (Multi-stage training strategy) 적용: latent reasoning 학습을 강화하기 위해, 언어적 reasoning chain을 활용하여 학습 과정을 guidance하는 다단계 학습 전략을 Deng et al. (2024)의 연구에서 영감을 받아 적용했습니다.
- Emergent Advanced Reasoning Patterns 발견: COCONUT 모델이 language-based reasoning과 달리, continuous thought 내에 다수의 잠재적인 다음 추론 단계를 인코딩하여 BFS와 유사한 추론 패턴을 보임을 실험적으로 입증했습니다. 이는 모델이 초기 단계에서 잘못된 결정을 하더라도, continuous thought 내에 다양한 옵션을 유지하고 점진적으로 잘못된 경로를 제거하는 방식으로 발전된 추론 메커니즘을 학습한다는 것을 의미합니다.
- 실험적 검증 및 성능 향상: 다양한 reasoning tasks (GSM8k, ProntoQA, ProsQA) 에서 COCONUT이 기존 CoT 방법보다 성능이 뛰어나거나 견줄만하며, 특히 planning 능력이 중요한 logical reasoning task에서 뛰어난 성능을 보임을 입증했습니다. 또한, COCONUT은 inference 시 더 적은 thinking tokens를 생성하면서도 더 나은 성능을 달성하여 효율성을 높였습니다.
Novelty
- Latent Space Reasoning: 기존 LLM 추론 방식에서 벗어나 latent space에서 직접 추론하는 새로운 패러다임을 제시한 점이 가장 큰 novelty입니다. 이는 LLM 추론의 근본적인 공간 제약을 재고하고, 새로운 가능성을 탐색했다는 점에서 의미가 있습니다.
- Continuous Thought 개념 도입: LLM의 hidden state를 "continuous thought"라는 새로운 개념으로 정의하고, 이를 추론 과정의 핵심 representation으로 활용한 점이 독창적입니다. 이는 LLM의 내부 representation을 활용하는 새로운 접근 방식을 제시합니다.
- BFS-like Reasoning Pattern: COCONUT 모델이 명시적인 instruction이나 training 없이도 BFS와 유사한 advanced reasoning pattern을 emergent하게 학습한다는 점이 놀랍습니다. 이는 latent space reasoning의 잠재력을 보여주는 중요한 발견입니다.
- Multi-stage Training Curriculum for Latent Reasoning: iCoT (Deng et al., 2024)에서 영감을 받은 다단계 학습 curriculum을 latent reasoning에 적용하여 효과적인 학습 전략을 개발한 점도 novelty에 해당합니다.
Experimental Highlights
- GSM8k (Math Reasoning): COCONUT은 language reasoning chains의 효과를 미러링하며 reasoning accuracy를 향상시켰습니다. 이는 continuous thought가 math reasoning에서도 효과적임을 보여줍니다. 특히, iCoT (Deng et al., 2024) 와 비교했을 때도 성능이 우수했습니다.
- ProntoQA & ProsQA (Logical Reasoning): planning 능력이 더 요구되는 logical reasoning task에서 COCONUT은 CoT methods를 능가하는 성능을 보였습니다. 특히, 새롭게 제안된 ProsQA dataset에서 COCONUT의 성능 향상이 두드러졌습니다.
- Token Efficiency: COCONUT은 CoT methods보다 훨씬 적은 thinking tokens를 생성하면서도 comparable하거나 더 나은 성능을 달성했습니다. 이는 latent space reasoning이 언어 공간 추론보다 효율적일 수 있음을 시사합니다.
- Latent Tree Search 분석: COCONUT의 latent reasoning 과정을 tree search 관점에서 분석하여, 모델이 BFS와 유사하게 다양한 추론 경로를 탐색하고, 점진적으로 promising path에 집중하는 경향을 보임을 시각적으로 확인했습니다 (Figure 8).
- Value Function 기반 의사 결정: 모델이 각 node의 potential to reach the target을 평가하는 implicit value function을 학습하며, node의 height (leaf node까지의 최단 거리) 가 낮을수록 정확한 평가가 용이함을 분석했습니다 (Figure 9).
Limitations and Future Work
- Inference 시 <eot> token 결정: latent mode 종료 시점을 결정하는 문제, 특히 <eot> token 자동 결정 classifier 또는 constant length padding 방식 모두 완벽하지 않다는 점을 언급합니다. 더 robust한 latent mode 종료 전략 연구가 필요합니다.
- 학습 효율성: COCONUT 학습 시 multiple forward passes의 sequential nature 때문에 parallelism에 어려움이 있으며, 학습 효율성 개선이 필요합니다.
- 일반적인 학습 전략 부재: COCONUT w/o curriculum 실험 결과에서 보듯이, language reasoning chains의 guidance 없이 latent space에서 효과적으로 reasoning을 학습하는 일반적인 전략 개발이 필요합니다.
- 더 큰 c 값 (latent thoughts 수) 에 대한 불안정성: c=3 이상에서 성능 저하 및 불안정성이 관찰되었으며, finer-grained schedules 또는 incremental 방식의 continuous thought 추가 등 학습 안정성 개선 연구가 필요합니다.
- Pre-training with Continuous Thoughts: future work로 continuous thoughts를 활용한 LLM pre-training을 제안합니다. 이는 모델이 더 넓은 범위의 reasoning scenarios에서 효과적으로 generalize하도록 도울 수 있습니다.
- Language & Latent Reasoning 결합: language reasoning skeleton 생성 후 latent space에서 reasoning process를 완료하는 등 language와 latent reasoning을 결합하는 방식이 성능 및 안정성 향상에 도움이 될 수 있음을 언급합니다.
Future Work 방향:
- Latent reasoning methods의 refining 및 scaling
- Continuous thoughts 기반 LLM pre-training
- Language와 latent reasoning 결합 방식 연구
- 학습 효율성 및 안정성 개선 연구
- 더 일반적인 latent space reasoning 학습 전략 개발
Abstaract
Large language models (LLMs) are “language space”에서 추론하는 데 제한됩니다. 여기서 그들은 일반적으로 복잡한 reasoning problem을 해결하기 위해 chain-of-thought (CoT)로 reasoning process를 표현합니다.
그러나, 우리는 language space가 reasoning에 항상 최적인 것은 아닐 수 있다고 주장합니다. 예를 들어, 대부분의 word token은 주로 텍스트 일관성을 위한 것이며 reasoning에 필수적이지 않지만, 일부 중요한 token은 복잡한 planning을 필요로 하며 LLMs에 큰 어려움을 제기합니다.
natural language를 사용하는 대신 무제한의 latent space에서 LLM reasoning의 잠재력을 탐구하기 위해, 우리는 새로운 paradigm Coconut (Chain of Continuous Thought)를 소개합니다. 우리는 LLM의 마지막 hidden state를 reasoning state("continuous thought"라고 칭함)의 representation으로 활용합니다. 이것을 word token으로 디코딩하는 대신, 우리는 그것을 continuous space에서 직접 후속 input embedding으로서 LLM에 다시 피드백합니다.
Experiments는 Coconut이 여러 reasoning task에서 LLM을 효과적으로 보강할 수 있음을 보여줍니다. 이 새로운 latent reasoning paradigm은 emergent advanced reasoning patterns으로 이어집니다. continuous thought는 여러 대안적인 다음 reasoning step을 인코딩할 수 있으며, 모델이 CoT와 같은 단일 deterministic path에 성급하게 전념하기보다는 problem을 해결하기 위해 breadth-first search (BFS)를 수행하도록 합니다. Coconut은 inference 중에 더 적은 thinking token으로 planning 중에 상당한 backtracking을 필요로 하는 특정 logical reasoning task에서 CoT보다 뛰어납니다.
이러한 발견은 latent reasoning의 가능성을 보여주고 future research에 대한 귀중한 통찰력을 제공합니다.
1 Introduction
Large language models (LLMs)는 인간 언어에 대한 광범위한 pretraining으로부터 나타나는 놀라운 reasoning abilities를 입증했습니다. next token prediction은 효과적인 training objective이지만, 그것은 reasoning machine으로서 LLM에 근본적인 constraint를 부과합니다. 즉, LLMs의 명시적인 reasoning process는 word token으로 생성되어야 한다는 것입니다. 예를 들어, chain-of-thought (CoT) reasoning으로 알려진 prevalent approach는 LLMs에게 natural language를 사용하여 step-by-step으로 솔루션을 생성하도록 prompting하거나 training하는 것을 포함합니다. 그러나 이것은 특정 인간 cognition 결과와 극명한 대조를 이룹니다. Neuroimaging 연구들은 language network (language comprehension과 production을 담당하는 brain regions의 집합)가 다양한 reasoning task 동안 대체로 inactive 상태를 유지한다는 것을 일관되게 보여주었습니다. 더욱이 evidence는 인간 언어가 reasoning보다는 communication에 최적화되어 있음을 나타냅니다.
LLMs가 reasoning을 위해 language를 사용할 때 significant issue가 발생합니다. 각 특정 reasoning token에 필요한 reasoning의 양은 크게 다르지만, current LLM architectures는 모든 token을 predicting하는 데 거의 동일한 computing budget을 할당합니다. reasoning chain의 대부분의 token은 fluency만을 위해 생성되며, actual reasoning process에는 거의 기여하지 않습니다. 반대로, 일부 critical token은 complex planning을 필요로 하며 LLMs에 큰 어려움을 제기합니다. 이전 연구에서는 LLMs에게 succinct reasoning chains를 생성하도록 prompting하거나, 일부 critical token을 생성하기 전에 additional reasoning을 수행함으로써 이러한 problems를 해결하려고 시도했지만, 이러한 solutions는 language space 내에 constraint되어 있으며 fundamental problems를 해결하지 못합니다. 반대로, LLMs가 language constraints 없이 자유롭게 reasoning하고, 필요한 경우에만 findings를 language로 translate하는 것이 ideal일 것입니다.
본 연구에서 우리는 novel paradigm인 Coconut (Chain of Continuous Thought)를 도입하여 latent space에서 LLM reasoning을 탐구합니다. 이것은 traditional CoT process에 대한 간단한 modification을 포함합니다. language model head와 embedding layer를 사용하여 hidden states와 language tokens 간에 mapping하는 대신, Coconut은 마지막 hidden state (continuous thought)를 다음 token에 대한 input embedding으로 직접 feed합니다 (Figure 1). 이 modification은 reasoning이 language space 내에 있는 것으로부터 freeing시키고, continuous thoughts가 fully differentiable하기 때문에 system은 gradient descent에 의해 end-to-end로 optimize될 수 있습니다. latent reasoning의 training을 enhance하기 위해, 우리는 language reasoning chains를 training process를 guide하는 데 효과적으로 활용하는 multi-stage training strategy를 employ합니다.
흥미롭게도, 우리가 제안한 paradigm은 efficient reasoning pattern으로 이어집니다. language-based reasoning과 달리, Coconut의 continuous thoughts는 여러 potential next steps를 동시에 encode할 수 있으며, breadth-first search (BFS)와 유사한 reasoning process를 가능하게 합니다. 모델이 initially 올바른 decision을 내리지 못할 수도 있지만, continuous thoughts 내에서 많은 possible options을 유지하고, 일부 implicit value functions에 의해 guided되어 reasoning을 통해 incorrect paths를 점진적으로 eliminate할 수 있습니다. 이 advanced reasoning mechanism은 traditional CoT를 surpass합니다. 심지어 모델이 이전 연구에서 볼 수 있듯이 이러한 방식으로 operate하도록 명시적으로 trained되거나 instructed되지 않았음에도 불구하고 그렇습니다.
Experimentally, Coconut은 LLMs의 reasoning capabilities를 successfully enhance합니다. math reasoning (GSM8k)의 경우, continuous thoughts를 사용하는 것이 language reasoning chains의 effects를 mirroring하면서 reasoning accuracy에 beneficial한 것으로 나타났습니다. 이것은 더 많은 continuous thoughts를 chaining함으로써 increasingly challenging problems를 scale하고 solve할 수 있는 potential을 나타냅니다. stronger planning ability를 필요로 하는 ProntoQA 및 우리가 newly proposed한 ProsQA (Section 4.1)를 포함한 logical reasoning에서, Coconut과 일부 variants는 inference 중에 significantly fewer tokens를 생성하면서도 language-based CoT methods를 surpass합니다. 우리는 이러한 findings가 latent reasoning의 potential을 underscore하고 future research에 대한 valuable insights를 provide할 수 있다고 믿습니다.
논문 핵심 요약 노트: Coconut - Chain of Continuous Thought (Introduction 섹션)
주요 문제 제기:
- LLMs의 Reasoning 방식의 비효율성: 현재 LLMs는 "language space" 내에서 chain-of-thought (CoT) 방식으로 reasoning을 수행합니다. 하지만 language token 기반 reasoning은 본질적으로 비효율적입니다.
- Reasoning token 중요도 불균형: 대부분의 token은 문장 유창성을 위해 생성될 뿐, 실제 reasoning에는 거의 기여하지 않습니다. 반면, 핵심적인 reasoning token은 복잡한 계획을 요구하며 LLMs에게 큰 부담을 줍니다.
- 계산 자원 낭비: LLM architecture는 모든 token 예측에 거의 동일한 계산 자원을 할당하므로, 중요도가 낮은 token 생성에 불필요한 연산이 소모됩니다.
- 인간 인지와의 괴리: 인간의 reasoning 과정은 언어 네트워크의 활성화가 미미하며, 언어는 본래 communication에 최적화되어 reasoning 도구로는 비효율적일 수 있습니다.
기존 연구의 한계:
- 기존 연구들은 succinct CoT prompting, critical token 이전 추가 reasoning 등의 방법으로 language space 내에서 효율성을 개선하려 했으나, 근본적인 문제 해결에는 미흡했습니다.
Coconut Paradigm의 제안 (핵심 아이디어):
- Latent Space Reasoning 도입: language space의 제약에서 벗어나 "latent space"에서 LLM reasoning을 수행하는 새로운 패러다임 **Coconut (Chain of Continuous Thought)**를 제안합니다.
- Continuous Thought: LLM의 마지막 hidden state를 "continuous thought"라는 reasoning state representation으로 활용합니다.
- Hidden State Feedback: hidden state를 word token으로 decoding하는 대신, 다음 token의 input embedding으로 직접 LLM에 feedback합니다. 즉, reasoning process가 token space를 거치지 않고 hidden state space 내에서 순환합니다.
Coconut의 장점 및 기대 효과:
- Reasoning 효율성 증대: 불필요한 language token 생성을 최소화하고, reasoning에 집중된 연산이 가능해져 효율성이 향상될 것으로 기대됩니다.
- Emergent Advanced Reasoning 패턴:
- BFS (Breadth-First Search) 유사 Reasoning: Continuous thought는 여러 가능한 다음 reasoning step을 동시에 encoding하여, CoT의 단일 결정적 경로 탐색 방식과 달리 BFS처럼 다양한 가능성을 탐색하는 reasoning이 가능해집니다. (모델이 명시적으로 BFS를 학습하거나 지시받지 않아도 emergent하게 나타나는 현상)
- Backtracking 능력 향상: 잘못된 경로에 대한 미련 없이 다양한 옵션을 유지하며 reasoning을 진행, backtracking이 필요한 논리적 추론 task에서 강점을 보일 것으로 예상됩니다.
실험적 검증 및 기대:
- Math reasoning (GSM8k) 및 Logical reasoning (ProntoQA, ProsQA) task에서 Coconut의 성능 향상을 실험적으로 입증합니다. 특히 logical reasoning task에서 CoT 대비 우수한 성능과 더 적은 token 생성 수를 보여줍니다.
- Latent reasoning의 잠재력을 강조하며, future research에 valuable insights를 제공할 것으로 기대됩니다.
핵심 키워드: Latent Space Reasoning, Continuous Thought, Chain-of-Thought (CoT) 한계 극복, Reasoning 효율성, Breadth-First Search (BFS) 유사 Reasoning, Emergent Behavior
Figure 1은 Chain of Continuous Thought (Coconut)와 Chain-of-Thought (CoT)를 비교합니다.
CoT에서, model은 reasoning process를 word token sequence (예: 그림에서 xi, xi+1, ..., xi+j)로 생성합니다.
Coconut은 마지막 hidden state를 reasoning state("continuous thought"라고 칭함)의 representation으로 간주하고, 그것을 다음 input embedding으로 직접 사용합니다.
이것은 LLM이 language space 대신 unrestricted latent space에서 reason할 수 있도록 합니다.
2 Related Work
Chain-of-thought (CoT) reasoning. 우리는 chain-of-thought라는 용어를 final answer를 output하기 전에 language로 intermediate reasoning process를 generate하는 methods를 broadly하게 refer하기 위해 사용합니다. 이것은 LLMs를 prompting하거나 (Wei et al., 2022; Khot et al., 2022; Zhou et al., 2022), reasoning chains를 generate하도록 LLMs를 training하는 것을 포함하며, supervised finetuning (Yue et al., 2023; Yu et al., 2023) 또는 reinforcement learning (Wang et al., 2024; Havrilla et al., 2024; Shao et al., 2024; Yu et al., 2024a)을 사용합니다. Madaan and Yazdanbakhsh (2022)는 CoT의 tokens를 symbols, patterns, and text로 분류하고, LLM을 guide하여 그들의 역할을 분석하여 concise CoT를 generate하도록 제안했습니다. 최근 theoretical analyses는 model expressivity 관점에서 CoT의 usefulness를 입증했습니다 (Feng et al., 2023; Merrill and Sabharwal, 2023; Li et al., 2024). CoT를 employing함으로써, transformer의 effective depth는 generated outputs가 input으로 looped back되기 때문에 증가합니다 (Feng et al., 2023). CoT의 established effectiveness와 결합된 이러한 analyses는 continuous thoughts를 다음 input embedding으로서 LLM에 feedback하는 우리의 design에 동기를 부여했습니다. CoT는 특정 task에 효과적인 것으로 입증되었지만, 그것의 autoregressive generation nature는 더 complex problems (LeCun, 2022; Hao et al., 2023)에 대한 human reasoning을 mimic하는 것을 challenging하게 만듭니다. 이러한 problems는 일반적으로 planning과 search를 필요로 합니다. LLMs에 explicit tree search algorithms를 equip하는 works (Xie et al., 2023; Yao et al., 2023; Hao et al., 2024) 또는 search dynamics and trajectories에 대해 LLM을 training하는 works (Lehnert et al., 2024; Gandhi et al., 2024; Su et al., 2024)가 있습니다. 우리의 analysis에서, 우리는 language space의 constraint를 removing한 후, model이 이러한 방식으로 explicitly trained되지 않았음에도 불구하고 BFS와 유사한 새로운 reasoning pattern이 emerges하는 것을 발견합니다.
Latent reasoning in LLMs. Previous works는 LLMs에서 latent reasoning을 transformers에서의 hidden computation으로 mostly define합니다 (Yang et al., 2024; Biran et al., 2024). Yang et al. (2024)은 two-hop reasoning problems의 dataset을 constructed하고 intermediate variable을 hidden representations로부터 recover하는 것이 possible하다는 것을 discovered했습니다. Biran et al. (2024)은 furthermore hidden representation을 "back-patching"하여 latent reasoning에 intervene하는 것을 proposed했습니다. Shalev et al. (2024)은 LLMs에서 parallel latent reasoning paths를 discovered했습니다. 또 다른 line of work는 model이 reason하기 위해 CoT를 generate하는 경우에도, model이 actually different latent reasoning process를 utilize할 수 있다는 것을 discovered했습니다. 이 phenomenon은 CoT reasoning의 unfaithfulness로 알려져 있습니다 (Wang et al., 2022; Turpin et al., 2024). LLM의 latent reasoning을 enhance하기 위해, previous research는 additional tokens로 augment하는 것을 proposed했습니다. Goyal et al. (2023)은 learnable <pause> tokens를 training corpus에 randomly inserting하여 model을 pretrained했습니다. 이것은 다양한 task, especially <pause> tokens를 사용한 supervised finetuning이 뒤따를 때 LLM의 performance를 improves합니다. On the other hand, Pfau et al. (2024)은 filler tokens (예: “...”)의 usage를 further explored하고, 그것들이 highly parallelizable problems에 잘 작동한다고 concluded했습니다. However, Pfau et al. (2024)은 이러한 methods가 CoT처럼 LLM의 expressivity를 extend하지 않는다고 mentioned했습니다. hence, 그것들은 more general and complex reasoning problems로 scale하지 않을 수 있습니다. Wang et al. (2023)은 다음 reasoning step을 generate하기 전에 discrete latent variable로서 planning token을 predict하는 것을 proposed했습니다. Recently, knowledge distillation (Deng et al., 2023) 또는 CoT를 gradually shorten하는 special training curriculum (Deng et al., 2024)을 사용하여 transformer에서 CoT reasoning을 latent reasoning으로 "internalize"할 수 있다는 것도 found되었습니다. Yu et al. (2024b)은 complex reasoning algorithms로 generated된 data로부터 latently reason할 수 있는 model을 distill하는 것도 proposed했습니다. 이러한 training methods는 우리의 framework에 combined될 수 있으며, specifically, 우리는 iCoT (Deng et al., 2024)에서 inspired된 continuous thoughts의 learning을 multiple stages로 breaking down하는 것이 training에 매우 beneficial하다는 것을 find합니다. Recently, looped transformers (Giannou et al., 2023; Fan et al., 2024)는 algorithmic tasks를 solve하기 위해 proposed되었으며, 이는 continuous thoughts의 computing process와 some similarities를 가지지만, 우리는 common reasoning tasks에 focus하고 language space와 비교하여 latent reasoning을 investigating하는 것을 aim합니다.
논문 핵심 요약 노트: Coconut - Chain of Continuous Thought (Related Work 섹션)
핵심 연구 배경 및 선행 연구 분석:
- Chain-of-Thought (CoT) Reasoning의 현황과 한계:
- CoT는 intermediate reasoning process를 language로 생성하는 방법론들을 포괄하는 용어로 정의 (prompting, finetuning, RL 등 다양한 구현 방식 존재).
- CoT의 유용성은 모델 표현력 관점에서 이론적으로도 입증됨 (transformer depth 증가 효과).
- CoT의 한계 명확히 지적:
- Autoregressive 생성 방식의 본질적인 제약으로 인해 복잡한 문제에서 인간 추론 방식 (planning, search) 모방에 어려움 존재.
- 명시적인 트리 탐색 알고리즘 결합 또는 탐색 궤적 학습 연구 존재하나, 여전히 language space에 갇혀있는 한계.
- Coconut과의 연결: CoT의 효과성은 hidden state feedback 아이디어의 동기가 되었으나, Coconut은 language space 제약을 제거하여 CoT의 근본적인 한계를 극복하고자 함. Emergent한 BFS 유사 reasoning 패턴을 통해 CoT와 차별화.
- LLM Latent Reasoning 연구 동향 분석:
- 기존 연구들은 latent reasoning을 transformer 내부의 hidden computation으로 주로 정의.
- Latent Reasoning 분석 및 개입 시도: hidden representation에서 중간 변수 복구, back-patching을 통한 개입, parallel latent reasoning paths 발견 등 연구 존재.
- CoT Unfaithfulness 문제: 모델이 CoT를 생성해도 실제 latent reasoning process는 다를 수 있음 지적. Coconut은 명시적인 CoT 없이 latent reasoning 자체에 집중.
- Latent Reasoning 강화 방법: <pause> token, filler token 활용 연구 존재하나, CoT처럼 모델 표현력 확장에는 한계 지적. Planning token 예측 연구는 discrete latent variable 활용.
- CoT를 Latent Reasoning으로 "Internalize" 하는 연구: Knowledge distillation, curriculum learning 활용 연구 소개. Coconut은 이러한 연구들과 training 방법론을 공유할 수 있음을 언급, 특히 iCoT의 multi-stage training 전략을 Coconut에 적용하여 효과를 봄.
- Looped Transformer: 알고리즘 task에 적용된 looped transformer와 유사성 언급, but Coconut은 일반적인 reasoning task에 집중하며 language space 대비 latent reasoning 탐구에 초점.
Coconut의 차별성 및 핵심 강조점:
- Language Space 제약 탈피: 기존 연구들이 language space 내에서 CoT 개선 또는 latent reasoning 분석에 집중한 반면, Coconut은 latent space reasoning 패러다임을 전면적으로 제시.
- Emergent BFS-like Reasoning: 명시적인 search 알고리즘 없이 latent space reasoning 자체에서 BFS와 유사한 탐색 능력이 자연스럽게 나타나는 점을 강조.
- Multi-stage Training: iCoT에서 영감을 받은 multi-stage training 전략이 latent reasoning 학습에 효과적임을 밝힘.
결론: Coconut은 CoT의 한계를 극복하고 latent reasoning의 잠재력을 극대화하기 위해 language space를 벗어나 새로운 방향을 제시하는 연구. 특히 emergent BFS-like reasoning 능력은 기존 연구와 차별화되는 중요한 contribution.
3 Coconut: Chain of Continuous Thought
In this section, we introduce our new paradigm Coconut (Chain of Continuous Thought) for reasoning in an unconstrained latent space. 우리는 language models에 사용하는 background와 notation을 introducing하는 것으로 시작하겠습니다. input sequence x = (x1, ..., xT )에 대해, standard large language model M은 다음과 같이 described될 수 있습니다.
Ht = Transformer(Et)
M(xt+1 | x≤t) = softmax(W ht)
여기서 Et = [e(x1), e(x2), ..., e(xt)]는 position t까지의 token embeddings의 sequence입니다. Ht ∈ R^(t×d)는 position t까지의 모든 tokens에 대한 마지막 hidden states의 matrix입니다. ht는 position t의 마지막 hidden state, 즉 ht = Ht[t, :]입니다. e(·)는 token embedding function입니다. W는 language model head의 parameter입니다.
Method Overview. proposed된 Coconut method에서, LLM은 “language mode”와 “latent mode” 사이를 switches합니다 (Figure 1). language mode에서, model은 standard language model로서 operates하며, autoregressively 다음 token을 generating합니다. latent mode에서, 그것은 마지막 hidden state를 다음 input embedding으로서 directly utilizes합니다. 이 마지막 hidden state는 current reasoning state를 represents하며, “continuous thought”라고 termed됩니다.
Special tokens <bot> 및 <eot>는 각각 latent thought mode의 beginning과 end를 mark하기 위해 employed됩니다. 예시로서, 우리는 latent reasoning이 positions i와 j 사이에서 occurs한다고 가정합니다. 즉, xi = <bot>이고 xj = <eot>입니다. model이 latent mode (i < t < j)에 있을 때, 우리는 input embedding을 replace하기 위해 previous token으로부터의 마지막 hidden state를 사용합니다. 즉, Et = [e(x1), e(x2), ..., e(xi), hi, hi+1, ..., ht−1] 입니다. latent mode가 finishes된 후 (t ≥ j), input은 token embedding을 사용하는 것으로 reverts됩니다. 즉, Et = [e(x1), e(x2), ..., e(xi), hi, hi+1, ..., hj−1, e(xj ), ..., e(xt)]입니다. 마지막 hidden states는 final normalization layer에 의해 processed되었으므로, magnitude가 너무 크지 않다는 점에 worth noting합니다. M(xt+1 | x≤t)는 i < t < j일 때 defined되지 않습니다. latent thought가 language space로 mapped back되도록 intended되지 않았기 때문입니다. However, softmax(W ht)는 probing purposes를 위해 still calculated될 수 있습니다 (Section 4 참조).
Training Procedure. In this work, 우리는 model이 question을 input으로 receive하고 reasoning process를 통해 answer를 generate하도록 expected되는 problem-solving setting에 focus합니다. 우리는 Deng et al.(2024)에서 inspired된 multi-stage training curriculum을 implementing하여 continuous thought를 supervise하기 위해 language CoT data를 leverage합니다. Figure 2에서 shown된 바와 같이, initial stage에서, model은 regular CoT instances에 대해 trained됩니다. subsequent stages에서, k-th stage에서, CoT의 처음 k reasoning steps는 k × c continuous thoughts로 replaced됩니다. 여기서 c는 single language reasoning step을 replace하는 latent thoughts의 number를 controlling하는 hyperparameter입니다. Deng et al.(2024)을 following하여, 우리는 training stages가 switch될 때 optimizer state를 reset합니다. 우리는 continuous thoughts를 encapsulate하기 위해 <bot> 및 <eot> tokens (c에 towards counted되지 않음)를 insert합니다. training process 동안, 우리는 normal negative log-likelihood loss를 optimize하지만, questions와 latent thoughts에 대한 loss를 mask합니다. objective는 continuous thought가 removed된 language thought를 compress하도록 encourage하는 것이 아니라, future reasoning의 prediction을 facilitate하기 위한 것임을 important하게 noting해야 합니다. Therefore, LLM이 human language에 compared to reasoning steps의 more effective representations를 learn하는 것이 possible합니다.
Training Details. Our proposed continuous thoughts는 fully differentiable하며 back-propagation을 allow합니다. 우리는 current training stage에서 n latent thoughts가 scheduled될 때 n + 1 forward passes를 perform하며, each pass로 new latent thought를 computing하고 finally remaining text sequence에 대한 loss를 obtain하기 위해 additional forward pass를 conducting합니다. 우리는 KV cache를 사용하여 any repetitive computing을 save할 수 있지만, multiple forward passes의 sequential nature는 parallelism에 대한 challenges를 poses합니다. Coconut의 training efficiency를 further optimizing하는 것은 future research를 위한 important direction으로 remains합니다.
Inference Process. Coconut에 대한 inference process는 latent mode에서, 우리가 마지막 hidden state를 다음 input embedding으로서 directly feed한다는 점을 except하고는 standard language model decoding과 analogous합니다. challenge는 latent mode와 language mode 사이를 언제 switch할지 determining하는 데 lies합니다. 우리가 problem-solving setting에 focus하므로, 우리는 question tokens 바로 following하여 <bot> token을 insert합니다. <eot>에 대해, 우리는 two potential strategies를 consider합니다. a) model이 latent reasoning을 terminate할 시기를 autonomously decide할 수 있도록 latent thoughts에 대해 binary classifier를 train하거나, b) always latent thoughts를 constant length로 pad합니다. 우리는 both approaches가 comparably well하게 work한다는 것을 found했습니다. Therefore, 우리는 otherwise specified되지 않는 한, simplicity를 위해 experiment에서 second option을 사용합니다.
논문 핵심 요약 노트: Coconut - Chain of Continuous Thought (Section 3)
Coconut Paradigm 핵심 아이디어:
- Unconstrained Latent Space Reasoning: 기존 LLM의 "language space" 제약에서 벗어나 latent space에서 reasoning을 수행하는 새로운 패러다임 Coconut (Chain of Continuous Thought) 제시.
- Continuous Thought: LLM의 last hidden state를 "continuous thought"라는 reasoning state의 representation으로 정의. Word token decoding 없이 hidden state 자체를 reasoning 정보로 활용.
- Language Mode vs. Latent Mode: LLM은 두 가지 mode로 동작:
- Language Mode: Standard LLM처럼 autoregressive token generation.
- Latent Mode: Last hidden state를 다음 input embedding으로 직접 사용 (Hidden State Feedback). Reasoning process가 hidden state space 내에서 순환.
- Mode Switching Tokens: <bot> (latent mode 시작), <eot> (latent mode 종료) special tokens을 사용하여 mode 전환 제어.
Method 상세:
- Input Embedding 구성:
- Latent Mode ( <bot>과 <eot> 사이): Input embedding sequence (Et) 구성 시, 이전 token embedding 대신 이전 hidden state (hi, hi+1, ...) 를 사용. Hidden state가 reasoning 정보의 carrier 역할.
- Language Mode (latent mode 외부): Standard token embedding (e(x)) 사용.
- Loss Function: Standard negative log-likelihood loss 사용, but question 및 latent thought 부분의 loss는 masking. Continuous thought가 language thought 압축이 아닌 future reasoning prediction 촉진하도록 학습 유도. 더 효과적인 reasoning representation 학습 목표.
Training Procedure (Multi-stage Curriculum Learning):
- iCoT (Deng et al., 2024) inspired Multi-stage Training: Language CoT data를 활용하여 continuous thought 학습.
- Curriculum 구성:
- Initial Stage: Regular CoT instance 학습.
- Subsequent Stages (k-th stage): CoT의 처음 k reasoning steps를 k × c 개의 continuous thoughts로 점진적으로 대체. c는 latent thought 개수 hyperparameter.
- Optimizer Reset: Training stage 전환 시 optimizer state reset (Deng et al. (2024) 방식 차용).
- <bot>, <eot> token 삽입: Continuous thought 구간 encapsulation.
Training Details:
- Differentiable Continuous Thoughts: Backpropagation 가능.
- n+1 Forward Passes: n개의 latent thought 구간에서 각 latent thought 생성 시 forward pass, remaining text sequence loss 계산 위해 추가 forward pass.
- KV Cache 활용 가능: But, sequential forward pass로 인해 parallelism challenge 존재. Training efficiency 개선은 future work.
Inference Process:
- Standard Decoding 유사: Latent mode에서 hidden state feedback 적용 외에는 standard language model decoding과 유사.
- Mode Switching 결정:
- <bot>: Question token 직후 삽입 (problem-solving setting 가정).
- <eot>: 2가지 전략 고려:
- Binary Classifier: Latent thought 기반 종료 시점 예측 (autonomously).
- Fixed Length Padding: Constant length로 latent thought padding (simplicity 위해 선택).
핵심 차별점 및 시사점:
- Latent Space Reasoning 패러다임: Language token 기반 reasoning의 한계를 극복하고 latent space의 potential 활용.
- Hidden State Feedback Mechanism: Reasoning process를 hidden state circulation으로 구현하는 novel approach.
- Multi-stage Training: Continuous thought 효과적 학습 위한 curriculum learning 전략.
Future Research 방향: Training efficiency 최적화, <eot> token 결정 전략 개선 등