AI바라기의 인공지능
단백질 : 논문리뷰 : FoldToken2: Learning compact, invariant and generative protein structure language 본문
단백질 : 논문리뷰 : FoldToken2: Learning compact, invariant and generative protein structure language
AI바라기 2025. 4. 10. 13:45쉬운 설명:
이 논문은 3D 단백질 구조를 마치 **'단백질 맞춤형 바코드(discrete tokens)'**로 변환하는 기술(FoldToken2)을 제안합니다. 이 바코드는 단백질을 어떻게 돌려보아도 항상 동일하게 유지되며(invariant), 원래 3D 구조 정보를 거의 손실 없이 담고 있어 다시 원래 모습으로 복원(recoverable)할 수 있습니다. FoldToken2는 이전 버전(FoldToken1)보다 훨씬 더 똑똑한 '번역기'(encoder/decoder)와 '바코드 생성 규칙'(quantizer)을 사용하여, 단일 단백질뿐 아니라 여러 단백질이 뭉친 복합체(multi-chain)에 대해서도 정확한 바코드를 빠르고 효율적으로 만들어냅니다. 이렇게 만들어진 바코드는 컴퓨터가 이해하기 쉬운 언어 형태이므로, 앞으로 텍스트나 이미지를 다루던 기존 AI 모델들이 단백질 구조를 분석하고 생성하는 데 더 쉽게 활용될 수 있습니다.
FoldToken2 학습 노트
Purpose of the Paper:
- 단백질 구조 데이터의 SE(3) equivalent 특성 (회전/병진 시 좌표는 변하지만 구조 자체는 동일)으로 인한 representation learning, alignment, generation의 어려움 극복.
- 기존 FoldToken1의 한계를 넘어, 더욱 정확하게 복원 가능(recoverable) 하면서도 compact하고 invariant한 단백질 구조 언어 (discrete tokens)를 생성하는 것을 목표.
- Equivariant 구조를 invariant한 latent space/language로 변환하여, 기존 NLP/CV 모델을 단백질 구조 분석 및 생성에 더 쉽게 활용할 수 있는 기반 마련.
- Single-chain 뿐만 아니라 multi-chain 단백질 구조까지 tokenization을 확장하고자 함.
Key Contributions & Novelty:
- Invariant Structure Encoder 개선:
- Contribution: 기존 FoldToken1의 angle-based Transformer 대신, coordinate-based 입력과 BlockGAT [15] GNN을 사용하여 3D 의존성을 더 효과적으로 포착하는 invariant encoder 제안.
- Novelty: Residue를 block으로 표현하고 frame-based GNN (BlockGAT)을 사용하여 단백질 구조의 invariant 특성을 효율적으로 학습.
- Vector-Quantized Compressor 개선 (T-SoftCVQ):
- Contribution: FoldToken1의 SoftCVQ를 개선하여 Teacher-guided VQ와 Adaptive temperature annealing 전략을 도입한 T-SoftCVQ 제안.
- Novelty: 학습 안정성 (gradient explosion 방지) 및 수렴 속도를 개선하고, 더 나은 이산 표현 (discrete token) 학습 유도.
- Equivariant Structure Decoder 개선 (SE(3)-BlockGAT):
- Contribution: Discrete token으로부터 3D 좌표를 생성하기 위해, BlockGAT 레이어와 결합된 새로운 SE(3) Frame Passing Layer를 이용한 decoder 제안. Iterative refinement 방식으로 구조 정밀도 향상.
- Novelty: Plug-and-play 방식의 효율적인 SE(3) layer를 통해 경량화된 decoder (4.9M parameters)로도 높은 구조 복원 성능 달성. AlphaFold2와 같은 복잡한 구조 모방 없이 자체 개발.
- Multi-chain Protein Structure Tokenization:
- Contribution: Single-chain을 넘어 multi-chain complex에 대한 구조 tokenization 및 복원 가능성 입증.
- Novelty: 본 연구가 multi-chain protein structure tokenization을 성공적으로 수행한 거의 첫 번째 시도일 수 있음을 시사.
Experimental Highlights:
- Datasets: CATH4.3 (single-chain), PDB (multi-chain).
- Metrics: TMScore (↑), RMSD (↓).
- Key Result 1 (Single-chain Reconstruction vs. FoldToken1):
- FoldToken2† (adaptive temp. 적용)가 FoldToken1 대비 TMScore 약 20% 향상, RMSD 약 81% 감소 (T116 dataset 기준). reconstruction 성능 대폭 개선.
- Worst-case 성능도 크게 향상되어 모델 안정성 증가 (T116: min TMScore 0.88, max RMSD 0.93).
- Key Result 2 (AutoEncoder Performance w/o VQ):
- Vector Quantization 없이도 매우 높은 복원 성능 (T116: TMScore 0.98, RMSD 0.41) 달성, 제안된 encoder-decoder 구조의 우수성 입증.
- ESMFold, AlphaFold2/3와 비교 시, 훨씬 작은 decoder (4.9M)로 경쟁력 있는 정확도 달성.
- Key Result 3 (Multi-chain & Efficiency):
- Multi-chain complex에 대해서도 낮은 RMSD로 복원 가능함을 시각적으로 확인 (Fig. 4).
- 전체 PDB 데이터셋 학습에 단 40 GPU hours 소요 (OpenFold 5000 GPU hours 대비 매우 효율적).
- Ablation Study Insights:
- Adaptive temperature annealing (FT2†)이 성능 향상에 중요.
- Codebook size가 클수록 (e.g., 65536) 성능 향상.
- Encoder/Decoder 8 layer, hidden dim 128 설정이 CATH4.3 데이터셋에서 최적 (더 깊거나 넓은 모델은 성능 향상 미미).
- Decoder의 kNN 수를 30개 이상 늘려도 성능 향상 없음.
Limitations and Future Work:
- Limited Scalability on Test Data: CATH4.3 데이터셋에서는 모델 크기 (layer/hidden dim) 증가가 성능 향상으로 이어지지 않음.
- Reason: 데이터셋 크기 제약 또는 더 섬세한 large model 학습 전략 필요.
- Future Work: 더 큰 데이터셋 (e.g., AF2DB)에서의 학습, large model 안정화 및 튜닝 연구.
- Downstream Task Performance: 제안된 TokenFlow (sequence to structure) 및 FoldGPT (masked structure generation) 모델이 초기 실험에서 기대만큼 잘 작동하지 않음 (overfitting 등).
- Reason: 모델/학습 규모 부족.
- Future Work: 관련 downstream task 모델의 scale-up 및 학습 전략 개선.
- Computational Cost: FoldToken1보다는 학습 속도가 느림 (데이터 표현, backbone, refinement 전략 차이).
Overall Summary:
FoldToken2는 단백질 구조의 SE(3) equivalent 문제를 해결하기 위해, invariant하고 compact하며 정확하게 복원 가능한 discrete language를 학습하는 새로운 프레임워크입니다. 개선된 BlockGAT encoder, T-SoftCVQ quantifier, SE(3)-BlockGAT decoder를 통해 기존 FoldToken1 대비 reconstruction 성능을 혁신적으로 향상시켰으며, multi-chain complex tokenization의 가능성을 처음으로 제시했습니다. 특히 경량화된 모델과 높은 학습 효율성은 주목할 만하며, 향후 invariant한 단백질 언어를 활용한 다양한 downstream task 연구의 발전을 촉진할 잠재력을 지닙니다.
1과의 차이점
FoldToken1 vs. FoldToken2: 주요 차이점
FoldToken2는 FoldToken1의 핵심 아이디어를 계승하면서 세 가지 주요 구성 요소 (encoder, quantifier, decoder)를 모두 개선하여 성능과 적용 범위를 크게 확장했습니다. 주요 차이점은 다음과 같습니다.
- Input Representation & Encoder Backbone:
- FoldToken1: 단백질 backbone 구조를 bond angle과 torsion angle로 표현하고, Transformer 기반의 encoder를 사용했습니다. 이는 3D 공간 정보를 간접적으로만 포착하는 한계가 있었습니다.
- FoldToken2: 3D 좌표(coordinates) 기반으로 residue를 block 단위로 표현하고, BlockGAT [15]라는 Graph Neural Network (GNN) 기반의 encoder를 사용합니다. 이를 통해 3D 공간 의존성을 더 직접적이고 효과적으로 학습하여 invariant한 feature를 추출합니다.
- Vector Quantization (VQ) Module:
- FoldToken1: 기본적인 SoftCVQ를 사용했습니다.
- FoldToken2: SoftCVQ를 개선한 T-SoftCVQ를 제안합니다.
- Teacher-guided VQ: 학습 초기 단계에서 encoder의 embedding을 VQ 결과로 일부 사용하여 수렴을 가속화합니다.
- Adaptive Temperature Annealing: 학습 loss에 따라 temperature 조절 파라미터(β)를 동적으로 조절하여, 기존의 선형적인 temperature 감소 방식보다 더 빠르고 안정적으로 학습합니다 (Gradient explosion 문제 완화 및 수렴 속도 향상).
- Decoder (Structure Generation):
- FoldToken1: Transformer 기반 decoder를 사용
- FoldToken2: 새로운 SE(3)-BlockGAT decoder를 제안합니다.
- BlockGAT으로 residue 간 상호작용을 파악하고, SE(3) Frame Passing Layer를 통해 local frame 정보를 명시적으로 업데이트하며 3D 구조를 생성합니다.
- Iterative Refinement: 여러 SE(3)-BlockGAT layer를 쌓아 점진적으로 구조를 정교화합니다.
- 기존의 복잡한 SE(3) 모델 구조 대신, 더 가볍고 효율적인 plug-and-play SE(3) layer를 설계하여 사용합니다.
- Scope & Performance:
- FoldToken1: 주로 single-chain 단백질 구조 tokenization에 초점을 맞췄습니다.
- FoldToken2: Single-chain 성능을 대폭 향상시켰을 뿐만 아니라 (TMScore 20%↑, RMSD 81%↓), multi-chain 단백질 complex tokenization까지 확장하여 가능성을 보였습니다.
요약: FoldToken2는 FoldToken1 대비 데이터 표현 방식, 핵심 네트워크 구조 (Transformer -> GNN), VQ 전략, Decoder 구조 등 거의 모든 면에서 개선이 이루어졌으며, 이를 통해 재구성(reconstruction) 성능, 학습 효율성, 적용 범위(multi-chain) 에서 큰 발전을 이루었습니다.
Abstract
3D coordinates의 equivalent 특성은 protein structure representation learning, alignment, 그리고 generation에서 오랫동안 어려움을 제기해 왔습니다. protein structures를 equivalently represents하는 compact하고 invariant language를 만들 수 있을까요?
이 목표를 향해, 우리는 FoldToken2를 제안하여 equivariant structures를 discrete tokens로 변환하면서 original structures의 recoverability를 유지합니다.
FoldToken1에서 FoldToken2로 오면서, 우리는 세 가지 핵심 구성 요소를 개선했습니다: (1) invariant structure encoder, (2) vector-quantized compressor, 그리고 (3) equivalent structure decoder.
우리는 protein structure reconstruction task에서 FoldToken2를 평가했으며, 이전 FoldToken1보다 TMScore에서 20%, RMSD에서 81% 더 우수한 성능을 보임을 보여줍니다. FoldToken2는 아마도 single-chain 및 multi-chain protein structures quantization 모두에서 잘 작동하는 첫 번째 방법일 것입니다.
우리는 FoldToken2가 protein structure representation learning, structure alignment, 그리고 structure generation tasks에서 추가적인 개선을 이끌어낼 것이라고 믿습니다.
1 Introduction
"SE-(3) structure는 특별하거나 어려워서는 안 됩니다. 장벽을 낮추자." – Our Goal
Protein structure modelling은 computational biology에서 기초적인 역할을 하며 machine learning 분야에서 점점 더 많은 관심을 받고 있습니다. SE-(3) equivariant nature 때문에, protein structure를 encoding하고 generating하는 것은 결코 간단한 일이 아니며, protein structures에 특화된 특별한 design을 요구합니다. 예를 들어, PiFold는 structure patterns를 encode하기 위해 invariant featurizer를 제안했고, AlphaFold2는 equivariant 3D coordinates를 generate하기 위해 frame-based model을 design했습니다. protein structure models를 designing하는 데 수많은 innovations이 제안되었지만, structures data 자체는 SE-(3) nature를 유지합니다. equivariant structures를 invariant form으로 transform한 다음, 기존 NLP/CV models를 사용하여 structures를 encode하고 generate할 수 있을까요?
우리는 SE-(3) structures를 invariant representations으로 transform하는 새로운 방법인 FoldToken2를 소개합니다. 핵심 아이디어는 self-reconstruction을 통해 structure information을 preserve하는 compact invariant latent space를 만드는 것입니다. pretraining 후, invariant latent representation은 latent space에서 editable한 equivariant structures의 prototype 역할을 할 수 있습니다. image와 text와 유사하게, 우리는 또한 latent space를 discretize하여 SE-(3) invariant language를 만들기 위해 vector quantization module을 도입합니다. invariant latent embedding 또는 language를 input으로 사용하여, 기존 CV 또는 NLP models를 활용하여 protein structures를 encode하고 generate할 수 있습니다. FoldToken2는 세 가지 핵심 구성 요소, 즉 (1) invariant encoder, (2) vector quantization module, (3) equivariant decoder를 포함합니다.
frame-based GNN (BlockGAT)은 equivariant structures를 invariant embeddings으로 encoding하는 데 사용됩니다. FoldToken1은 backbone structures를 bond 및 torsion angles로 represent하여 3D dependencies를 capture하는 능력이 부족했습니다. 이에 대한 해결책으로, FoldToken2는 AlphaFold2처럼 residues를 block으로 represent하며, 효율적이고 강력한 graph neural network BlockGAT을 사용합니다. BlockGAT은 informative 3D dependencies를 capturing하기 위한 단순화된 featurizer와 high-level representations을 learning하기 위한 optimized graph network module을 포함합니다. sparse graph attention mechanism은 SE-(3) transformers보다 훨씬 efficient하게 만들어주며, 이는 large-scale pre-training에 중요합니다.
optimized vector quantization module (T-SoftCVQ)은 continuous embeddings를 discrete tokens(fold language라고 함)로 quantize하는 데 사용됩니다. image 및 video modelling에 vector quantization을 적용하여 큰 성공을 거두었지만, protein structure modelling에 적용하려는 시도는 거의 없었습니다. BERT 및 GPT와 같은 advanced sequence models를 강력한 structure learner로 만들기 위해서는 continuous embeddings를 discrete tokens로 quantizing하는 것이 중요합니다. FoldToken1을 기반으로, 우리는 새로운 temperature annealing 및 encoding strategy를 제안하여 vector quantization module (SoftCVQ)을 더욱 개선하여 T-SoftCVQ를 만들었습니다. 정교한 design 덕분에, reconstruction results는 이전 VQ methods보다 지속적으로 우수한 성능을 보입니다.
conditional SE-(3) decoder는 structure generation을 위해 제안되었습니다. decoder는 discrete tokens를 conditional features로 사용하고, equivariant graph message passing을 통해 3D coordinates를 generates합니다. 각 SE-(3) layer는 BlockGAT과 plug-and-play SE-(3) module을 포함합니다. 여러 개의 SE-(3) BlockGATs를 stacking함으로써, 우리는 discrete tokens에 conditioned된 Gaussian noise로부터 protein structure를 iteratively refine합니다. reconstruction 동안, BlockGAT은 pairwise residue interactions를 online으로 extracts하고, SE-(3) module은 frame-level message passing을 사용하여 local frames를 refine합니다.
우리는 single-chain 및 multi-chain settings 모두에서 FoldToken2의 reconstruction performance를 평가합니다. single chain reconstruction에서, 아마도 single-chain을 tokenizing한 첫 번째 방법일 수 있는 FoldToken1에 이어, FoldToken2는 TMScore와 RMSD 모두에서 각각 20%와 81% 향상된 reconstruction performance를 보입니다. 또한, FoldToken2는 이 idea를 multi-chain protein structure reconstruction으로 확장하여 유망한 결과를 이끌어냅니다. 우리는 FoldToken2가 protein structure representation learning, structure alignment, 그리고 structure generation tasks에서 추가적인 개선을 이끌어낼 것이라고 믿습니다.
FoldToken2 Introduction: 정리 노트 (AI 연구자용)
핵심 목표:
- Protein structure modeling의 SE(3) equivariance 문제를 해결하여, 기존 NLP/CV model들을 protein structure 연구에 더 쉽게 활용할 수 있도록 장벽을 낮추는 것.
핵심 아이디어:
- Equivariant 3D protein structure를 **discrete invariant tokens ("fold language")**로 변환.
- 이를 통해 structure 정보를 invariant latent space에서 self-reconstruction으로 보존하고, 이 latent representation (또는 token)을 input으로 사용하여 표준 sequence/CV model을 적용 가능하게 함.
핵심 구성 요소 및 개선점:
- Invariant Encoder (BlockGAT):
- Equivariant structure를 invariant embedding으로 변환.
- FoldToken1 (bond/torsion angles) 대비 개선: AlphaFold2와 유사한 block 단위 representation 사용 + BlockGAT (Graph Neural Network) 적용으로 3D dependencies 효과적 capture.
- SE(3) transformer 대비 sparse graph attention으로 효율성 증대 (large-scale pre-training에 유리).
- Vector Quantization (T-SoftCVQ):
- Continuous invariant embedding을 discrete tokens ("fold language")로 quantize.
- FoldToken1의 SoftCVQ 개선 (new temperature annealing & encoding strategy).
- VQ를 통해 BERT, GPT 등 advanced sequence model을 structure learner로 활용 가능하게 하는 것이 목표.
- Equivariant Decoder:
- Discrete tokens를 condition으로 받아 3D coordinates를 generation.
- Stacked SE(3)-BlockGATs 구조 사용: 각 layer는 BlockGAT (pairwise interactions 추출) + SE(3) module (frame refinement)로 구성되어 Gaussian noise로부터 structure를 iteratively refine.
주요 성과 (Introduction 기준):
- Reconstruction 성능: FoldToken1 대비 single-chain에서 TMScore 20%, RMSD 81% 향상.
- 확장성: 기존 single-chain tokenization 아이디어를 multi-chain protein structure로 성공적으로 확장.
기대 효과:
- Protein structure representation learning, structure alignment, structure generation 연구 촉진.
쉬운 설명 : FoldToken2 Introduction
문제: 단백질은 복잡한 3D 구조를 가지고 있고, 이걸 컴퓨터로 다루려니 좀 까다로워요. 단백질을 이리저리 돌리면 컴퓨터가 보는 좌표 값은 계속 바뀌는데 (이걸 SE(3) equivariant 하다고 해요), 구조 자체의 본질은 그대로잖아요? 이런 특성 때문에 단백질 구조를 분석하거나 새로 만드는 AI 모델은 특별하게 디자인해야 해서 어려움이 있었어요.
FoldToken2의 아이디어: "단백질 3D 구조를, 마치 우리가 글자를 읽듯이, **고유한 '단어'(token)**들로 바꿔보면 어떨까?" 이 '단어'들은 단백질을 어떻게 돌리든 변하지 않는(invariant) 정보만 담고 있어요. 이렇게 '단백질 구조 언어'(fold language)를 만들면, 이미 이미지나 텍스트 분석에 강력한 성능을 보이는 기존 AI 모델들 (NLP/CV models)을 가져다가 단백질 구조를 분석하고 생성하는 데 쓸 수 있지 않을까요? 이게 FoldToken2의 핵심 목표예요. 단백질 연구의 진입 장벽을 낮추려는 거죠.
어떻게?: 크게 3단계로 작동해요.
- 구조 분석기 (Encoder): 단백질의 3D 구조 정보를 받아서, 방향과 상관없는 핵심 특징(invariant embedding)을 뽑아내요. 이전 버전(FoldToken1)보다 더 똑똑한 방법(BlockGAT)을 써서 3D 관계를 잘 파악해요.
- 단어 변환기 (Quantizer): 분석기가 뽑아낸 특징 정보를 여러 개의 정해진 '단어'(discrete token) 중 가장 비슷한 것으로 바꿔줘요. 이걸 '단백질 구조 언어'로 만드는 과정이죠. 이전 버전보다 이 변환 성능을 더 높였어요(T-SoftCVQ).
- 구조 생성기 (Decoder): 이렇게 만들어진 '단어'들만 보고도 원래의 3D 단백질 구조를 다시 만들어낼 수 있어요.
결론: 이 새로운 방법(FoldToken2)은 이전 버전보다 훨씬 더 정확하게 원래 구조를 복원해냈고(TMScore 20%, RMSD 81% 향상), 여러 개 단백질이 뭉쳐있는 복잡한 구조에도 잘 작동해요. 앞으로 단백질 구조를 이해하고, 비교하고, 새로 만드는 AI 연구에 큰 도움이 될 거라고 기대하고 있어요.
2 Method
2.1 Overall Framework
Fig.1에 나타난 바와 같이, 전반적인 framework는 encoder, quantifier, decoder를 포함하여 FoldToken1과 동일하게 유지됩니다. FoldToken2는 reconstruction performance를 향상시키기 위해 각 module을 포괄적으로 개선했으며, 요약하면 다음과 같습니다:
- Data Form: Angle-based representation에서 coordinate-based representation으로 변경.
- Backbone: Transformer backbone을 BlockGAT이라는 새로운 GNN으로 교체.
- VQ: Teacher-guided temperature annealing strategy를 도입.
- Generator: Protein structures를 iteratively refine하기 위한 새로운 SE-(3) layer를 제안.
2.2 Invariant Graph Encoder
Rotation 및 translation equivariant nature 때문에 동일한 protein이라도 다른 coordinates 기록을 가질 수 있으며, 이는 동일한 protein에 대해 compact invariant representations를 learning하는 데 어려움을 제기합니다. 이전 연구들에서는 invariant featurizer가 invariant structure patterns를 encode할 수 있음을 보여주었으며, 우리도 동일한 방식을 따릅니다: protein structures를 invariant node 및 edge features로 구성된 graph로 representing합니다. 그런 다음 BlockGAT을 사용하여 high-level representations를 learn합니다.
Frame-based Block Graph. (n)개의 blocks를 포함하는 protein (M = {B_s}{s=1}^n)이 주어졌을 때 (여기서 각 block은 아미노산을 represent함), 우리는 kNN algorithm을 사용하여 block graph (G({B_s}{s=1}^n, E))를 구축합니다. block graph에서 (s)-번째 node는 (B_s = (T_s, f_s))로 represented되고, ((s, t)) 사이의 edge는 (B_{st} = (T_{st}, f_{st}))로 represented됩니다. (T_s)와 (T_{st} = T_s^{-1} \circ T_t)는 각각 (s)-번째 block의 local frames와 (s)-번째 block과 (t)-번째 block 사이의 relative transform입니다. (f_s)와 (f_{st})는 node 및 edge features입니다.
BlockGAT Encoder. 우리는 BlockGAT layer (f_{\theta})를 사용하여 block-level representations를 learn합니다: $$ f^{(l+1)}s, f^{(l+1)}{st} \leftarrow \text{BlockGAT}s(f^{(l)}s, f^{(l)}{st} | T_s, T{st}, E) \quad (1) $$ 여기서 (f^{(l)}s)와 (f^{(l)}{st})는 (l)-번째 layer의 input node 및 edge features를 represent합니다. (T_s = (R_s, t_s))는 (s)-번째 block의 local frame이고, (T_{st} = T_s^{-1} \circ T_t = (R_{st}, t_{st}))는 (s)-번째 block과 (t)-번째 block 사이의 relative transform입니다. (T_s), (T_{st}), (f^{(0)}s) 및 (f^{(0)}{st})는 이전에 제안된 invariant featurizer를 사용하여 ground truth structures로부터 initialize됩니다.
2.3 Quantifier
FoldToken1을 따라, 우리는 SoftCVQ를 사용하여 invariant embeddings를 discrete tokens(fold language라고 함)로 quantize합니다. SoftCVQ는 continual embeddings를 discrete tokens로 projecting하는 대신, 미리 정의된 binary number ((b_j))를 continuous token embeddings (v_j)로 maps하고, 그 다음 token embeddings (v_j)와 latent embeddings (h_s) 사이의 soft alignment를 수행합니다.
Decimal integer (z)와 codebook size (m)이 주어졌을 때, 우리는 (z)를 길이 (\log_2(m))의 binary vector (b_j)로 represent합니다. 예를 들어, (m=4)이면 (b_1 = [0, 0], b_2 = [0, 1], b_3 = [1, 0], b_4 = [1, 1])입니다. quantization operation은 continuous embeddings (h_s)를 discrete tokens (z)로 transforms합니다: $$ \begin{cases} a_{sj} = \frac{\exp(f_s^T v_j / T)}{\sum_{j=1}^m \exp(f_s^T v_j / T)} \ z = \text{arg max}j a{sj} \end{cases} \quad (2) $$ 여기서 (v_j)는 (j)-번째 token embedding이며, conditional network에 의해 generated되고, 이는 또한 de-quantization operation 역할도 합니다: $$ \begin{cases} b_s = \text{Bit}(z, \log_2(m)) \ \hat{f}s = \text{ConditionNet}(b_s) \end{cases} \quad (3) $$ 식 (2)에서, "argmax"는 non-differentiable하므로, training 중에는 soft approximation을 사용합니다: $$ \begin{cases} \hat{f}i = \sum{j=1}^m a{ij} v_j \ b_j = \text{Bit}(j, \log_2(m)) \ v_j = \text{ConditionNet}(b_j) \end{cases} \quad (4) $$ Temperature parameter (T)는 attention query operation의 softness를 controls합니다. (T)가 크면 attention weights는 uniform에 가까워지고, 그렇지 않으면 attention weights는 one-hot에 가까워집니다. Training 동안, temperature는 1.0에서 1e-8까지 annealed됩니다. ConditionNet: (\mathbb{R}^{\log_2(m)} \rightarrow \mathbb{R}^d)는 MLP입니다. 만약 codebook size가 (2^{16})이면, MLP는 65536개의 16-dimension boolvectors를 65536개의 (d)-dimension vectors로 projects합니다. Softmax operation의 gradient를 고려하면: $$ [s_1, s_2, \dots, s_k] = \text{Softmax}(z_1/T, z_2/T, \dots, z_k/T) $$ $$ \frac{\partial s_i}{\partial z_j} = \frac{1}{T} s_i \cdot (\mathbb{1}(i=j) - s_j) \quad (5) $$ (T)가 0에 가까워지면 gradient가 폭발합니다. 이를 피하기 위해, temperature가 1e-5보다 낮아지면 encoder와 VQ module을 frozen합니다. Temperature scheduler는 다음과 같습니다: $$ \begin{cases} x = 2\pi(\frac{t % T_{period}}{T_{period}}) - \pi \ T = \max(10^{-8}, \frac{1+\cos(x)}{2}) \times \beta \end{cases} \quad (6) $$ (Note: 원문에서 (T_{period})가 명시되지 않았으나, 주기적인 스케줄러이므로 (T_{period})로 표기함)
Model convergence를 가속화하기 위해 다음을 도입합니다:
- Teacher-Guided VQ: 확률 (T) (temperature와 동일한 기호 사용에 유의)로 encoder embedding (f_s)를 quantized feature (\hat{f}_s)로 무작위로 복사합니다. 이를 teacher-guided vector quanternion(quantization의 오타로 보임)이라고 합니다.
- Adaptive tuner: Training loss에 기반하여 temperature scale (\beta)를 adaptively 조정합니다.
Adaptive temperature tuner는 다음과 같이 정의됩니다: $$ \beta = \begin{cases} 1.0, & \text{if } 0.1 < L \ 0.05, & \text{if } 0.05 < L \le 0.1 \ 0.01, & \text{if } 0.02 < L \le 0.05 \ 0.0001, & \text{if } L \le 0.02 \end{cases} \quad (7) $$
2.4 Equivariant Graph Decoder
Invariant representations를 condition으로 하여 protein structures를 generating하는 것은 computing efficiency 측면에서 상당한 어려움을 제기합니다. 예를 들어, 잘 알려진 AlphaFold2를 처음부터 training하는 데는 128개의 TPUv3 cores로 11일이 걸리고, OpenFold는 training에 50000 GPU hours가 소요됩니다. 이 연구에서는 structure prediction을 위해 어떤 GNN layer에도 추가할 수 있는 efficiency plug-and-play SE3Layer를 제안합니다. SE3Layer의 단순화된 module과 sparse attention을 갖춘 BlockGAT 덕분에, 우리는 8개의 NVIDIA-A100을 사용하여 전체 PDB dataset에 대해 model을 1일 안에 train할 수 있습니다.
SE-(3) Frame Passing Layer. 우리는 frame-level message passing을 도입하여, (s)-번째 block의 local frame을 이웃들로부터 relative rotation (R_s)와 translation (t_s)를 aggregate하여 update합니다: $$ \begin{cases} \text{vec}(R_s) = \sum_{j \in N_s} a^r_{sj} \text{vec}(R_{sj}) \ R_s \leftarrow \text{Quat2Rot} \circ \text{Norm} \circ \text{MLP}{9 \to 4}(\text{vec}(R_s)) \quad \text{Normalize quanternion} \ t_s = \sum{j \in N_s} a^t_{sj} t_{sj} \end{cases} \quad (8) $$ 여기서 (a^r_{sj})와 (a^t_{sj})는 rotation 및 translation weights이고, (N_s)는 (s)-번째 block의 neighbors입니다. ( \text{vec}(\cdot) )은 (3 \times 3) matrix를 9-dimensional vector로 flattens합니다. ( \text{MLP}{9 \to 4}(\cdot) )은 9-dim rotation matrix를 4-dim quaternion으로 maps하고, ( \text{Norm}(\cdot) )은 quaternion을 normalize하여 valid rotation을 represent하도록 보장합니다. ( \text{Quat2Rot}(\cdot) )은 quaternion to rotation function입니다. 세부 사항을 더 소개하면 다음과 같습니다: $$ \begin{cases} w^r{st}, w^t_{st} = \sigma(\text{MLP}(f_{st})) \ \text{vec}(R_{st}) \leftarrow w^r_{st}\text{vec}(R_{st}) + (1 - w^r_{st})\text{MLP}{d \to 9}(f{st}) \ t_{st} \leftarrow w^t_{st}t_{st} + (1 - w^t_{st})\text{MLP}{d \to 3}(f{st}) \ a^r_{st}, a^t_{st} = \text{Softmax}(\text{MLP}{d \to 1}(f{st})) \end{cases} \quad (9) $$ 여기서 (w^r_{st})와 (w^t_{st})는 rotation 및 translation을 위한 updating weights이고, (a^r_{st})와 (a^t_{st})는 attention weights입니다. 제안된 SE-(3) layer는 local frame updating을 위해 어떤 graph neural network에도 추가될 수 있습니다.
Iterative Refinement 우리는 SE-(3) layer를 BlockGAT에 추가하여 SE-(3) BlockGAT이라는 새로운 module을 제안합니다. 우리는 multi-layer SE-(3) BlockGAT을 stack하여 structures를 iteratively refine합니다: $$ \begin{cases} f^{(l+1)}s, f^{(l+1)}{st} = \text{BlockGAT}^{(l)}(f^{(l)}s, f^{(l)}{st}) \ T^{(l)}{st} = T_s^{-1} \circ T_t \ T^{(l+1)}s = \text{SE3Layer}(\text{sg}(T^{(l)}{st}), f^{(l+1)}{st}) \end{cases} \quad (10) $$ 여기서 ( \text{sg}(\cdot) )은 stop-gradient operation이고, ( \text{SE3Layer}(\cdot) )는 식 (9)에서 설명된 SE-(3) layer입니다. Predicted local frame (T^{(l)}_s)가 주어지면, 다음과 같이 3D coordinates를 얻을 수 있습니다: $$ \begin{cases} h_s = \text{MLP}(f^{(l)}_s) \ x_s = T^{(l)}_s \circ h_s \end{cases} \quad (11) $$
2.5 Reconstruction Loss
Chroma에서 영감을 받아, 우리는 model을 train하기 위해 multiple losses를 사용합니다. Overall loss는 다음과 같습니다: $$ L = L_{\text{global}} + L_{\text{fragment}} + L_{\text{pair}} + L_{\text{neighbor}} + L_{\text{distance}} \quad (12) $$ Loss 항들을 설명하기 위해, aligned RMSD loss를 ( L_{\text{align}}(\hat{X}, X) = |\text{Align}(\hat{X}, X) - X| )로 정의합니다. 여기서 (X \in \mathbb{R}^{n,3})은 ground truth 3D coordinates이고, (\hat{X} = {x_1, x_2, x_3, \dots, x_n} \in \mathbb{R}^{n,3})은 predicted 3D coordinates입니다. Global, fragment 및 pair loss는 aligned MSE loss로 정의되지만, 다른 input data shape를 갖습니다:
- Global Loss: (X) shape ([n, 4, 3]). Global structure의 RMSD.
- Fragment Loss: (X) shape ([n, c, 4, 3]). 각 residue에 대한 (c)개 neighbors의 RMSD.
- Pair Loss: (X) shape ([n, K, c \cdot 2, 4, 3]). 각 kNN pair에 대한 (c)개 neighbors의 RMSD.
- Neighbor Loss: (X) shape ([n, K, 4, 3]). 각 residue에 대한 (K)개 neighbors의 RMSD.
여기서 (n)은 residues의 수, (c = 7)은 fragments의 수, (K = 30)은 kNN의 수, 4는 네 개의 backbone atoms {N, CA, C, O}를 고려함을 의미하고, 3은 3D coordinates를 의미합니다. Distance loss는 predicted pairwise distances와 ground truth pairwise distances 사이의 MSE loss로 정의됩니다: $$ L_{\text{distance}} = |\text{Dist}(\hat{X}) - \text{Dist}(X)| \quad (13) $$ 여기서 ( \text{Dist}(X) \in \mathbb{R}^{n,n} )은 3D coordinates (X)의 pairwise distance matrix입니다. 우리는 각 decoder layer에 loss를 적용하며, final loss는 그 평균값으로, 이는 좋은 performance에 중요합니다.
FoldToken2 Method: 정리 노트 (AI 연구자용)
Overall Framework (vs. FoldToken1):
- 기본 구조(Encoder-Quantizer-Decoder)는 유지.
- 주요 개선점:
- Input: Angle-based → Coordinate-based (Frame-based representation 활용).
- Encoder Backbone: Transformer → BlockGAT (GNN).
- Quantizer (VQ): SoftCVQ 기반 + Teacher-Guided Temperature Annealing & Adaptive Tuner.
- Decoder: Novel SE(3) Layer 기반 Iterative Refinement.
2.2 Invariant Graph Encoder:
- Input Representation: kNN 기반 Frame-based Block Graph.
- Node ((B_s)): Amino acid block, Local Frame (T_s) 및 Node feature (f_s) 포함.
- Edge ((B_{st})): Relative Transform (T_{st} = T_s^{-1} \circ T_t) 및 Edge feature (f_{st}) 포함.
- Model: BlockGAT ((f_{\theta})) 사용. (T_s, T_{st}) 등에서 추출한 invariant features 를 입력받아 high-level representation 학습 (Eq. 1). Feature는 [15] 방식 사용.
2.3 Quantifier (SoftCVQ 개선):
- 기본 방식: SoftCVQ - Predefined binary vector ((b_j)) → Continuous token embedding ((v_j)) mapping (via ConditionNet MLP) 후, latent embedding ((h_s))과 soft alignment (Eq. 2, 4).
- 개선점 (Novelty):
- Temperature Annealing: Cosine schedule (Eq. 6) 사용, (T)를 1.0에서 1e-8까지 점진적 감소. (T < 1e-5) 시 Encoder/VQ frozen하여 gradient 폭발 방지 (Eq. 5).
- Teacher-Guided VQ: Training 시 확률 (T) (temperature)로 quantized feature (\hat{f}_s) 대신 원본 encoder embedding (f_s) 사용 → Convergence 도움.
- Adaptive Temperature Tuner: Training loss (L) 값에 따라 annealing scale (\beta) 자동 조절 (Eq. 7).
2.4 Equivariant Graph Decoder:
- 목표: Invariant token/embedding 조건 하에 Equivariant 3D coordinate 생성.
- 핵심: Efficiency (8x A100으로 1일 학습 주장) + Novel SE(3) Layer.
- SE(3) Frame Passing Layer (Eq. 8, 9):
- Plug-and-play 모듈. GNN layer에 추가 가능.
- Local frame ((T_s)) 업데이트: 이웃들의 relative transform ((R_{sj}, t_{sj})) 정보를 attention ((a^r_{sj}, a^t_{sj})) 및 update gate ((w^r_{st}, w^t_{st})) 가중합으로 aggregate.
- Rotation은 9D matrix → 4D quaternion 변환 후 처리.
- Iterative Refinement (Eq. 10):
- SE(3)-BlockGAT module (BlockGAT + SE(3) Layer) stacking.
- 각 layer: BlockGAT으로 feature 업데이트 → SE(3) Layer로 frame 업데이트 (이전 frame (T_{st}^{(l)})은 stop-gradient 처리).
- Coordinate Generation (Eq. 11): 최종 refined frame (T_s^{(l)})와 feature (f_s^{(l)})로부터 (x_s = T_s^{(l)} \circ \text{MLP}(f_s^{(l)})) 계산.
- SE(3) Frame Passing Layer (Eq. 8, 9):
2.5 Reconstruction Loss:
- Chroma [16] 방식 multi-term loss (Eq. 12).
- 주요 Loss: Aligned RMSD ((L_{\text{align}})) 기반 loss들 (Global, Fragment, Pair, Neighbor) + Pairwise distance MSE ((L_{\text{distance}}), Eq. 13).
- Key Detail: Loss를 모든 decoder layer에서 계산 후 평균 → 성능에 중요.
쉬운 설명 : FoldToken2 Method
FoldToken2가 단백질 3D 구조를 '단백질 구조 언어'(token)로 바꾸고, 다시 3D 구조로 복원하는 방법을 좀 더 자세히 알아볼게요. 크게 3단계로 이루어집니다.
1단계: 구조 분석 및 특징 추출 (Invariant Graph Encoder)
- 뭘 보나?: 이전 버전(FoldToken1)과 달리, 이제 단백질의 3D 좌표를 직접 봅니다.
- 어떻게?: 각 아미노산(단백질 구성 블록)마다 자신만의 기준 좌표계('local frame')를 설정하고, 주변 아미노산들과의 상대적인 위치/방향 관계('relative transform')를 계산해요. 이걸 kNN(가까운 이웃 찾기) 알고리즘으로 연결해서 'Block Graph'라는 관계망 지도를 만들어요.
- 분석 도구: 이 관계망 지도 정보를 BlockGAT이라는 똑똑한 그래프 분석 도구(GNN)에 넣어서, 단백질을 어떻게 돌려도 변하지 않는 고유한 특징(invariant features)을 뽑아냅니다.
2단계: 특징을 '단어'로 바꾸기 (Quantifier)
- 목표: 1단계에서 뽑아낸 연속적인 특징 값들을, 정해진 몇 개의 '단어'(discrete token) 중 하나로 딱딱 끊어서 바꿔줍니다.
- 방법 (SoftCVQ 개선):
- 미리 '단어 사전'(codebook)을 만들어두고, 각 단어에 해당하는 숫자 코드(binary vector)를 준비해요.
- 1단계 특징과 가장 비슷한 단어를 '부드럽게'(soft alignment) 찾습니다.
- 개선점:
- 온도 조절 학습: 처음엔 '뜨겁게'(높은 온도 T) 시작해서 여러 단어를 고려하다가, 점차 '식히면서'(온도 낮춤, annealing) 가장 확실한 단어 하나만 고르도록 유도해요. (마치 금속을 천천히 식혀 안정시키는 것처럼요!) 특정 온도 이하에선 학습을 잠시 멈춰서 계산이 불안정해지는 걸 막아요.
- 선생님 찬스: 가끔씩(확률 T) 단어로 바꾸는 대신, 1단계에서 얻은 원본 특징을 그대로 사용해서 학습이 더 잘 되도록 도와줘요 (Teacher-Guided).
- 자동 튜닝: 학습이 얼마나 잘 되는지(loss) 봐가면서 '식히는 속도'(온도 조절 변수 (\beta))를 알아서 조절해요 (Adaptive Tuner).
3단계: '단어'로부터 3D 구조 만들기 (Equivariant Graph Decoder)
- 목표: 2단계에서 만들어진 '단어'(token) 정보만 가지고 원래의 3D 구조를 복원(생성)합니다.
- 핵심 기술:
- 효율적인 조립 블록 (SE(3) Layer): 새롭게 만든 이 부품은 각 아미노산의 3D 위치와 방향('frame')을 주변 아미노산들과의 관계를 바탕으로 업데이트해줘요. 기존의 복잡한 방식보다 훨씬 빠르고 효율적이라고 해요. (다른 GNN에도 쉽게 끼워 쓸 수 있대요!)
- 단계별 정교화 (Iterative Refinement): 1단계의 BlockGAT과 3단계의 SE(3) Layer를 합친 'SE(3)-BlockGAT' 모듈을 여러 층 쌓아서, 구조를 한 단계씩 점진적으로 더 정확하게 다듬어 나갑니다.
- 최종 조립: 마지막 층에서 얻어진 각 아미노산의 위치/방향 정보('frame')와 특징 정보를 이용해 최종 3D 좌표를 계산해냅니다.
학습 방법 (Reconstruction Loss):
- 모델이 얼마나 잘 만들었는지 평가하기 위해, 원본 3D 구조와 모델이 만든 구조를 여러 측면에서 비교해요. 전체적인 모양, 부분적인 조각들, 특정 아미노산 쌍의 관계, 주변 이웃 구조, 원자 간 거리 등을 모두 고려해서 종합 점수(Loss)를 매깁니다.
- 특히, 3단계의 각 정교화 단계마다 점수를 매겨서 평균을 내는 방식으로 학습하는데, 이게 성능 향상에 중요하다고 하네요.
