AI바라기의 인공지능
Diffusion : 빠른 논문 리뷰 : Learnable Sampler Distillation for Discrete Diffusion Models 본문
Diffusion : 빠른 논문 리뷰 : Learnable Sampler Distillation for Discrete Diffusion Models
AI바라기 2025. 12. 24. 14:25용어 설명 (Terminology)
- Discrete Diffusion Models (DDMs): 이미지 픽셀과 같은 연속적인(continuous) 데이터가 아니라, 텍스트 토큰이나 DNA 서열과 같이 이산적인(discrete) 구조를 가진 데이터를 생성하기 위한 Diffusion 모델.
- NFEs (Number of Function Evaluations): 샘플링 과정에서 모델(함수)을 호출하는 횟수. NFEs가 낮을수록 생성 속도가 빠름.
- Accumulated Error: DDMs에서 적은 스텝으로 샘플링할 때 발생하는 오류의 총합. 이는 독립적인 토큰 예측으로 인한 Compounding Decoding Error와 큰 step size로 인한 수치적 근사 오차인 Discretization Error가 합쳐진 것을 의미함.
- Intermediate Score Trajectory Alignment: Student sampler가 Teacher sampler의 최종 결과물(discrete output)을 직접 모방하는 대신, 중간 단계에서의 score(denoising 방향) 값을 모방하도록 학습하는 방식.
- Effective Transition Term: Reverse process의 중간 단계에서 step size와 concrete score를 결합한 항으로, LSD+에서 시간 스케줄(time schedule)을 학습할 때 사용됨.
- Relaxed Objective: Student와 Teacher 간의 엄격한 일치가 어려울 때, 입력값에 약간의 변형(Hamming distance 기준)을 허용하여 학습을 용이하게 만드는 목적함수.
Purpose of the Paper
- Inefficiency of DDMs: 기존 Discrete Diffusion Models (DDMs)는 고품질 데이터를 생성하기 위해 매우 많은 sampling steps(예: 1024회 이상)를 필요로 하여 추론 비용이 매우 높음.
- Challenge of Acceleration: 단순히 step size를 키워 NFEs를 줄이면 Accumulated Error가 증폭되어 생성 품질이 급격히 저하됨. 연속형 Diffusion 모델을 위한 기존 가속화 기법(예: S4S)들은 샘플링 과정의 미분 가능성(differentiability)에 의존하는데, DDMs는 Categorical sampling의 비미분성(non-differentiability) 때문에 이를 직접 적용할 수 없음.
- Goal: DDMs의 고유한 특성(이산적 상태, 비미분성)을 고려하여, 적은 스텝(low NFEs)으로도 Teacher sampler의 고품질 궤적을 모방할 수 있는 Learnable Sampler Distillation (LSD) 프레임워크를 제안함.
Key Contributions
이 논문은 DDMs를 위한 새로운 Distillation 접근법인 LSD와 이를 확장한 **LSD+**를 제안하며 다음과 같은 기여를 함:
- Distillation via Intermediate Score Alignment (Novelty):
- Discrete data의 비미분성 문제를 해결하기 위해, 최종 출력(discrete output)을 비교하여 gradient를 역전파하는 대신, 중간 단계의 score trajectory를 Teacher와 일치시키는 방식을 제안함. 이를 통해 discrete sampling 연산을 우회하여 sampler parameter를 학습 가능하게 만듦.
- Learnable Sampler Coefficients:
- LSD는 고정된 수치해석적 계수를 사용하는 대신, 학습 가능한 계수(learnable coefficients)를 도입함. 이는 각 time step에서 sampling dynamics를 적응형(adaptively)으로 조절하여 큰 step size로 인한 discretization error를 보정함.
- LSD+: Learnable Time Schedules:
- LSD의 확장판인 LSD+는 계수뿐만 아니라 non-uniform time schedules (step sizes)까지 함께 학습함.
- Reverse process의 effective transition term을 Teacher와 Student 간에 정렬시키는 방식으로 최적의 step size를 찾아내어 Accumulated Error를 더욱 효과적으로 감소시킴.
- Relaxed Training Objective:
- Student sampler가 Teacher의 궤적을 완벽히 따라가는 것이 어렵다는 점을 감안하여, 입력 토큰에 작은 Hamming distance 내의 perturbation(변형)을 허용하는 Relaxed objective를 도입함. 이는 학습의 수렴성을 높이고 성능을 안정화하는 데 기여함.
Experimental Highlights
- Tasks & Baselines: Text generation (SEDD, RADD backbones), Image generation (CIFAR-10, ImageNet), Synthetic countdown task에서 검증. 비교 대상은 Euler, Tweedie, JYS 등 기존 SOTA samplers.
- Performance on Text Generation:
- SEDD-small backbone (GPT-2 level): 64 NFEs 기준, 기존 Euler sampler의 Perplexity는 약 56.2인 반면, LSD+는 20.4를 기록하여 압도적인 품질 향상을 보여줌.
- 극단적으로 적은 8 NFEs에서도 LSD+는 Perplexity 128.4를 기록하며 Euler(423.1) 대비 훨씬 사용 가능한 수준의 텍스트를 생성함.
- Performance on Image Generation:
- CIFAR-10: 기존 CTMC baseline 대비 더 낮은 FID score를 달성하며 시각적 품질 우위를 입증.
- ImageNet (256x256): MaskGIT backbone에 적용 시, 최신 기법인 Halton sampler보다 더 낮은 FID를 기록함 (NFE=4일 때 Halton 54.05 vs LSD+ 48.29).
- Efficiency:
- LSD+는 NVIDIA RTX4090 GPU에서 약 5분 만에 학습이 완료되며, 이는 경쟁 방법론인 JYS(약 10분)보다 빠르고 효율적임.
Limitations and Future Work
- Dependency on Teacher Quality (Limitation):
- LSD의 성능 상한선은 근본적으로 Teacher sampler의 품질에 종속됨. Student는 Teacher를 모방할 뿐 Teacher의 성능을 뛰어넘기는 어려움.
- Theoretical Analysis in Discrete Space (Limitation):
- 연속 공간(continuous space)과 달리, 이산 공간에서는 미분 불가능성으로 인해 Relaxed objective가 분포 일치(distribution matching)를 보장한다는 엄밀한 이론적 증명이 더 까다로움. 현재는 경험적(empirical) 효과에 더 의존함.
- Future Work:
- Teacher와 Student sampler 간의 분포적 불일치(distributional discrepancy)에 대한 이론적 보장(theoretical guarantees)을 제공하는 연구.
- 더 발전된 Teacher sampler를 개발하여 Student의 성능 상한을 높이는 연구.
Overall Summary
이 논문은 Discrete Diffusion Models (DDMs) 의 느린 추론 속도 문제를 해결하기 위해, 고품질 Teacher sampler의 궤적을 적은 스텝의 Student sampler로 증류(distillation)하는 LSD (Learnable Sampler Distillation) 기법을 제안했습니다. 이산 데이터의 비미분성 문제를 Intermediate Score Alignment와 Relaxed Objective로 우회하고, 샘플링 계수와 시간 스케줄을 학습함으로써 기존 방법론 대비 월등히 적은 NFEs로 고품질 데이터를 생성할 수 있음을 입증했습니다. 이는 텍스트나 바이오 시퀀스와 같은 이산 데이터 생성 분야에서 DDMs의 실용성을 크게 높인 중요한 연구입니다.
쉬운 설명 (Easy Explanation)
이 논문의 핵심 아이디어는 **"선생님의 답안지(최종 결과)만 베끼는 것이 아니라, 선생님이 문제를 푸는 과정(중간 생각의 방향)을 배우는 것"**과 비슷합니다.
- 문제: 기존의 DDMs(텍스트 생성 AI 등)는 좋은 글을 쓰려면 수천 번의 수정(step)을 거쳐야 해서 너무 느렸습니다. 그렇다고 단계를 대충 건너뛰면(step size를 키우면) 글이 엉망이 되었습니다.
- 어려움: 보통의 가속화 방법은 미분(기울기 계산)을 통해 정답을 역추적하며 배우는데, 텍스트 같은 이산 데이터는 중간에 뚝뚝 끊어져 있어서 미분을 할 수가 없습니다.
- 해결책 (LSD):
- 그래서 이 모델은 최종 완성된 글을 보고 배우는 게 아니라, 선생님 모델이 중간중간 "이쪽 방향으로 고쳐야 해"라고 가리키는 나침반(Score)의 방향을 흉내 내도록 학습합니다.
- 또한, "어느 시점에 얼마나 많이 고쳐야 하는지"(Time Schedule) 도 스스로 학습해서, 적은 횟수의 수정만으로도 선생님 모델만큼 훌륭한 결과를 만들어냅니다.
- 결과: 이렇게 학습한 학생 모델(Student)은 선생님보다 훨씬 적은 노력(단계)만으로도 거의 비슷한 수준의 고품질 텍스트나 이미지를 아주 빠르게 만들어냈습니다.
LSD 프레임워크: "거북이 선생님 따라잡기"
1. [시작] 노이즈 입력 (Input)
아무 의미 없는 노이즈(엉망인 상태) 하나를 선생님과 학생에게 동시에 던져줍니다.
2. [선생님] 정답 방향 제시 (Teacher's Score)
1000걸음 걷는 꼼꼼한 선생님 모델이 **"정답(깨끗한 데이터)으로 가려면 북쪽 15도로 가야 해"**라고 **정확한 방향(Score)**을 미리 계산해 둡니다. (고정값)
3. [학생] 나의 방향 예측 (Student's Prediction)
10걸음 만에 가고 싶은 성격 급한 학생 모델이 **"제 생각엔 북쪽 10도 같은데요?"**라고 자신의 방향을 내놓습니다.
4. [핵심] 보정 계수 적용 (Learnable Coefficients)
학생은 한 번에 멀리 뛰어야 하므로(Large Step), 자신의 예측 방향에 **'마법의 보정 숫자(계수)'**를 곱해서 **"그럼 북쪽 15도로 맞춰서 크게 뛰겠습니다!"**라고 방향을 튜닝합니다.
5. [채점] Loss 계산 (Objective)
**"선생님이 찍어준 방향"**과 **"학생이 튜닝한 방향"**이 얼마나 일치하는지 숫자로 계산합니다. (다르면 Loss 높음, 같으면 Loss 낮음)
6. [학습] 업데이트 (Optimization)
Loss를 줄이는 방향으로, 학생은 **'어느 방향으로 뛰어야 하는지'**와 **'얼마나 크게 뛰어야 하는지(Time Schedule)'**를 머릿속(파라미터)에 저장합니다.
7. [결과] 초고속 추론 (Inference)
학습이 끝나면 학생은 선생님 없이도 혼자서 **선생님과 똑같은 길을, 100배 빠른 속도(적은 스텝)**로 주파하여 고품질 텍스트/이미지를 만들어냅니다.
한 줄 요약:
"선생님은 천천히 걷지만 길을 정확히 알고 있으니, 학생은 선생님의 나침반(방향)만 컨닝해서 지름길로 축지법(점프)을 쓰는 법을 배우는 겁니다."