AI바라기의 인공지능

LLM : 논문리뷰 : STaR: Bootstrapping Reasoning With Reasoning 본문

논문리뷰

LLM : 논문리뷰 : STaR: Bootstrapping Reasoning With Reasoning

AI바라기 2025. 1. 18. 14:46

 

STaR: Self-Taught Reasoner - Bootstrapping Reasoning With Reasoning

Purpose of the Paper

기존의 reasoning 능력을 향상시키기 위한 large language models (LLMs) 은 대규모 rationale datasets 을 구축하거나, few-shot inference 만을 사용함으로써 accuracy 가 희생되는 한계가 있었습니다. 이 논문은 LLM 이 스스로 reasoning 하는 방법을 학습하도록 하는 새로운 방법론인 Self-Taught Reasoner (STaR) 을 제안합니다. STaR 은 적은 수의 rationale examplesrationale 이 없는 대규모 dataset 을 반복적으로 활용하여 점진적으로 더 복잡한 reasoning 을 수행할 수 있는 능력을 bootstrapping 합니다. 즉, 모델이 생성한 reasoning 을 스스로 학습에 활용함으로써, 외부의 정답 dataset 에 의존하지 않고도 자체적으로 성능을 향상시키는 것이 이 논문의 핵심적인 purpose 입니다.

Key Contributions

  1. Bootstrapping Mechanism: rationale dataset 을 반복적으로 생성하기 위한 bootstrapping mechanism 을 제안합니다. 이 mechanism 은 새로운 rationales 의 정확성을 확인할 필요 없이, few initial examples with rationales 로부터 시작합니다.
  2. Rationale GenerationRationalization 의 결합: modelanswerjustifying 하도록 한 다음, 마치 hint 없이 rationale 을 생각해낸 것처럼 fine-tuning 합니다. 이러한 rationalizationbootstrapping process 를 가속화하고 개선합니다.
  3. Mathematical and Commonsense Reasoning Domains 에서 검증: 다양한 ablations 을 통해 mathematical and commonsense reasoning domains 에서 STaR 의 효과를 검증합니다.
  4. Iterative Self-Improvement: pre-trained large language model 이 자체 language modeling capacity 를 반복적으로 사용하여 스스로 개선하도록 하는 최초의 기술을 제안합니다.

Novelty

  • Rationale Generation 을 통한 Self-Improvement: 기존 연구들이 외부 dataset 이나 human feedback 에 의존했던 반면, STaRmodel 이 생성한 rationalefeedback 으로 사용하여 자체적으로 reasoning 능력을 향상시킵니다.
  • Rationalization 도입: model 이 정답을 justifying 하도록 함으로써, reasoning 과정을 역으로 추론하고, rationale generation 의 한계를 극복합니다.
  • Iterative Bootstrapping: rationale generationrationalization, 그리고 fine-tuning 을 결합한 반복적인 bootstrapping 을 통해 점진적으로 reasoning 능력을 향상시킵니다.

Experimental Highlights

  • Arithmetic, Math Word Problems, and Commonsense ReasoningSTaR 을 적용하여, few-shot prompts 를 large rationale dataset 으로 효과적으로 변환하고, 성능을 크게 향상시켰습니다.
  • CommonsenseQA 에서는 few-shot baseline (+35.9%)directly predict answers 하도록 fine-tunedbaseline (+12.5%) 보다 STaR 이 우수한 성능을 보였고, 30배 더 큰 fine-tuned model (72.5% vs. 73.0%) 과 비슷한 성능을 달성했습니다.
  • GSM8K 에서는 STaRfew-shot with rationalesanswers 를 직접 예측하도록 training 하는 것보다 훨씬 뛰어난 성능을 보였습니다 (10.7% vs 3.1%, 5.8%).
  • Human Evaluation: STaR 로 생성된 rationalesfew-shot 으로 생성된 rationales 보다 reasoning quality 가 높다는 것을 확인했습니다.

Limitations

  • Initial Few-shot Performance 의존: 첫 번째 iteration 이 성공하려면 few-shot performancechance 보다 높아야 합니다. 즉, 초기 model 이 어느 정도의 reasoning capabilities 를 가지고 있어야 합니다.
  • High Chance Performance 에서의 한계: binary decisions 와 같이 chance performance 가 높은 setting 에서는 poor rationales 가 많이 생성되어 STaR 의 효과가 제한될 수 있습니다.
  • Computational Cost: 특히 large language models 을 사용할 경우, iterative training process 는 상당한 computational resources 를 필요로 합니다.
  • Drift: Sampling 중에 few-shot prompting 을 포함하면 rationales 가 초기 few-shot set 과 점점 달라지는 "drift" 현상이 크게 줄어들지만, 이로 인해 model 이 초기 rationalesqualitydifficulty 에 덜 제약되어 generalization 에 긍정적인 영향을 줄 수 있지만, 원래 prompting style 과의 일치도가 낮아질 수 있습니다.

Future Work

  • Temperature 의 영향 탐구: dataset 확장을 위한 대안으로 higher-temperature sampling 의 가능성을 탐구하고, 다양한 hinting techniques 의 영향과 일반성을 탐구합니다.
  • Prompting 최적화: samplingfew-shot prompting 을 포함하는 것의 장단점을 고려하여, 다양한 datasetsmodels 에 대한 최적의 prompting strategy 를 연구합니다.
  • Rationalization 의 개선: rationalizationreasoning 능력을 향상시키는 메커니즘을 더 깊이 이해하고, rationalization 을 개선할 수 있는 방법을 연구합니다.
  • Faithfulness 향상: 생성된 rationalesmodel 의 실제 reasoning process 를 더 잘 반영하도록 하는 방법을 연구합니다.
  • Scalability 개선: 더 큰 modelsdatasetsSTaR 을 적용하기 위한 scalability 개선 방안을 연구합니다.
  • 다양한 Task 로의 확장: STaR 을 다양한 reasoning tasksdomains 에 적용하여 그 효과와 한계를 탐구합니다.

 

 

Abstract

"chain-of-thought" 단계를 생성하는 것은 수학이나 상식 Question-answering과 같은 복잡한 reasoning tasks에 대한 language model의 성능을 향상시킵니다. 그러나 language model의 rationale generation을 유도하려면 현재 대규모 rationale datasets을 구성하거나 few-shot inference만 사용하여 정확도를 떨어뜨려야 합니다. 우리는 적은 수의 rationale 예제와 rationale이 없는 대규모 dataset을 반복적으로 활용하여, 더 복잡한 reasoning을 수행하는 능력을 bootstrap하는 기술을 제안합니다. 이 기술은 "Self-Taught Reasoner" (STaR)이며, 간단한 루프에 의존합니다. 몇 가지 rationale 예제를 prompt로 사용하여, 많은 질문에 답하기 위해 rationale을 generate 합니다. 생성된 답변이 틀리면 정답이 주어졌을 때 rationale을 다시 generate 합니다. 궁극적으로 정답을 산출한 모든 rationale에 대해 fine-tune을 수행합니다. 그리고 이 과정을 반복합니다. 우리는 STaR이 최종 답을 직접 예측하도록 fine-tune된 model에 비해 여러 datasets에서 성능을 크게 향상시키고, CommensenseQA에서 30배 더 큰 state-of-the-art language model을 fine-tuning 한 것과 비슷한 성능을 보인다는 것을 보여줍니다. 따라서 STaR은 model이 자체적으로 generate한 reasoning을 통해 학습하여 스스로 개선할 수 있도록 합니다.

 

1 Introduction

인간의 의사 결정은 종종 생각의 연장된 chain의 결과입니다. 최근 연구에 따르면 명시적인 중간 reasoning ("rationales")이 large language model (LLM)의 성능 또한 향상시킬 수 있음을 보여주었습니다. 예를 들어, 중간 단계를 위한 "scratchpads"를 사용하도록 명시적으로 trained된 LLMs가 산술에서 완벽한 in-distribution 성능과 강력한 out-of-distribution 일반화를 달성할 수 있는 반면, 정답을 직접 예측하도록 trained된 models은 둘 다 실패한다는 것을 보여주었습니다. 이러한 연구는 최종 답변을 제공하기 전에 명시적인 rationale을 생성하는 것("rationale generation")이 수학적 reasoning, 상식 reasoning, 코드 평가, 사회적 편견 추론 및 자연어 추론을 포함한 다양한 tasks에 걸쳐 LLMs에 가치가 있음을 시사합니다.

그러나 rationale generation을 유도하는 두 가지 주요 방법 모두 심각한 단점이 있습니다. rationale generation에 대한 한 가지 접근 방식은 수동으로 human annotators에 의해 또는 수작업으로 만든 templates를 사용하여 자동으로 rationale의 fine-tuning dataset을 구성하는 것입니다. 수동 방법은 비용이 많이 들고 각 흥미로운 문제에 대해 이러한 dataset을 구성하는 것은 불가능합니다. 한편, template 기반 방법은 자동으로 생성된 rationale에 의존하지만, 일반적인 솔루션이 이미 알려져 있거나 합리적인 하드 코딩된 휴리스틱을 만들 수 있는 경우에만 작동합니다.

대안은 language model prompt에 적은 수의 rationale 예제를 포함하여 in-context learning을 활용하는 것입니다. 이는 rationale이 없는 prompting ("direct" prompting)에 비해 수학적 및 기호적 reasoning tasks에 대한 정확도를 향상시키는 것으로 나타났습니다. 그러나 rationale을 사용한 few-shot 기술은 reasoning이 아닌 대응 기술보다 성능이 뛰어난 경향이 있지만, 일반적으로 더 큰 datasets를 사용하여 정답을 직접 예측하도록 fine-tune된 models보다 실질적으로 성능이 떨어집니다.

본 논문에서는 다른 접근 방식을 채택합니다. LLM의 기존 reasoning 능력을 활용하여 고품질 rationale을 생성하는 능력을 반복적으로 bootstrap합니다. 구체적으로, 우리는 large language model이 자체적으로 rationale을 생성하도록 few-shot prompt를 사용하고, 정답으로 이어지는 rationale에 대해 fine-tuning하여 model의 능력을 더 개선합니다. 우리는 이 절차를 반복하여 매번 개선된 model을 사용하여 다음 training set를 생성합니다. 이것은 상승 작용을 일으키는 과정으로, rationale generation의 개선이 training data를 개선하고, training data의 개선이 rationale generation을 더욱 개선합니다.

그러나 우리는 이 루프가 해결하지 못한 문제에 대한 직접적인 training signal을 받지 않기 때문에 결국 training set의 새로운 문제를 해결하지 못한다는 것을 발견했습니다. 이 문제를 극복하기 위해 우리는 rationalization을 제안합니다. model이 정답을 맞히지 못한 각 문제에 대해, 정답을 제공하여 새로운 rationale을 생성합니다. 이를 통해 model은 정답이 주어지면 더 쉽게 유용한 rationale을 생성할 수 있습니다. 그런 다음 이러한 rationale은 training data의 일부로 수집되어 전반적인 정확도를 향상시킵니다.

따라서 우리는 model이 자체 rationale을 생성하는 법을 배우면서 점점 더 어려운 문제를 해결하는 법을 배울 수 있도록 하는 확장 가능한 bootstrapping 방법인 Self-Taught Reasoner (STaR, 그림 1) 방법을 개발합니다. 우리 방법에서는 다음 프로세스를 반복합니다. 각 반복에서 먼저 현재 model의 rationale generation 능력을 사용하여 dataset을 해결하려고 시도하여 finetuning dataset을 구성합니다. 그런 다음 rationalization을 사용하여 이 dataset을 보강하여 model이 해결하지 못한 문제에 대한 정답을 정당화합니다. 마지막으로 결합된 dataset에 대해 large language model을 finetune합니다.

산술, 수학 단어 문제 및 상식 reasoning에 STaR을 적용하면 적은 수의 few-shot prompts를 큰 rationale dataset으로 효과적으로 변환하여 극적인 성능 향상을 가져올 수 있음을 관찰합니다. CommonsenseQA에서 STaR은 few-shot baseline (+35.9%)과 정답을 직접 예측하도록 fine-tune된 baseline (+12.5%)을 능가하고, 30배 더 큰 fine-tune된 model과 비슷한 성능(72.5% 대 73.0%)을 보입니다.

따라서 우리는 다음과 같은 기여를 합니다.

  1. 우리는 새로운 rationale의 정확성을 확인하지 않고도 적은 수의 rationale 예제에서 rationale dataset을 반복적으로 생성하는 bootstrapping 메커니즘을 제안합니다.
  2. 우리는 rationale generation을 rationalization으로 보완합니다. rationalization은 model이 정답을 정당화한 다음, 아무런 힌트 없이 rationale을 생각해 낸 것처럼 fine-tune되는 것을 말합니다. 우리는 rationalization이 bootstrapping 프로세스를 가속화하고 개선한다는 것을 보여줍니다.
  3. 우리는 수학적 및 상식적 reasoning domains에서 다양한 ablation을 통해 이러한 기술을 평가합니다.
  4. 우리는 우리의 지식으로는, pre-trained large language model이 자체 language modeling 능력을 반복적으로 사용하여 스스로를 개선할 수 있도록 하는 최초의 기술을 제안합니다.

 

이 문서는 STaR 논문의 Introduction 섹션을 AI 연구자들이 빠르게 핵심만 파악할 수 있도록 정리한 노트입니다. 일반적인 내용보다는 이 논문만의 차별화된 아이디어와 기여를 중점적으로 다룹니다.

기존 연구의 한계:

  • Rationale Generation의 어려움: LLMs의 reasoning 성능 향상을 위해 "chain-of-thought" (rationale generation)를 활용하는 연구들이 있었지만, 기존 방법들은 크게 두 가지 한계점이 있었습니다.
    • 데이터 구축: Human-annotation 또는 template 기반으로 rationale dataset을 구축하는 것은 비용이 많이 들고, 모든 문제에 적용하기 어렵습니다.
    • Few-shot의 한계: Rationale 예제를 prompt에 포함하는 few-shot 기법은 direct prediction을 fine-tuning한 모델보다 성능이 떨어지는 경향이 있었습니다.

STaR의 핵심 아이디어:

  • Iterative Bootstrapping: STaR은 LLM의 기존 reasoning 능력을 활용하여 반복적으로 rationale generation 능력을 향상시킵니다. 적은 수의 rationale 예제(few-shot)에서 시작하여, 모델 스스로 rationale을 생성하고 정답을 맞힌 rationale로 fine-tuning 하는 과정을 반복합니다.
  • Rationalization: 모델이 틀린 문제에 대해서는 정답을 제공하고, "거꾸로" rationale을 생성하도록 유도합니다(rationalization). 이를 통해 모델이 풀지 못한 문제에 대한 학습 기회를 제공하고, 성능을 더욱 향상시킵니다.

STaR의 차별점 및 기여:

  • Self-Taught Learning: 모델이 스스로 생성하고 검증한 rationale을 통해 학습하는, 자기 주도 학습(self-taught learning) 방식입니다. 이를 통해 대규모 rationale dataset 구축 없이도 reasoning 능력을 효과적으로 향상시킵니다.
  • Bootstrapping & Rationalization: Iterative bootstrapping과 rationalization을 결합하여, 적은 수의 예제만으로도 고품질의 rationale dataset을 생성하고, 이를 통해 성능을 극대화합니다.
  • 최초의 LLM 자기 개선: STaR은 pre-trained LLM이 자신의 language modeling 능력을 활용하여 반복적으로 스스로를 개선할 수 있도록 하는 최초의 기법으로 평가됩니다.

결론:

STaR은 기존 rationale generation 방식의 한계를 극복하고, LLM이 스스로 학습하여 reasoning 능력을 향상시킬 수 있는 새로운 패러다임을 제시합니다. 특히, 적은 데이터와 컴퓨팅 자원으로도 큰 폭의 성능 향상을 이끌어 낼 수 있다는 점에서, 앞으로 LLM 연구에 큰 영향을 미칠 것으로 기대됩니다.

 

 

 

 

2 Background and Related Work

In-context Learning

최근, large language models의 in-context learning 능력을 탐구하는 연구들이 등장했습니다. 본질적으로, in-context learning은 few-shot learning을 language modeling 문제로 취급합니다. context (즉, prompt)에 몇 가지 예시를 보여주고, model이 새로운 예시에 적용할 패턴을 학습하고 식별하도록 합니다. 어떤 이들은 Bayesian inference 관점에서 language modeling objective를 기반으로 in-context learning을 연구하는 반면, 다른 이들은 "induction heads" 관점에서 이 과정을 보다 기계적으로 설명하려고 시도했습니다. 더욱이, prompt 구성의 차이는 few-shot 성능에 큰 영향을 미치는 것으로 알려져 있습니다. 어떤 이들은 심지어 few-shot prompts를 embedding 공간에서 최적화될 수 있는 "soft prompt"로 대체하는 것이 눈에 띄는 이득을 가져온다는 것을 발견했습니다. 질문의 representation을 강조하는 대신, 우리는 model output에 중점을 둡니다. 특히, 결론에 도달하기 전에 문제에 대해 reasoning하는 model의 능력에 중점을 둡니다.

Rationales

Rationale이 language model 성능에 미치는 영향에 대한 초기 연구 중 하나는 정답 앞에 명시적인 rationale이 있는 dataset으로 language model을 training하는 것이 최종 정답을 생성하는 model의 능력을 향상시킬 수 있음을 보여주었습니다. 그러나 이것은 수천 개의 training 예제가 수동으로 human reasoning으로 주석이 추가되어야 했습니다. 최근, 단계별 "scratchpads"가 산술, 다항식 평가 및 프로그램 평가와 같은 tasks에서 fine-tune된 LLM 성능과 일반화를 향상시킬 수 있음을 보여주었습니다. 유사하게, fine-tuning 없이 일련의 tasks에 대한 model 성능을 개선하기 위해 단일 few-shot "chain-of-thought" reasoning prompt를 사용했습니다. 마지막으로, curriculum learning 접근 방식이 formal 수학 문제를 해결하는 데 도움이 될 수 있음을 보여주었습니다. 단, 1) 그것들이 Lean (정리 증명 언어)으로 변환되고, 2) 증명의 타당성을 직접 평가할 수 있고, 3) 각 문제에 대해 수많은 잠재적 솔루션을 샘플링할 수 있고, 4) 별도의 value function model을 trained했으며, 5) GPT-f (대규모 수학 dataset에 이미 fine-tune된 model)로 시작한 경우에 한합니다. 이러한 조건이 모두 적용되지 않는 많은 domains이 있음을 주목할 필요가 있습니다. 또한, rationale이 왜 이러한 이점을 가지는지 설명하기 위한 연구들이 있었습니다. 어떤 이들은 latent variable models의 관점에서 그 영향을 분석한 반면, 다른 이들은 intermediate task supervision의 이점에 대한 formal 증명을 제공했습니다.

Iterated Learning

발견된 솔루션 또는 성공적인 방법을 사용하여 추가 솔루션을 찾는 다양한 iterated learning 알고리즘이 제안되었습니다. Expert Iteration (ExIt)은 우리 접근 방식에 영감을 준 reinforcement learning 기술입니다. 본질적으로, "apprentice"에 의한 self-play 루프, 느린 "expert"의 피드백을 통한 imitation learning, 그리고 이제 개선된 apprentice로 expert를 교체하는 것으로 구성됩니다. formal reasoning을 위해 ExIt을 기반으로 하는 반면, 모듈식으로 결합될 수 있는 modular networks를 사용하여 visual question answering에 iterated learning을 적용합니다. STaR과 expert iteration 방법 간에는 더 많은 유사점이 있습니다. 예를 들어, 궁극적인 답변이 목표와 일치하는지 여부에 따라 생성된 예제를 filtering하는 것은 expert 피드백으로 볼 수 있습니다. 그러나 우리는 고정된 "expert"를 가지고 있으며 별도의 value function을 train하지 않습니다.

Natural Language Explanations

자연어 설명은 또한 설명 가능한 기계 학습의 관점에서 논의되었으며, reasoning보다는 정당화에 중점을 둡니다. 이 연구 방향에 대한 동기는 주로 설명 가능한 의사 결정에 기반을 두고 있으며, 일반적으로 사후 설명(post-hoc explanations)을 요구하는 것이 model 성능을 향상시키지 않는다는 것을 발견합니다.

 

 

 

핵심 키워드: In-context Learning, Rationales, Iterated Learning, Natural Language Explanations

1. In-context Learning:

  • 주요 연구 동향: LLMs의 in-context learning, 즉, prompt에 예시를 주어 패턴을 학습하고 적용하는 능력을 탐구하는 연구들이 활발히 진행 중입니다.
  • STaR과의 연관성: STaR은 in-context learning의 개념을 활용하지만, 질문의 representation보다는 model output, 특히 reasoning 능력에 집중합니다.

2. Rationales:

  • 선행 연구:
    • Scratchpads: 중간 계산 과정을 위한 "scratchpads"를 사용하도록 LLM을 fine-tune하면 성능과 일반화가 향상됨을 보였습니다.
    • Chain-of-thought: Fine-tuning 없이, 단일 "chain-of-thought" prompt만으로도 여러 tasks에서 성능 향상을 보였습니다.
    • Curriculum Learning (for formal math): 특정 조건 하에서 (Lean 언어 사용, 증명 검증, 샘플링, value function, pre-fine-tuned model) curriculum learning이 formal 수학 문제 해결에 효과적임을 보였습니다.
  • STaR과의 차별성:
    • 데이터 구축: 기존 연구들은 대규모 human-annotated rationale dataset 구축이 필요하거나, 특정 조건(e.g., Lean 언어, 증명 검증)이 필요했습니다. 반면, STaR은 적은 수의 rationale 예제만으로 iterative하게 dataset을 구축합니다.
    • 일반화: STaR은 특정 task(e.g., formal math)에 국한되지 않고, 다양한 domains에 적용 가능합니다.
    • Rationales의 이점 분석: 기존 연구들이 rationale의 효과를 latent variable models, intermediate task supervision 관점에서 분석한 것과 달리, STaR은 rationale generation과 rationalization을 통한 iterative learning에 집중합니다.

3. Iterated Learning:

  • Expert Iteration (ExIt): Self-play와 expert 피드백을 통한 imitation learning을 결합한 reinforcement learning 기법으로, STaR에 영감을 주었습니다.
  • STaR과의 유사성 및 차이점:
    • 유사성: 정답 여부에 따른 filtering을 expert feedback으로 간주할 수 있습니다.
    • 차이점: STaR은 고정된 "expert" (정답을 알고 있는 상태)를 사용하며, 별도의 value function을 train하지 않습니다.

4. Natural Language Explanations:

  • 기존 연구 동향: 설명 가능한 AI 관점에서, reasoning보다는 정당화(justification)에 초점을 맞춘 natural language explanation 연구들이 진행되었습니다.
  • STaR과의 관계: STaR은 사후 설명(post-hoc explanations)을 생성하는 것이 아니라, 모델이 정답을 도출하는 과정에서 reasoning (rationale generation)을 하도록 유도한다는 점에서 차이가 있습니다.

결론:

STaR은 기존 연구들의 한계를 극복하고, rationale generation과 rationalization을 결합한 iterative learning을 통해 LLM의 reasoning 능력을 향상시키는 새로운 접근 방식을 제시합니다. 특히, 데이터 구축, 일반화 가능성, 학습 방식 측면에서 기존 연구들과 차별화되며, LLM의 self-taught learning 가능성을 보여주는 중요한 연구라고 할 수 있습니다.

 

 

 

 

 

 

 

3 Method

3.1 Rationale Generation Bootstrapping (STaR Without Rationalization)

우리는 pre-trained LLM M과 정답 y를 가진 문제 x의 초기 dataset D = {(xi, yi)} (D는 아래첨자 i=1)가 주어집니다. 우리의 기술은 중간 rationale r을 가진 적은 수의 예제로 구성된 prompt set P = {(x p i , r p i , y p i )} (P는 아래첨자 i=1, 여기서 P는 D의 부분집합 (예: P = 10))에서 시작합니다. 표준 few-shot prompting과 마찬가지로, 이 prompt set를 D의 각 예제에 연결합니다. 즉, xi = (x p 1 , r p 1 , y p 1 , . . . , x p P , r p P , y p P , xi)이며, 이는 model이 xi에 대한 rationale rˆi를 생성한 다음 정답 yˆi를 생성하도록 유도합니다. 우리는 정답으로 이어지는 rationale이 오답으로 이어지는 rationale보다 품질이 좋다고 가정합니다. 따라서 생성된 rationale을 필터링하여 정답(yˆi = yi)으로 이어지는 rationale만 포함합니다. 이 필터링된 dataset에 대해 기본 model M을 fine-tune한 다음, 새로 fine-tune된 model로 새 rationale을 생성하여 이 프로세스를 다시 시작합니다. 성능이 안정될 때까지 이 과정을 계속 반복합니다. 이 과정에서 새 dataset을 수집하면 overfitting을 피하기 위해 하나의 model을 지속적으로 training하는 대신 원래의 pre-trained model M에서부터 training을 합니다.

STaR은 RL 스타일의 policy gradient objective에 대한 근사치로 볼 수 있습니다. 이를 확인하기 위해, M은 이산 latent variable model pM(y | x) = P r p(r | x)p(y | x, r)로 볼 수 있습니다. 다시 말해, M은 먼저 latent rationale r을 샘플링한 다음 y를 예측합니다. 이제 indicator reward function 1(yˆ = y)이 주어지면 dataset 전체의 총 예상 보상은 다음과 같습니다.

J(M, X, Y ) = X i Erˆi,yˆi∼pM(·|xi)1(ˆyi = yi), (1)

∇J(M, X, Y ) = X i Erˆi,yˆi∼pM(·|xi) [1(ˆyi = yi) · ∇ log pM(ˆyi , rˆi | xi)] , (2)

여기서 gradient는 policy gradients에 대한 표준 로그 도함수 트릭을 통해 얻어집니다. indicator function은 정답 yi로 이어지지 않는 모든 샘플링된 rationale에 대한 gradient를 폐기합니다. 이것이 STaR의 필터링 프로세스입니다(5행). 따라서 STaR은 (1) 이 추정치의 분산을 줄이기 위해 (ˆri, yˆi) 샘플을 greedy하게 decoding하고(rationale의 편향된 탐색 가능성을 감수하고), (2) 동일한 데이터 배치에 대해 여러 gradient 단계를 수행(일부 policy gradient 알고리즘과 유사)하여 J를 근사합니다. 이러한 근사치는 STaR을 표준 LLM training 장비로 구현할 수 있는 간단하고 광범위하게 적용 가능한 방법으로 만듭니다. 향후 연구에서는 STaR과 위의 RL objective 사이의 연관성을 더 면밀히 조사해야 합니다.

3.2 Rationalization

rationale generation bootstrapping 알고리즘에는 한계가 있습니다. model은 정답을 맞힌 예제에 대해서만 trained되기 때문에, training set에서 새로운 문제를 해결하지 못하면 개선이 끝납니다. 이것은 근본적으로 알고리즘이 실패한 예제에서 어떠한 training signal도 얻을 수 없기 때문입니다. 우리는 "rationalization"이라고 부르는 기술을 제안합니다. 구체적으로, 우리는 정답을 힌트로 제공하고 이전 rationale generation 단계와 동일한 스타일로 rationale을 생성하도록 요청합니다. 정답이 주어지면 model은 거꾸로 추론할 수 있으므로 정답으로 이어지는 rationale을 더 쉽게 생성할 수 있습니다. 예를 들어, 그림 2에서 우리는 "(b) grocery cart"가 rationale을 생성하기 위한 prompt에서 정답이라는 힌트를 제공합니다.

우리는 model이 rationale generation으로 해결하지 못한 문제에 rationalization을 적용합니다. rationalization으로 생성된 rationale을 dataset에 추가할 때, model이 힌트 없이 rationale을 생각해 낸 것처럼 해당 prompt에 힌트를 포함하지 않습니다. 필터링 후, 이전에 생성된 dataset과 rationalization으로 생성된 dataset을 결합하여 fine-tune합니다.

 

 

 

 

핵심 키워드: Rationale Generation Bootstrapping, Rationalization, Fine-tuning, Policy Gradient

3.1 Rationale Generation Bootstrapping (STaR without Rationalization):

  • 목표: 적은 수의 rationale 예제(few-shot prompt)를 사용하여, pre-trained LLM (M)이 스스로 rationale을 생성하고, 이를 통해 reasoning 능력을 향상시키는 것입니다.
  • 과정:
    1. 초기 Prompt Set 구성: Rationale, 정답이 포함된 적은 수의 예제로 구성된 prompt set P를 준비합니다.
    2. Rationale 생성: P를 각 문제 x에 덧붙여(concatenate) LLM(M)에게 제공하고, rationale (rˆ)과 정답 (yˆ)을 생성하도록 합니다.
    3. Filtering: 생성된 rationale 중 정답을 맞힌 것(yˆ = y)만 남기고 필터링합니다.
    4. Fine-tuning: 필터링된 rationale dataset으로 기존 pre-trained LLM (M)을 다시 fine-tune 합니다. (continual training이 아닌, 처음부터 다시 학습)
    5. 반복: 2~4 과정을 반복합니다. 성능이 수렴할 때까지, 새로 fine-tune된 모델로 새로운 rationale을 생성하고, 이를 통해 다시 학습합니다.
  • 핵심 아이디어:
    • 정답을 맞힌 rationale이 더 낫다는 가정: 정답으로 이어진 rationale이 그렇지 않은 rationale보다 더 낫다는 가정 하에, 정답을 맞힌 rationale만으로 모델을 학습시킵니다.
    • Iterative Improvement: 모델이 스스로 생성한 rationale로 자신을 학습시키면서, 점진적으로 reasoning 능력을 향상시킵니다.
  • RL과의 연관성:
    • STaR은 policy gradient objective에 대한 근사로 볼 수 있습니다.
    • Model M은 latent rationale r을 샘플링하여 정답 y를 예측하는 discrete latent variable model로 간주됩니다.
    • 정답 여부를 나타내는 indicator function (1(yˆ = y))이 reward 역할을 합니다.
    • STaR은 이 reward를 최대화하는 방향으로 학습됩니다.
    • STaR은 분산을 줄이기 위해 greedy decoding을 사용하고, 동일한 데이터 배치에 대해 여러 gradient 단계를 수행하여 이 objective를 근사합니다.
  • 한계: 모델이 틀린 문제에 대해서는 학습할 수 없습니다. 즉, 모델이 새로운 문제를 풀지 못하면, 더 이상 성능 향상이 이루어지지 않습니다.

3.2 Rationalization:

  • 목표: Rationale Generation Bootstrapping의 한계를 극복하기 위해, 모델이 틀린 문제에 대해서도 학습할 수 있도록 합니다.
  • 방법:
    1. 정답 힌트 제공: 모델이 틀린 문제에 대해, 정답을 힌트로 제공합니다.
    2. 거꾸로 Rationale 생성: 정답을 힌트로 사용하여, "거꾸로" rationale을 생성하도록 유도합니다. 즉, "문제 + 정답 -> rationale" 순서로 추론하도록 합니다.
    3. Dataset 추가: Rationalization을 통해 생성된 rationale을 dataset에 추가합니다. 이때, 힌트는 prompt에서 제외하여, 마치 모델이 힌트 없이 rationale을 생성한 것처럼 보이게 합니다.
    4. Fine-tuning: 기존 rationale dataset과 rationalization으로 생성된 dataset을 결합하여 모델을 fine-tune합니다.
  • 핵심 아이디어:
    • 틀린 문제로부터의 학습: 정답을 힌트로 제공함으로써, 모델이 틀린 문제로부터 학습할 수 있는 기회를 제공합니다.
    • 더 쉬운 Rationale 생성: 정답을 알고 있는 상태에서 rationale을 생성하는 것이 더 쉽다는 점을 활용합니다.
  • 예시:
    • 문제: "사과가 5개 있고, 그 중 2개를 먹었습니다. 남은 사과는 몇 개일까요?"
    • 모델의 오답: "2개"
    • Rationalization: 정답 "3개"를 힌트로 제공하고, "5개에서 몇 개를 먹어야 3개가 남을까?"와 같은 방식으로 prompt를 구성하여 rationale 생성을 유도합니다.

결론:

STaR은 Rationale Generation BootstrappingRationalization이라는 두 가지 핵심 기술을 통해 LLM의 reasoning 능력을 향상시킵니다. 특히, Rationalization은 기존 bootstrapping 방식의 한계를 극복하고, 모델이 틀린 문제로부터 학습할 수 있도록 함으로써, LLM의 성능을 극대화하는 데 중요한 역할을 합니다.