Introduction
Related Papers
Methods
Results
Discussion
Summary
OPENAI에서 만든 GSM8K(Grade School Math 8.5k) 제안 논문.
Key Contribution
- GSM8K 배포
- LLM output을 채점하는 Verifier training method를 제안
- Dropout이 FineTuning과 verfication 모둥서 강력란 regularizer로 작용하여 performace에 크게 영향을 미친다는 점 보고.
Dataset Configuration
전체 8.5k 개 중, 약 7.5k개의 train / 1k개의 validation set으로 구성됨.
사람이 수작업으로 작성했으며, 난이도는 초등 대수학 이상의 개념을 요구치 않는다.
solution은 순수 수학적 표현이 아닌 자연어로 제공되어 모델의 internal monologues를 이해하는데 도움을 준다.
2~8단계의 풀이 과정으로 요구하는 주로, 기본적인 사칙연산을 요구하는 문제들로 구성됨.
높은 언어적 다양성과 자연어 기반의 솔루션을 포함하여 기존 데이터 셋들의
4. low-quality
5. template dependencies
6. inefficient step
를 보완한다.
FineTuning
본 논문에서는 GSM8K를 사용하여, GPT-3 175B 모델을 SFT함.
Loss는 pre-train과 동일하게 NTP loss 사용.
test 시에는 temp=0, greedy-decoding 환경에서 실시.Learning Curve : Training
Findings 중 하나는 train 시 모델이 2 epochs 이후부터 Test@100 성능(100개의 샘플 중 하나라도 정답을 포함할 확률)이 급격히 저하되는데, 이는 모델이 훈련 데이터에 대해 과적합되어 생성되는 솔루션의 다양성이 줄어들고 불확실성이 증가하기 때문이다. 이는 Verifier 훈련을 위한 generator를 선택할 때 중요한 요소가 된다. 모델이 최종 답안만 직접 출력하도록 finetuning하면 성능이 20.6%에서 5.2%로 크게 하락하므로, 전체 자연어 솔루션을 생성하는 것이 필수적이다. 계산의 정확성을 높이기 위해, 모든 모델은 훈련 시 calculation annotations를 주입하여 계산기를 사용하도록 훈련된다. 테스트 시에는 이 annotation을 만나면 샘플링을 override하고 Python의 eval 함수를 통해 계산 결과를 삽입한다.
Calculation Annotations: 상세 설명 및 예시
GSM8K 연구에서 사용된 “Calculation Annotations”는 언어 모델이 복잡한 산술 연산을 수행할 때, 외부 계산기(Python eval 함수)를 활용하도록 유도하는 특별한 표시입니다. 이는 언어 모델 자체의 산술 능력 한계를 보완하고, 보다 정확하고 신뢰할 수 있는 계산 결과를 얻기 위한 기법입니다.
1. Calculation Annotations의 작동 방식데이터
준비 단계 (Training):
연구팀은 훈련 데이터셋에 있는 문장제 문제의 솔루션 중간에 특정 패턴을 삽입합니다.
이 패턴은 다음과 같습니다: <<[수식]>>
예를 들어, 20 + 10이라는 계산이 필요하다면, 이를 20 + 10 = <<20+10>> 와 같이 표시합니다.
훈련 시에는 이 <<…>> 내부의 텍스트도 일반 토큰처럼 처리됩니다. 즉, 모델은 이 부분을 보고 계산을 수행하거나, 계산기 주석이 있음을 인지하고 학습합니다.
핵심: 모델은 이 주석을 통해 “여기서는 직접 계산하는 대신, <<…>> 안의 수식을 계산하여 결과를 사용해야 한다”는 것을 학습합니다.테스트 단계 (Testing):
모델이 테스트 문제를 풀 때, 생성하는 텍스트 중에 <<…>> 패턴을 발견하면 특별한 처리를 합니다.
계산기 개입: 모델이 <<…>> 패턴에 도달하면, 내부의 수식(예: 20+10)을 추출하여 Python의 eval() 함수를 통해 실제로 계산합니다.
결과 대체: eval() 함수의 결과(예: 30)를 원래의 <<…>> 자리나 그 이후에 오는 토큰과 대체합니다.예시:
모델 생성: She eats 21 omelets per week so over 4 weeks she will eat 4x21 = <<21x4=84>>
계산기 작동: eval(“21x4”) → 84
최종 모델 출력 (계산기 적용 후): She eats 21 omelets per week so over 4 weeks she will eat 4x21 = 84
주의사항: 만약 eval() 함수가 오류를 발생시키거나 타임아웃이 발생하면, 해당 주석은 무시되고 모델은 평소처럼 샘플링을 계속 진행합니다.연구팀은 이 계산기 주석 생성 로직의 구현이 완벽하지는 않지만, 잘못된 주석을 생성할 가능성은 매우 낮다고 언급합니다 (Appendix C).
Verification
Verification은 finetuning baseline의 성능을 향상시키기 위해 모델이 생성한 솔루션의 정확성을 판단하는 Verifier를 훈련하고, 테스트 시에는 이 Verifier를 사용하여 최적의 솔루션을 선택한다. Verifier는 문제와 후보 솔루션을 조건으로 솔루션이 정확할 확률을 출력한다. Verifier 훈련 시 솔루션의 레이블(정확/부정확)은 오직 최종 답안의 정확성만을 기준으로 결정된다.
Verification 훈련 파이프라인은 다음과 같다:
Generator 훈련(위에서 정리해둔 FT과정):
먼저, generator 역할을 하는 모델을 훈련 데이터셋에 대해 2 epochs 동안 finetuning한다. 이 2 epochs는 기본적인 추론 능력을 학습하기에 충분하며, 이보다 오래 훈련하면 생성 솔루션의 다양성이 급격히 감소하는 문제를 피할 수 있다.
솔루션 샘플링 및 레이블링:
훈련 문제당 generator로부터 100개의 솔루션 completions를 high temperature (T=0.7T=0.7T=0.7T=0.7)로 샘플링한다. 각 솔루션은 최종 답안의 정확성에 따라 “correct” 또는 “incorrect”로 레이블링된다. (이 과정에서 잘못된 추론으로도 정답에 도달하는 “false positives”가 발생할 수 있음).
Verifier 훈련:
샘플링 및 레이블링된 데이터셋을 사용하여 Verifier를 1 epoch 동안 훈련한다. Verifier는 generator와 동일한 모델 크기를 사용하며, 솔루션의 정확성을 예측하는 것 외에 generator와 동일한 language modeling objective를 auxiliary objective로 사용하여 함께 훈련된다. 이는 Verifier가 언어 분포를 더 잘 이해하여 샘플 간의 차이를 더 잘 식별하도록 돕는다. Verifier는 모델의 unembedding layer에 있는 특수 토큰의 logits에 대해 bias 및 gain 파라미터를 적용하는 작은 scalar head를 통해 토큰별로(token-level) 예측을 수행한다. 훈련 시에는 문제 토큰은 마스킹되고 솔루션 토큰에 대해서만 loss를 계산한다. Generator의 과적합을 방지하기 위해 generator와 verifier 모델을 분리한다.
테스트 시에는 각 테스트 문제에 대해 100개의 completions를 샘플링한 후, Verifier로 이들을 랭크하여 가장 높은 점수를 받은 솔루션을 최종 답안으로 선택한다.
초기에는 데이터셋 크기가 작을 때 Verifier의 이점이 미미하지만, 충분히 큰 데이터셋을 사용하면 강력한 성능 향상을 보인다. 175B Verifier는 6B Verifier보다 더 적은 훈련 데이터로 finetuning baseline을 능가한다.