AI바라기의 인공지능

개념 정리(심화) : LLM의 PPO란 Proximal Policy Optimization 본문

인공지능

개념 정리(심화) : LLM의 PPO란 Proximal Policy Optimization

AI바라기 2026. 3. 14. 12:53

PPO란 Proximal Policy Optimization

SFT에서는 기본적으로 다음 토큰을 예측하도록 학습된다.
하지만 모델이 생성한 전체 응답만 목적에 맞으면 점수를 주고 싶다!!

이럴때 필요한게 PPO 같은 강화학습 방법이다. 토큰 하나하나에 신경쓰지 않고 전체적인 보상을 통해 모델 파라미터를 업데이트 하는 것이 포인트이다.

PPO 목적 함수 (Loss):

$$L^{PPO}(\theta) = \mathbb{E} \left[ \min \left( \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} \hat{A}_t, \text{clip}\left(\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}, 1-\epsilon, 1+\epsilon\right) \hat{A}_t \right) \right]$$

토큰별 실제 보상: $r_t = r_{RM}(x, y)\mathbb{1}_{[t=T]} - \beta \log \frac{\pi_{\theta_{old}}(a_t|s_t)}{\pi_{\text{SFT}}(a_t|s_t)}$

순간 변동폭 (TD Error): $\delta_t = r_t + \gamma V_\omega(s_{t+1}) - V_\omega(s_t)$

가중 누적 이점 (GAE): $\hat{A}_t = \delta_t + (\gamma \lambda) \hat{A}_{t+1}$

이렇게만 보면 이해할 수 없다. 이해하면 천재이다.

PPO는 데이터를 많이 생성하기에 old 모델과 현재 학습 모델의 차이가 있다. 그렇기에 턴을 들어가기 전 모델을 old 모델로 한다.



생성모델이 처음부터 끝까지 답변 생성.
수식: $a_t \sim \pi_{\theta_{old}}(a_t|s_t)$


원래 SFT모델의 확률과 비교해서 멀어진 정도를 측정 후 벌점을 줌. 상점은 없음. 리워드 모델이 답변 전체에 대한 점수 부여. 아까 받은 토큰별 멀어진 벌점이 각 토큰별 점수가 되지만 마지막 토큰의 경우 최종 보상점수가 추가됨.

수식: $r_t = r_{RM}(x, y)\mathbb{1}_{[t=T]} - \beta \log \frac{\pi_{\theta_{old}}(a_t|s_t)}{\pi_{\text{SFT}}(a_t|s_t)}$


그후 밸류모델이 각 토큰 지점 별로 이대로면 마지막에 받을 점수가 몇점일지를 예상.
수식: $V_\omega(s_t)$


순간 변동폭(δ) = (이 자리에 놓인 실제 보상) + (다음 토큰의 예상 점수) - (현재 토큰의 예상 점수) 그럼 각 토큰별 순간 변동폭이라는 점수가 부여되는데.
수식: $\delta_t = r_t + \gamma V_\omega(s_{t+1}) - V_\omega(s_t)$


100번 토큰의 이점: 자기 자신의 순간 변동폭
99번 토큰의 이점: 자기 자신의 순간 변동폭 + (100번 토큰의 이점 × 0.95)
98번 토큰의 이점: 자기 자신의 순간 변동폭 + (99번 토큰의 이점 × 0.95)
...
1번 토큰의 이점: 자기 자신의 순간 변동폭 + (2번 토큰의 이점 × 0.95)
그래서 각 토큰별 이점을 얻게됨.

수식: $\hat{A}_t = \delta_t + (\gamma \lambda) \hat{A}_{t+1}$




클리핑 0.8~1.2[(현재 업데이트 중인 모델의 생성 확률) / (얼려둔 옛날 모델의 생성 확률)] * 각 토큰의 이점을 계산하고
수식: $\min \left( \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} \hat{A}_t, \text{clip}\left(\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}, 1-\epsilon, 1+\epsilon\right) \hat{A}_t \right)$


그 값을 평균내고 w를 그 값이 높아지는 방향으로 학습
수식: $L^{PPO}(\theta) = \mathbb{E} [ \dots ]$


즉 토큰 하나하나의 정확도에 굳이 신경쓰지 않고, 전체적으로 좋은 답변을 생성하기 위해 파라미터가 업데이트되는 방식. 이때 너무 많은 업데이트를 하지 않게 하기 위해, 클리핑을 진행.