COCONUT: 연속 잠재 공간에서 추론하도록 LLM 훈련하기

Digest: LLM은 왜 반드시 언어 공간에서 추론해야 하는가? Chain-of-Thought(CoT)는 자연어로 추론 단계를 표현하지만, 대부분의 토큰은 텍스트 일관성을 위한 것이고 추론에 핵심적인 토큰은 소수다. COCONUT(Chain of Continuous Thought)는 모델의 마지막 은닉 상태(last hidden state)를 다음 입력 임베딩으로 직접 피드백하여, 토큰화를 우회하고 연속 잠재 공간에서 추론하게 한다. 핵심 발견은 이 잠재 추론이 너비 우선 탐색(BFS) 유사 패턴을 자연 발생적으로 보인다는 것이다—연속적 사고(continuous thought)가 여러 대안 경로를 동시에 인코딩하여, CoT의 탐욕적(greedy) 단일 경로 한계를 넘어선다. ProsQA(계획 필요 논리 추론)에서 CoT 77.5% 대비 97.0% 정확도를 달성하면서 토큰 수는 49.4 → 14.2로 줄였다. 다만 GSM8k(수학 추론)에서는 CoT 42.9% 대비 34.1%로 하락하여, 잠재 추론이 탐색형 문제에 특화되었음을 보여준다.


메타데이터

모델GPT-2 (주 실험), Llama 3.2-3B, Llama 3-8B (보조 실험)
데이터셋ProsQA, ProntoQA, GSM8k

왜 이 연구를 하는가?

핵심 질문

LLM의 추론이 반드시 언어 공간에서 이루어져야 하는가? 언어의 제약을 벗어나 연속 잠재 공간에서 추론하면 어떤 이점이 있는가?

기존 접근법의 한계

한계설명
언어의 병목CoT는 추론을 자연어로 표현해야 하므로, 언어로 쉽게 표현되지 않는 추론 패턴이 제한됨
탐욕적 단일 경로CoT는 한 번에 하나의 추론 경로만 추구하여, 탐색이 필요한 문제에서 비효율적
토큰 낭비CoT 토큰의 대부분은 문법적 일관성을 위한 것이지 추론에 핵심적이지 않음

핵심 통찰

  • 모델의 은닉 상태는 이산적 토큰보다 훨씬 풍부한 정보를 인코딩할 수 있으며, 여러 대안적 다음 단계를 동시에 표현할 수 있다
  • 언어 공간을 우회하면 CoT의 “연극적” 부분을 제거하고 순수한 계산적 추론만 보존할 수 있다

방법 (Method)

프레임워크 개요

graph TB
    subgraph "기존 CoT"
        A1["질문"] --> B1["토큰1<br/>(언어)"] --> C1["토큰2<br/>(언어)"] --> D1["..."] --> E1["답변<br/>(언어)"]
    end

    subgraph "COCONUT"
        A2["질문"] --> B2["연속 사고1<br/>(은닉 상태)"] --> C2["연속 사고2<br/>(은닉 상태)"] --> D2["..."] --> E2["답변<br/>(언어)"]
    end

    subgraph "핵심 차이"
        F["CoT: 각 단계를 언어로<br/>디코딩 → 재인코딩"]
        G["COCONUT: 은닉 상태를<br/>직접 다음 입력으로 전달"]
    end

    style B2 fill:#f9f,stroke:#333
    style C2 fill:#f9f,stroke:#333

핵심 구성요소

1. 연속적 사고(Continuous Thought): 모델의 마지막 은닉 상태를 표준 LM head로 디코딩(de-tokenize)하지 않고, 다음 스텝의 입력 임베딩으로 직접 사용한다. 이 은닉 상태가 “연속적 사고”이며, 이산적 토큰과 달리 연속 공간의 전체 표현력을 활용할 수 있다.

2. 다단계 커리큘럼(Multi-stage Curriculum): CoT의 언어적 추론 단계를 점진적으로 연속적 사고(continuous thought)로 대체해 나가는 훈련 스케줄이다. 언어 supervision이 전혀 없는 상태에서 latent reasoning을 직접 end-to-end로 학습시키면 거의 학습되지 않기 때문에(GSM8k 14.4%, 커리큘럼 적용 시 34.1%), 논문은 Deng et al. (2024)의 Stepwise Internalization을 따라 “언어 스텝 → 잠재 스텝”을 점진적으로 내재화한다.

스테이지 정의 (Stage k). 원래의 CoT가 자연어 추론 스텝 로 구성되어 있을 때:

  • Stage 0 (초기 단계): 일반 CoT 데이터로 표준 supervised fine-tuning. 모델이 먼저 “언어로 추론하는 법”을 학습한다.
  • Stage k (k ≥ 1): 앞쪽 k개의 언어 스텝 를 제거하고, 그 자리에 개의 연속적 사고 토큰을 삽입한다. 나머지 스텝 와 최종 답변은 여전히 언어 토큰으로 유지되며 loss 계산에 사용된다.
  • 는 “한 reasoning step(= CoT의 한 줄/문장, 보통 수십 토큰) 당 몇 개의 latent thought로 대체할지”를 정하는 하이퍼파라미터. 수학 추론(GSM8k)은 c = 2, 논리 추론(ProntoQA/ProsQA)은 c = 1을 사용. 여기서 “step”은 토큰이 아니라 CoT의 한 추론 단계 (예: "48/2 = 24 clips in May" 같은 한 줄)임에 주의 — 따라서 c는 매우 공격적인 압축비를 의미한다 (수십 토큰 → 1~2개 latent thought).

시퀀스 포맷. 입력 시퀀스는 특수 토큰으로 latent/language 영역을 구분한다:

[Question] <bot> [k×c개의 continuous thoughts] <eot> [s_{k+1} ... s_T] [Answer]

<bot>/<eot>는 latent reasoning 모드의 시작/끝을 표시. latent 영역에서는 LM head를 거치지 않고, 마지막 hidden state가 다음 스텝의 input embedding으로 직접 feedback 된다.

손실 함수와 마스킹. 목적함수는 표준 NLL이지만 (i) 질문 토큰과 (ii) latent thought 위치 전체에서 loss를 마스킹한다. 즉 gradient가 흐르는 위치는 “남아 있는 언어 스텝 + 최종 답변”뿐이다. Latent thought 자체는 어떤 명시적 타깃도 갖지 않으며, “미래 추론/답변의 likelihood를 높이는 방향”으로만 간접적으로 학습된다 — 즉 제거된 언어 스텝을 복원(compress)하도록 강제하지 않는다.

스테이지 전환 시 옵티마이저 리셋. Deng et al. (2024)을 따라 새 스테이지 진입 시마다 optimizer state (Adam의 momentum·variance)를 초기화한다. 이전 스테이지의 “언어 경로” gradient에 누적된 통계가 새 latent 목표 학습을 방해하는 것을 막기 위함이다.

에폭 스케줄.

  • GSM8k (math): 초기 stage 6 epochs, 이후 각 stage 3 epochs. 총 3개 stage + 초기 stage. 이후 최종 stage에서 50 epoch까지 유지.
  • ProntoQA / ProsQA (logic): 초기 stage 포함 7개 stage, 각 5 epochs. 이후 최종 stage에서 50 epoch까지 유지.

최적화의 계산 비용. 한 학습 예제에서 개의 latent thought가 스케줄되어 있으면 번의 순차적 forward pass가 필요하다 (각 pass가 새 latent thought 하나를 생성하고, 그것이 다음 pass의 입력이 됨). KV cache로 이전 pass의 계산을 재활용하지만, 순차성 때문에 토큰 수준의 parallelism은 제한된다.

추론 시 동작. 추론 단계에서는 EOT(End of Thought)를 예측하게 하지 않고, 훈련 시 사용한 최종 stage의 latent thought 개수로 고정(pad to constant length)한다. 논문은 “단순성”을 이유로 들지만, 이는 모델이 “얼마나 잠재 추론을 해야 하는지”를 스스로 결정하지 못한다는 구조적 한계를 함의한다 (논문도 binary classifier 기반 자동 종료를 대안으로 언급하나 미채택).

왜 커리큘럼이 필수인가. 처음부터 “latent thought로만 답을 맞히라”는 목표는 지나치게 느슨하여(loss 신호가 답변 토큰에만 존재) 학습 신호가 latent 영역까지 역전파되기 어렵다. 커리큘럼은 초기에는 긴 언어 supervision을 제공하다가, stage가 올라갈수록 언어 supervision을 짧게 잘라내어 latent thought가 점점 더 많은 추론 부담을 떠맡도록 점진적으로 “압력”을 가한다. 이 구조적 이유로, 커리큘럼 제거 시 GSM8k 성능이 34.1% → 14.4%로 절반 이하로 붕괴한다.

3. BFS 유사 추론 패턴: 잠재 공간에서의 추론을 분석하면, 초기 연속적 사고는 여러 후보 경로에 대해 “상당한 다양성”을 유지하며, 후기로 갈수록 유망한 경로로 수렴한다. 이는 CoT의 탐욕적 DFS와 대비되는 BFS 유사 패턴이다.


발견 (Findings)

주요 결과

태스크No CoTCoTCOCONUT토큰 수 (CoT → COCONUT)
ProsQA0%77.5%97.0%49.4 → 14.2
ProntoQA72.5%98.8%99.8%92.5 → 9.0
GSM8k3.2%42.9%34.1%25.0 → 8.2

커리큘럼의 중요성

훈련 방식GSM8k 정확도
커리큘럼34.1%
커리큘럼 없이 (end-to-end)14.4%

BFS 패턴 증거

잠재 추론의 분석 결과, 목표에서 먼 노드(높이가 높은 노드)는 모호한 가치 추정을 받는 반면, 가까운 노드는 확신 있게 평가된다. 이는 모델이 탐색 초기에는 여러 가능성을 열어두고, 정보가 축적되면서 유망한 경로로 수렴하는 BFS 전략을 학습했음을 시사한다.

핵심 발견

탐색형 문제에서 극적 개선: ProsQA에서 CoT 대비 19.5%p 향상은, 계획과 탐색이 필요한 문제에서 잠재 추론의 BFS 패턴이 CoT의 탐욕적 단일 경로보다 훨씬 효과적임을 보여준다.

수학 추론의 한계: GSM8k에서의 하락은, 정확한 수치 계산과 단계별 논리적 전개가 필요한 문제에서는 언어적 표현이 중요한 역할을 함을 시사한다. 잠재 공간은 “탐색”에 강하지만 “정밀한 계산”에는 약할 수 있다.

효율성의 극적 향상: ProntoQA에서 토큰 수를 92.5 → 9.0으로 줄이면서 정확도를 유지한 것은, CoT 토큰의 대부분이 실제 추론이 아닌 언어적 장식임을 직접적으로 보여준다. 이는 Reasoning Theater - Disentangling Model Beliefs from Chain-of-Thought의 “연극적 추론” 가설을 다른 각도에서 뒷받침한다.


이론적 의의

언어 없는 추론의 가능성과 한계

COCONUT은 “추론은 반드시 언어로 이루어져야 한다”는 암묵적 가정에 도전한다. 잠재 공간에서의 추론이 특정 문제 유형(탐색, 계획)에서 언어 기반 추론을 크게 능가한다는 발견은, 언어가 추론의 매개체이지 추론 자체가 아님을 시사한다. 이는 Reasoning Theater의 핵심 발견—모델의 진정한 추론은 내부 활성화에서 이루어지며, CoT는 이를 불완전하게 반영하는 “창”에 불과하다—과 일맥상통한다.

CoT의 “연극성”에 대한 구조적 증거

COCONUT이 CoT보다 적은 토큰으로 더 높은 성능을 달성한다는 것은, CoT 토큰의 상당 부분이 추론에 불필요한 “연극적” 요소임을 구조적으로 증명한다. 언어적 형식을 완전히 제거해도(오히려 제거하면) 추론 성능이 향상될 수 있다.


관련 연구


핵심 용어 정리

용어정의
연속적 사고 (Continuous Thought)모델의 마지막 은닉 상태를 다음 입력 임베딩으로 직접 사용하는 추론 표현. 이산적 토큰과 달리 연속 공간의 전체 표현력을 활용
COCONUT (Chain of Continuous Thought)연속적 사고를 연쇄적으로 활용하여 잠재 공간에서 추론하는 패러다임
너비 우선 탐색 (BFS) 유사 패턴잠재 추론에서 자연 발생적으로 관찰되는 패턴. 초기에 여러 후보 경로를 동시에 탐색하고 점진적으로 수렴
다단계 커리큘럼CoT의 언어적 추론 단계를 점진적으로 연속적 사고로 대체하는 훈련 전략
잠재 추론 (Latent Reasoning)이산적 언어 토큰이 아닌 연속 잠재 공간에서 이루어지는 추론 과정
ProsQA방향 비순환 그래프(DAG) 탐색을 요구하는 논리 추론 데이터셋. 계획과 탐색이 필요하여 COCONUT의 BFS 패턴이 특히 효과적