논문리뷰

architecture : 논문리뷰 : Titans: Learning to Memorize at Test Time

AI바라기 2025. 1. 15. 16:59

Titans: Learning to Memorize at Test Time 논문 정리

Purpose of the Paper

기존의 Recurrent models과 Attention 기반 모델들은 각각 장단점이 존재했습니다. Recurrent models은 데이터를 fixed-size memory (hidden state)로 압축하여 long-term dependency를 포착하는 데 한계가 있었고, Attention은 모든 토큰 간의 direct dependency를 포착하여 정확도가 높지만 quadratic cost로 인해 fixed-length context로 제한되었습니다.

이 논문의 특별한 목적은 Recurrent models의 장점인 효율성과 Attention의 장점인 정확한 dependency modeling을 결합하여, 긴 Context에서도 효과적으로 작동하는 새로운 neural long-term memory module을 설계하는 것입니다. 이를 통해 test time에도 지속적으로 historical context를 memorize하고, 이를 바탕으로 current context에 대한 attention을 조절할 수 있는 Titans architecture를 제안합니다. 즉, "Learning to Memorize at Test Time" 이라는 새로운 패러다임을 제시하는 것이 이 논문의 핵심적인 Purpose입니다.

Key Contributions

  • Neural Long-term Memory Module: human long-term memory system에서 영감을 받아, surprise(예측과의 불일치 정도)가 큰 데이터를 더 잘 기억하도록 설계된 (deep) neural long-term memory를 제안합니다. 이 모듈은 associative memory loss를 기반으로 입력의 surprise를 측정하고, 이를 통해 historical context의 abstraction을 parameters에 encoding 합니다. 또한, limited memory를 효율적으로 관리하기 위해 memory size와 data surprise에 따라 과거 정보를 adaptive하게 잊는 decaying mechanism을 도입했습니다.
  • Titans Architecture: 설계된 long-term memory를 deep learning architecture에 효과적으로 통합하는 세 가지 variants의 Titans architecture (MAC, MAG, MAL)를 제안합니다. 이들은 각각 long-term memory를 context, gated branch, layer로 활용하는 방식을 보여줍니다.
  • Parallelizable Training Algorithm: Deep neural long-term memory를 위한 빠르고 parallelizable한 training algorithm을 제시합니다. Mini-batch gradient descent, momentum, weight decay를 tensorized mini-batch gradient descent로 재구성하여 matmul operations을 최대로 활용함으로써 GPU/TPU acceleration을 극대화합니다.
  • Comprehensive Experimental Evaluation: Language modeling, commonsense reasoning, genomics, time series 등 다양한 task에서 Titans의 성능을 평가합니다. 실험 결과는 Titans가 Transformers와 최신 linear recurrent models를 능가하며, 특히 2M 이상의 context window size에서도 높은 정확도를 유지함을 보여줍니다.

Novelty

  • Meta In-Context Model: 기존의 연구들이 fixed parameters를 가진 memory를 사용한 것과 달리, Titans는 test time에도 data-dependent하게 parameters를 update하는 meta in-context model을 제안합니다. 이를 통해 online으로 data를 memorize/forget하는 function을 학습하여 generalization 성능을 향상시킵니다.
  • Surprise-based Memory Update: Human memory system에서 영감을 받아 input의 surprise를 gradient로 측정하고, 이를 기반으로 memory를 update하는 새로운 mechanism을 제안합니다. 또한, momentary surprise뿐만 아니라 past surprise까지 고려하는 momentum mechanism을 도입하여, 중요한 정보가 소실되는 것을 방지합니다.
  • Adaptive Forgetting Mechanism: Memory size와 data surprise를 고려하여 adaptive하게 과거 정보를 잊는 decaying mechanism을 제안합니다. 이는 기존의 recurrent models에서 사용되는 forgetting mechanism의 generalization이며, 더 나은 memory management를 가능하게 합니다.
  • Hybrid Architecture Design: Short-term memory 역할을 하는 attention과 long-term memory 역할을 하는 neural memory를 결합한 hybrid architecture를 제안합니다. 이를 통해 두 모듈의 장점을 결합하여, long context에서도 정확하고 효율적으로 작동하는 model을 구축합니다.

Experimental Highlights

  • Language Modeling: Titans는 Transformer++, RetNet, GLA, Mamba, DeltaNet, TTT, Gated DeltaNet 등 다양한 SOTA model들을 능가하는 성능을 보였습니다. 특히, neural memory module (LMM)은 Transformer++를 포함한 non-hybrid model 중 가장 좋은 perplexity와 accuracy를 달성했습니다.
  • Needle-in-Haystack: Titans는 2M 이상의 context window size에서도 높은 정확도를 유지하며, GPT4, Llama3 with RAG, RecurrentGemma2-9B, Mistral 등의 대규모 모델들을 능가하는 성능을 보였습니다.
  • BABILong Benchmark: Few-shot 및 fine-tuning setting 모두에서 Titans (MAC)는 Mamba2.8B, RWKV-6-7B, RecurrentGemma-9B, Gemma-9B, Llama3.1-8B, GPT-4, GPT40-mini 등 다른 모든 baseline model을 능가했습니다. 특히 fine-tuning setting에서는 RMT-FT, Llama3.1-8B + RAG, Qwen2.5-72B, Llama3.1-70B, GPT-4, GPT40-mini 등 extremely large model보다도 우수한 성능을 보였습니다.
  • Time Series Forecasting: Simba framework를 사용한 실험에서, Titans의 neural memory module은 Mamba-based, iTransformer, RLinear, PatchTST, Crossformer, TIDE, TimesNet, DLinear 등 다양한 baseline model을 능가하는 성능을 보였습니다.
  • DNA Modeling: GenomicsBenchmarks를 사용한 실험에서, Titans (LMM)은 CNN, DNABERT, GPT, HyenaDNA, Transformer++, Mamba-based 등 다양한 baseline model과 경쟁력 있는 성능을 보였습니다.
  • Ablation Study: Linear memory, convolution, momentum, weight decay, persistent memory 등 Titans의 각 구성 요소가 성능에 긍정적인 영향을 미치는 것을 확인했습니다.

Limitations

  • Computational Cost of Deep Memory: Deep memory를 사용할 경우 training throughput이 감소할 수 있습니다.
  • Hyperparameter Sensitivity: Titans는 여러 hyperparameters (e.g., learning rate, weight decay, momentum)를 가지고 있으며, 최적의 성능을 얻기 위해서는 적절한 hyperparameter tuning이 필요합니다.
  • Limited Evaluation on Extremely Long Context: 2M 이상의 context window size에서의 평가는 수행되었지만, 더 긴 context에서의 성능 평가는 제한적입니다.

Future Work

  • More Efficient Memory Architectures: Memorization에 더 효과적이고 효율적인 neural architecture를 설계하는 연구가 필요합니다.
  • Scaling to Larger Models and Datasets: 더 큰 model과 dataset에 Titans를 적용하여 scalability를 평가하고 개선하는 연구가 필요합니다.
  • Applications to Other Domains: Language modeling, genomics, time series 외에도, video understanding, robotics 등 다양한 domain에 Titans를 적용하는 연구가 필요합니다.
  • Theoretical Analysis: Titans의 generalization 성능, memory capacity, expressiveness 등에 대한 theoretical analysis가 필요합니다.
  • Exploration of Different Memory Mechanisms: Surprise-based memory update 외에도, attention-based memory update, episodic memory 등 다양한 memory mechanism을 탐구할 수 있습니다.
  • Simplifications of the model: Parameters as the Function of Chunks와 같이 간소화된 버전의 모델을 만들고 이를 더 큰 모델 학습에 적용하여 효율성을 개선할 수 있습니다.

 

 

 

 

Abstract

10년 이상 recurrent models 및 attention을 효과적으로 활용하는 방법에 대한 광범위한 연구 노력이 있어 왔습니다. recurrent models은 데이터를 고정 크기 메모리(hidden state라고 함)로 압축하는 것을 목표로 하는 반면, attention은 전체 context window에 attending하여 모든 tokens의 직접적인 dependencies를 포착할 수 있습니다. 그러나 이러한 보다 정확한 dependencies의 모델링은 quadratic cost를 수반하여 모델을 고정 길이 context로 제한합니다. 우리는 과거 context를 memorize하는 방법을 학습하고 attention이 긴 과거 정보를 활용하면서 현재 context에 attend하도록 돕는 새로운 neural long-term memory 모듈을 제시합니다. 우리는 이 neural memory가 빠른 추론을 유지하면서도 빠른 병렬화 가능한 training의 이점을 가지고 있음을 보여줍니다. 메모리 관점에서, 우리는 attention이 제한된 context이지만 정확한 dependency modeling으로 인해 short-term memory로 작동하는 반면, 데이터를 memorize하는 능력으로 인해 neural memory는 long-term, 더 persistent한 memory로 작동한다고 주장합니다. 이 두 모듈을 기반으로, 우리는 Titans라고 하는 새로운 architectures 제품군을 소개하고, 이 architecture에 memory를 효과적으로 통합할 수 있는 방법을 해결하기 위해 세 가지 variants를 제시합니다. language modeling, common-sense reasoning, genomics 및 time series tasks에 대한 우리의 실험 결과는 Titans가 Transformers 및 최신 linear recurrent models보다 더 효과적임을 보여줍니다. 또한 needle-in-haystack tasks에서 기준선에 비해 더 높은 정확도로 2M보다 큰 context window size로 효과적으로 scale 할 수 있습니다.

 

 

1 Introduction

"The true art of memory is the art of attention!" — Samuel Johnson, 1787

Transformers, pure attention-based architectures는 sequence modeling에서 state-of-the-art models로 확고히 자리 잡았으며, 주로 in-context learning과 scale at scale을 학습하는 능력 때문입니다. Transformers의 주요 구성 요소인 attention modules은 associative memory blocks로 기능하며, key-value associations을 저장하고 queries(예: search signals)와 keys(예: contexts) 간의 pairwise similarity를 계산하여 이를 검색하는 방법을 학습합니다. 따라서 설계상 Transformer의 output은 현재 context window에 있는 tokens의 직접적인 dependencies에만 독점적으로 조건화됩니다. 그러나 이러한 dependencies의 정확한 모델링은 context length 측면에서 quadratic time 및 memory complexity를 수반합니다. 복잡한 실제 tasks(예: language modeling, video understanding, long-term time series forecasting)에서는 context window가 극도로 커질 수 있으므로 이러한 downstream tasks에서 Transformers의 적용 가능성이 어려워집니다.

Transformers의 scalability 문제를 극복하기 위해 최근 연구에서는 linear Transformers의 다양한 variants를 설계하는 것을 목표로 합니다. 여기서 softmax는 attention에서 kernel function으로 대체되어(자세한 내용은 §2.1 참조) memory consumption이 크게 감소합니다. 효율성과 더 긴 context로 scale할 수 있는 능력에도 불구하고, linear Transformers는 kernel trick이 모델을 linear recurrent network로 만들기 때문에 Transformers에 비해 경쟁력 있는 performance를 보여주지 못합니다. 여기서 데이터는 matrix-valued states로 압축됩니다. 그러나 이것은 linear recurrent (또는 linear Transformers) models에 대한 모순적인 사실을 가져옵니다. 한편으로는 scalability와 efficiency(linear vs. quadratic complexity)를 향상시키기 위해 이러한 linear models을 사용하는데, 그 장점은 매우 긴 context에서 나타납니다. 반면에, 매우 긴 context는 작은 vector-valued 또는 matrix-valued states에서 적절하게 압축될 수 없습니다.

더 나아가, 효율성을 넘어, Hopfield Networks에서 LSTMs, Transformers에 이르기까지 대부분의 기존 architectures는 generalization, length extrapolation 및/또는 reasoning을 다룰 때 어려움에 직면하며, 이들 모두는 많은 어려운 실제 tasks의 불가분의 일부입니다. 이러한 architectures는 인간의 뇌에서 영감을 얻었지만, 각각은 (1) learning process에 중요한 구성 요소(예: short-term memory, long-term memory, meta-memory, attending to current context 등); (2) 이러한 구성 요소가 독립적으로 작동할 수 있는 상호 연결된 systems인 방법; 및/또는 (3) 데이터로부터 적극적으로 학습하고 과거 history의 추상화를 memorize하는 능력이 부족합니다. 우리는 효과적인 learning paradigm에서 인간의 뇌와 유사하게, 각각이 learning process에 중요한 구성 요소를 담당하는 별개의 상호 연결된 modules이 있다고 주장합니다.

 

Memory Perspective

Memory는 기본적인 정신 process이며 인간 학습의 불가분의 구성 요소입니다. 적절하게 기능하는 memory system이 없으면 인간과 동물은 기본적인 반사 작용과 고정관념적인 행동에 제한됩니다. 따라서 memory는 Hopfield Networks, LSTMs, Transformers와 같은 machine learning 문헌에서 많은 중대한 연구에 영감을 주었습니다.

신경 심리학 문헌에서 memory와 learning에 대한 일반적인 정의에서 영감을 얻어, 대부분의 기존 architectures는 memory를 입력으로 인한 neural update로 간주하고, learning을 objective가 주어졌을 때 효과적이고 유용한 memory를 획득하는 process로 정의합니다. 이러한 관점에서, Recurrent Neural Networks (RNNs)는 두 가지 주요 단계를 가진 vector-valued memory module M(hidden state라고도 함)을 가진 models로 정의할 수 있습니다. 시간 t에 새로운 입력 xt가 주어지면, 모델은 (1) 함수 f(Mt-1, xt)를 사용하여 memory를 업데이트하고(압축 포함), (2) 함수 g(Mt, xt)를 사용하여 입력에 해당하는 memory를 검색합니다(자세한 내용은 §2.1 참조).

유사하게, Transformers는 growing memory와 두 가지 유사한 단계를 가진 architectures로 볼 수 있습니다. 즉, key 및 value matrices 쌍은 모델의 memory 역할을 하며, 모델은 (1) key와 value를 memory에 추가하여 memory를 업데이트하고(압축 없이), (2) query 및 key vectors의 similarity를 찾아 query vectors의 해당 memory를 검색하며, 이는 output을 위한 value vectors에 가중치를 부여하는 데 사용됩니다.

이러한 관점은 기존 paradigms, 그들의 중요한 차이점을 더 잘 이해하고 더 효과적인 architectures를 설계하는 데 도움이 될 수 있습니다. 예를 들어, Transformers와 linear Transformers의 주요 차이점은 memory structure와 memory updating step입니다. linear Transformers는 과거 데이터를 고정 크기의 matrix-valued memory로 압축하는 반면, Transformers는 (context length 내에서) 압축 없이 모든 과거 데이터를 유지합니다. linear Transformers와 linear RNNs(state space models 포함)는 모두 memory update 단계에서 정보를 압축하지만, 중요한 차이점은 memory structure에 있습니다. linear RNNs(vs. linear Transformers)는 vector-valued memory(vs. matrix-valued memory)를 사용합니다. 따라서 이러한 관점은 우리에게 다음과 같은 질문을 던지도록 동기를 부여합니다. (Q1) memory에 좋은 structure를 구성하는 것은 무엇입니까? (Q2) 적절한 memory update mechanism은 무엇입니까? 그리고 (Q3) 좋은 memory retrieval process는 무엇입니까?

인간 memory에 대한 우리의 이해를 되돌아보면, 그것은 단일 process도 아니고 단일 기능을 수행하지도 않습니다. 실제로 memory는 short-term, working, long-term memory와 같은 systems의 연합체이며, 각각은 서로 다른 neural structures를 가진 서로 다른 기능을 수행하며, 각각은 독립적으로 작동할 수 있습니다. 이 사실은 우리에게 다음과 같은 질문을 던지도록 동기를 부여합니다. (Q4) 서로 다른 상호 연결된 memory modules을 통합하는 효율적인 architecture를 어떻게 설계할 수 있습니까? 마지막으로, memory를 저장하는 것은 과거의 추상화를 encode하고 저장해야 하는 neural process입니다. 매개변수가 데이터를 linear 방식으로 encoding하는 단일 vector 또는 matrix가 long-term history를 저장하기에 충분하다고 가정하는 것은 지나친 단순화일 수 있습니다. (Q5) 먼 과거를 효과적으로 저장/기억하기 위해 deep memory module이 필요합니까?

 

Contributions and Roadmap

이 논문에서 우리는 test time에 효율적이고 효과적으로 memorize하는 방법을 학습할 수 있는 long-term neural memory module을 설계하여 위의 다섯 가지 질문에 답하는 것을 목표로 합니다. 그 설계를 기반으로, 우리는 그것이 architecture에 어떻게 통합될 수 있는지 논의합니다.

 

Neural Memory (§3). 우리는 test time에 데이터를 그 매개변수에 memorize/저장하는 방법을 학습하는 (deep) neural long-term memory를 (meta in-context model로서) 제시합니다. 인간의 long-term memory system에서 영감을 받아, 우리는 이 memory module을 설계하여 예상을 위반하는 사건(놀라운 사건)이 더 기억에 남도록 합니다. 이를 위해 associative memory loss에서 입력에 대한 neural network의 gradient로 입력의 놀라움을 측정합니다(자세한 내용은 §3.1 참조). 제한된 memory를 더 잘 처리하기 위해 memory size의 비율과 data surprise의 양을 고려하는 decaying mechanism을 제시하여 더 나은 memory 관리를 제공합니다. 우리는 이 decay mechanism이 실제로 현대 recurrent models에서 forgetting mechanism의 일반화임을 보여줍니다. 흥미롭게도, 우리는 이 mechanism이 mini-batch gradient descent, momentum, weight decay를 사용하여 meta neural network를 최적화하는 것과 동등하다는 것을 발견했습니다. 더 많은 matmul operations를 사용하기 위해 tensorizing mini-batch gradient descent를 기반으로, 우리는 deep neural long-term memory를 train하기 위한 빠르고 병렬화 가능한 algorithm을 제시합니다.

 

Titans Architectures (§4). long-term neural memory를 설계한 후, 남은 중요한 질문은 memory를 deep learning architecture에 효과적이고 효율적으로 통합하는 방법입니다. 우리는 세 가지 hyper-heads로 구성된 deep models 제품군인 Titans를 제시합니다. (1) Core: 이 module은 short-term memory로 구성되며, 데이터 처리의 주요 흐름을 담당합니다(제한된 window size를 가진 attention을 사용합니다). (2) Long-term Memory: 이 branch는 먼 과거를 저장/기억하는 역할을 하는 우리의 neural long-term memory module입니다. (3) Persistent Memory: 이것은 task에 대한 지식을 encoding하는 학습 가능하지만 날짜에 독립적인 매개변수 집합입니다. 마지막으로, 개념 증명으로, 우리는 memory를 (i) context, (ii) layer, (iii) gated branch로 통합하는 Titans의 세 가지 variants를 제시합니다.

 

Experimental Results (§5). 우리는 language modeling, commonsense reasoning, recall-intensive, needle in haystack, time series forecasting, DNA modeling tasks에 대한 실험적 평가를 수행합니다. 우리는 우리의 Titan architecture가 포괄적인 benchmarks 세트에 걸쳐 모든 현대 recurrent models와 그들의 hybrid variants(sliding-window attention과 결합)를 능가하는 것을 관찰합니다. 또한 Titans는 동일한 context window를 가진 Transformers를 능가하고, 전체 context를 사용하는 Transformers와 경쟁력 있는 performance를 보여줍니다. 이러한 결과는 Transformers와 달리 Titans가 2M보다 큰 context window size로 scale하는 동안 달성됩니다.

 

 

핵심 키워드: Transformers, Attention, Scalability, Linear Transformers, Recurrent Models, Memory (Short-term, Long-term), Titans Architectures

기존 연구의 한계:

  1. Transformers의 Scalability 문제:
    • Transformers는 sequence modeling에서 state-of-the-art를 달성했지만, quadratic time 및 memory complexity로 인해 긴 context를 다루는 데 어려움이 있습니다.
    • 실제 task (language modeling, video understanding 등)에서는 context window가 매우 커져 Transformers 적용이 어려워집니다.
  2. Linear Transformers의 한계:
    • Linear Transformers는 scalability를 개선했지만, kernel trick 사용으로 인해 모델이 linear recurrent network로 변형되어 성능이 저하됩니다.
    • Linear recurrent models는 긴 context를 작은 vector-valued 또는 matrix-valued states에 효과적으로 압축하지 못합니다.
  3. 기존 Architectures의 공통적인 한계:
    • Hopfield Networks, LSTMs, Transformers를 포함한 기존 architectures는 generalization, length extrapolation, reasoning에 어려움을 겪습니다.
    • 인간 뇌에서 영감을 받았지만, learning process에 중요한 memory components (short-term, long-term, meta-memory 등) 또는 이들의 상호 연결된 systems에 대한 고려가 부족합니다.
    • 데이터로부터 적극적으로 학습하고 과거 history의 추상화를 memorize하는 능력이 부족합니다.

본 논문의 핵심 아이디어 및 기여:

  1. Memory 관점 재정립:
    • 기존의 memory를 입력에 의한 neural update로 정의하는 관점을 넘어, 인간의 memory system (short-term, working, long-term)에서 영감을 받아 새롭게 정의합니다.
    • Memory를 별개의 상호 연결된 modules로 구성된 시스템으로 간주합니다.
    • 다섯 가지 핵심 질문을 통해 memory의 이상적인 structure, update mechanism, retrieval process, architecture 통합 방법, long-term memory 저장을 위한 deep memory module의 필요성을 탐구합니다.
  2. Neural Long-term Memory 제안:
    • Test time에 데이터를 memorize하도록 학습하는 deep neural long-term memory를 제안합니다.
    • "Surprising"한 사건이 더 기억에 남는다는 인간의 long-term memory system에서 영감을 받아, 입력의 "surprise"를 측정하는 방법을 사용합니다.
    • 효율적인 memory 관리를 위해 decaying mechanism을 도입하고, 이것이 현대 recurrent models의 forgetting mechanism과 연관됨을 보입니다.
    • Meta neural network 최적화 관점에서 neural long-term memory training을 해석하고, 빠르고 병렬화 가능한 algorithm을 제시합니다.
  3. Titans Architectures:
    • 새로운 long-term neural memory를 통합한 deep learning architecture 제품군인 Titans를 제안합니다.
    • Titans는 Core (short-term memory), Long-term Memory, Persistent Memory의 세 가지 hyper-heads로 구성됩니다.
    • Memory를 context, layer, gated branch로 통합하는 세 가지 Titans variants를 제시합니다.

정리 노트 요약:

  • 본 논문은 Transformers의 scalability 한계와 기존 architectures의 memory 관련 한계를 지적하며, 새로운 neural long-term memory와 이를 통합한 Titans architectures를 제안합니다.
  • 인간의 memory system에서 영감을 받아 memory를 재정의하고, "surprise" 기반의 memory update, decaying mechanism, 효율적인 training algorithm을 제시합니다.
  • Titans는 language modeling, reasoning, time series 등 다양한 task에서 Transformers와 최신 recurrent models을 능가하며, 2M 이상의 context window size로 scale 가능함을 보입니다.

 

 

2 Preliminaries

이 섹션에서는 논문 전체에서 사용하는 표기법과 몇 가지 배경 개념에 대해 논의합니다.

  • 𝑥 ∈ R^(N × d_in)을 input으로, M을 neural network (neural memory module)로, Q, K, V를 attention mechanism의 query, key, value로, M을 attention mask로 둡니다.
  • sequence를 분할할 때 S^(i)를 사용하여 i번째 segment를 나타냅니다.
  • 논문 전체에서 표기법을 남용하여 subscripts를 사용하여 matrix, vector 또는 segments의 특정 요소를 나타냅니다. 예를 들어, S^(i)_j는 i번째 segment의 j번째 token입니다. 유일한 예외는 t가 있는 subscripts이며, 이는 시간에 따른 recurrence를 인덱싱하거나 시간 t에서의 neural network 상태를 나타내기 위해 예약되어 있습니다.
  • neural network N과 data sample x가 주어지면, N(x) (또는 N*(x))를 사용하여 weight 조정이 있는 (또는 없는) forward pass를 나타냅니다. 또한 표기법을 남용하여 N^(k)를 사용하여 neural network의 k번째 layer를 나타냅니다.
  • 다음에서는 먼저 attention 및 효율적인 variants에 대한 배경을 논의하고, 그 다음으로 최신 linear RNNs를 검토합니다. 마지막으로, 이러한 architectures의 memory 관점을 논의하여 Titans를 설계하도록 동기를 부여합니다.

2.1 Backgrounds

  • Attention. 많은 deep learning models의 사실상 backbone인 Transformers는 attention mechanism을 기반으로 합니다. 입력 𝑥 ∈ R^(N × d_in)이 주어지면, causal attention은 입력에 종속적인 key, value, query matrices에 대한 softmax를 기반으로 출력 y ∈ R^(N × d_in)을 계산합니다.여기서 WQ, WK, WV ∈ R^(d_in × d_in)은 학습 가능한 parameters입니다. recall에서의 강력함과 효과성에도 불구하고 Transformers는 출력을 계산하기 위해 최소 N × d 연산자가 필요하므로 더 긴 sequences에 대해 더 큰 memory consumption과 더 낮은 throughput이 발생합니다.
  • Q = xWQ, K = xWK, V = xWV, (1)
    y_i = (∑︁_(j=1)^i exp(Q^⊤_i K_j / √d_in) * V_j) / (∑︁_(ℓ=1)^i exp(Q^⊤_i K_ℓ / √d_in)), (2)
    
  • Efficient Attentions. 더 긴 sequences에 대한 softmax attention의 memory consumption과 throughput을 개선하기 위해 다양한 연구가 attention의 I/O 인식 구현, attention matrix를 sparsifying하여 보다 효율적인 attention mechanisms을 설계, softmax 근사 또는 kernel-based (linear) attentions 개발에 중점을 두었습니다. 이 부분에서는 후자, 즉 linear attentions에 중점을 둡니다. 여기서 표준 attention의 softmax는 대체 kernel function φ(., .)로 대체되어 φ(x, y) = φ(x)φ(y)가 됩니다. 따라서 attention은 다음과 같이 작성할 수 있습니다.결과적으로 ∑︁_(j=1)^i φ(K_j) 및 ∑︁_(ℓ=1)^i φ(K_ℓ) 항이 각 단계에서 다시 사용되므로 더 높은 throughput이 발생합니다. kernel을 identity matrix로 선택하면 위의 공식은 recurrent 형식으로 작성할 수도 있습니다.이를 통해 linear attentions에 대한 효율적인 추론이 가능합니다.
  • M_t = M_(t-1) + K^⊤_t V_t, (4)
    y_t = Q_t M_t, (5)
    
  • y_i = (∑︁_(j=1)^i φ(Q^⊤_i K_j) * V_j) / (∑︁_(ℓ=1)^i φ(Q^⊤_i K_ℓ)) 
        = (∑︁_(j=1)^i φ(Q_i)^⊤φ(K_j) * V_j) / (∑︁_(ℓ=1)^i φ(Q_i)^⊤φ(K_ℓ))
        = (φ(Q_i)^⊤ ∑︁_(j=1)^i φ(K_j)V_j) / (φ(Q_i)^⊤ ∑︁_(ℓ=1)^i φ(K_ℓ)), (3)
    
  • Modern Linear Models and Their Memory Perspective. 앞서 논의한 바와 같이, 학습을 효과적이고 유용한 memory를 획득하는 process로 정의할 수 있습니다. 이를 바탕으로, Recurrent Neural Networks (RNNs)의 hidden state를 정보를 압축하려는 memory unit으로 볼 수 있습니다. 따라서 일반적인 형태의 recurrent neural network에서 hidden state는 memory unit으로 취급될 수 있고 recurrence process는 memory unit의 읽기 및 쓰기 작업으로 분할될 수 있습니다. 즉, 𝑥 ∈ R^(N × d_in)을 입력으로, M ∈ R^d를 memory unit으로, y ∈ R^d_in을 출력으로 하면 recurrent neural network의 일반적인 형태는 다음과 같이 정의됩니다.여기서 f(., .)는 읽기이고 g(., .)는 해당 쓰기 함수입니다. 여기서 M_t의 아래 첨자는 시간 t에서의 memory 상태를 나타냅니다.
  • 이러한 관점에서 linear Transformers의 recurrence 공식(식 4 참조)은 keys와 values, (K_t, V_t)를 matrix-valued memory unit M_t에 가산적으로 압축하고 쓰는 것과 같습니다. 따라서 긴 context 데이터를 처리할 때 이 process의 가산적 특성으로 인해 memory overflow가 발생하여 모델 성능이 크게 저하됩니다. 이를 해결하기 위해 연구는 두 가지 유망한 방향에 중점을 두었습니다. (1) 망각 메커니즘 추가: 여러 연구에서 필요할 때 memory를 지울 수 있는 linear models에 대한 적응형(데이터 종속적) 망각 게이트 메커니즘을 제시했습니다. 이러한 모델의 예로 GLA, LRU, Griffin, xLSTM, Mamba2를 참조합니다. 후자는 전통적인 state space models의 이산화된 버전과도 연결됩니다. (2) 쓰기 작업 개선: 전통적인 recurrent models에서 memory 쓰기 작업의 가산적 특성을 극복하기 위해 Widrow와 Hoff는 Delta Rule을 제시했습니다. 여기서 memory(즉, key와 value 쌍)를 추가하기 전에 모델은 먼저 과거 값을 제거합니다. 병렬화 가능한 training 및 scaling을 향상시키기 위해 S. Yang, B. Wang, Yu Zhang 등은 빠른 병렬화 가능한 algorithm을 제시합니다. 마지막으로, 최근 S. Yang, Kautz, Hatamizadeh는 DeltaNets에 망각 게이트를 추가하여 개선했습니다.
  • M_t = f(M_(t-1), x_t),  Write Operation (6)
    y_t = g(M_t, x_t), Read Operation (7)
    
  • Memory Modules. Memory는 항상 neural network 설계의 핵심 부분 중 하나였습니다. linear layers를 key-value (associative) memory system으로 보는 아이디어는 동적 빠른 프로그램이 쓰기 가능한 memory 역할을 하기 위해 recurrent neural networks에 통합되는 빠른 weight 프로그램으로 거슬러 올라갑니다. Hebbian과 delta의 두 가지 학습 규칙은 빠른 weight 프로그램에 가장 널리 사용되는 학습 규칙이며, 다양한 연구에서 광범위하게 탐구되었습니다. 그러나 이러한 모든 모델은 순간적인 놀라움을 기반으로 하며 sequences의 token 흐름이 누락되고(섹션 3.1 참조) 대부분 망각 게이트가 없어 memory 관리가 제대로 이루어지지 않습니다.

부록 C에서 우리 architectures와 최신 models의 연관성에 대해 더 자세히 논의합니다. 추가 관련 작업은 부록 A에서 논의됩니다.

 

 

핵심 키워드: Attention, Efficient Attentions, Linear Transformers, Recurrent Models, Memory Perspective, Memory Modules

1. 표기법 및 기본 정의:

  • x: 입력
  • M: Neural Network (Neural Memory Module)
  • Q, K, V: Attention Mechanism의 Query, Key, Value
  • S^(i): i번째 segment
  • S^(i)_j: i번째 segment의 j번째 token
  • N(x): Weight 조정이 있는 forward pass
  • N*(x): Weight 조정이 없는 forward pass
  • N^(k): Neural Network의 k번째 layer
  • 주의: t를 subscript로 사용하는 경우는 시간에 따른 recurrence index 또는 neural network의 상태를 나타냅니다.

2. Attention 및 Efficient Attentions:

  • Causal Attention (식 1, 2): Transformers의 핵심 메커니즘으로, 입력에 종속적인 Key, Value, Query matrices에 대한 softmax를 기반으로 출력을 계산합니다.
  • Efficient Attentions: Softmax Attention의 memory consumption과 throughput 개선을 위한 다양한 방법들이 연구되었습니다.
  • Linear Attentions (식 3): 본 논문은 Softmax를 kernel function φ(., .)로 대체하는 Linear Attentions에 중점을 둡니다. (φ(x, y) = φ(x)φ(y))
  • Recurrent Format (식 4, 5): Kernel을 identity matrix로 선택하면 Linear Attentions를 recurrent 형태로 표현하여 효율적인 추론이 가능합니다.

3. Modern Linear Models 및 Memory Perspective:

  • 학습 = 효과적인 Memory 획득: 본 논문은 학습을 효과적이고 유용한 memory를 획득하는 process로 정의합니다.
  • RNN의 Hidden State = Memory Unit: RNN의 hidden state를 정보를 압축하는 memory unit으로 간주합니다.
  • Recurrent Process = Read & Write (식 6, 7): 일반적인 RNN에서 recurrent process를 memory unit에 대한 Read & Write 작업으로 분리합니다.
  • Linear Transformers의 Memory: Linear Transformers의 recurrence 공식은 Key와 Value를 matrix-valued memory unit에 가산적으로 압축하여 쓰는 것과 같습니다.
  • Linear Transformers의 한계: 긴 context data를 처리할 때 가산적인 memory write 방식으로 인해 memory overflow가 발생하여 성능이 저하됩니다.
  • 두 가지 개선 방향:
    1. 망각 메커니즘 추가: 필요할 때 memory를 지우는 adaptive forgetting gate mechanism (GLA, LRU, Griffin, xLSTM, Mamba2 등)
    2. 쓰기 작업 개선: Delta Rule (과거 값을 제거한 후 memory 추가), 병렬화 가능한 algorithm, 망각 게이트를 추가한 DeltaNets 등

4. Memory Modules:

  • Associative Memory로서의 Linear Layers: Linear layers를 key-value (associative) memory system으로 보는 아이디어는 Fast Weight Programs에서 비롯되었습니다.
  • Hebbian & Delta Learning Rules: Fast Weight Programs에 널리 사용되는 두 가지 학습 규칙입니다.
  • 기존 Memory Modules의 한계: 대부분 순간적인 "surprise"에 기반하고 sequence의 token flow를 고려하지 않으며, 망각 게이트가 없어 memory 관리가 비효율적입니다.

정리 노트 요약:

  • 본 섹션은 논문에서 사용되는 표기법과 Attention, Efficient Attentions, Linear Transformers, Recurrent Models에 대한 기본적인 배경 지식을 제공합니다.
  • 핵심은 학습을 "효과적인 memory 획득"으로 정의하고, RNN의 hidden state를 "memory unit"으로 간주하여 Linear Transformers를 memory 관점에서 분석한다는 것입니다.
  • Linear Transformers의 한계를 지적하고, 이를 개선하기 위한 망각 메커니즘과 쓰기 작업 개선 방향을 제시합니다. 특히, memory 관리가 비효율적인 기존 memory modules의 한계를 지적하며, 이는 다음 섹션에서 제안될 새로운 memory module의 동기가 됩니다.

이 섹션은 Linear Transformers를 memory 관점에서 재해석하고, 기존 연구의 한계를 명확히 함으로써, 새로운 neural long-term memory module의 필요성을 강조하는 데 중요한 역할을 합니다.

 

 

 

 

 

3 Learning to Memorize at Test Time

long-term memory가 부족한 문제를 극복하고 모델이 정보를 학습, 망각, 검색할 수 있도록 하기 위해, 이 섹션에서는 test time에 memorize하는 법을 배우는 meta model인 neural long-term memory module을 제시합니다. 섹션 3.1에서는 먼저 neural memory의 동기와 설계를 논의합니다. 섹션 3.2에서는 우리 architecture 설계가 빠르고 병렬화 가능한 training의 이점을 어떻게 활용할 수 있는지 논의합니다. 마지막으로, 섹션 3.3에서는 persistent memory module을 사용하여 architecture를 보강합니다. 여기서 우리는 학습 가능하지만 데이터 독립적인 parameters를 사용하여 task에 대한 meta 정보를 학습합니다.

3.1 Long-term Memory

neural long-term memory module을 설계하려면 과거 history의 추상화를 그 parameters에 encode할 수 있는 모델이 필요합니다. 이에 대한 예는 training data를 memorize하는 것으로 나타난 LLMs일 수 있습니다. 따라서 간단한 아이디어는 neural network를 train하고 training data를 memorize하기를 기대하는 것입니다. 그러나 memorization은 거의 항상 neural networks에서 바람직하지 않은 현상으로 알려져 있습니다. 모델 generalization을 제한하고, 개인 정보 보호 문제를 야기하며, 결과적으로 test time에 성능이 저하되기 때문입니다. 더욱이, training data의 memorization은 data가 out-of-distribution일 수 있는 test time에는 도움이 되지 않을 수 있습니다. 우리는 test time에 데이터를 memorize/forget하는 방법을 배우는 online meta-model이 필요하다고 주장합니다. 이 설정에서 모델은 memorization이 가능한 함수를 학습하지만 training data에 overfitting되지 않아 test time에 더 나은 generalization을 제공합니다.

  • Learning Process and Surprise Metric. long-term memory를 train하기 위한 핵심 아이디어는 그 training을 online learning 문제로 취급하는 것입니다. 여기서 우리는 과거 정보 x1, ..., x(t-1)을 long-term neural memory module Mt의 parameters에 압축하는 것을 목표로 합니다. 앞에서 논의한 바와 같이, 예상을 위반하는 (즉, 놀라운) 사건은 인간에게 더 기억에 남습니다. 이로부터 영감을 받아, 모델에 대한 놀라움의 간단한 정의는 입력에 대한 gradient일 수 있습니다. gradient가 클수록 입력 data가 과거 data와 더 다릅니다. 따라서 이 surprise score를 사용하여 다음과 같이 memory를 update할 수 있습니다.그러나 이 surprise metric은 큰 놀라운 순간 이후에 오는 중요한 정보를 놓칠 수 있습니다. 즉, 여러 번의 놀라운 단계 후에 gradient가 극도로 작아져 평평한 영역(즉, local minima)에 고착되고 sequence의 일부 부분에 대한 정보가 누락될 수 있습니다. 인간의 memory 관점에서 보면, 사건이 기억에 남더라도 장기간에 걸쳐 지속적으로 우리를 놀라게 하지 않을 수 있습니다. 그 이유는 초기 순간이 장기간에 걸쳐 우리의 주의를 끌기에 충분히 놀랍기 때문에 전체 기간을 memorize하게 되기 때문입니다. 위의 surprise metric(식 8)을 개선하기 위해, 우리는 surprise metric을 (1) 매우 최근 과거의 놀라움의 양을 측정하는 과거 놀라움; (2) 들어오는 data의 놀라움을 측정하는 순간적 놀라움으로 나눕니다.흥미롭게도 이 공식은 S_t가 momentum 요소인 momentum을 사용한 gradient descent와 유사합니다. 따라서 여기서 momentum은 시간(sequence length)에 따른 놀라움의 memory 역할을 합니다. 이 공식에서 η_t 항은 data-dependent surprise decay(x_t의 함수)로, 시간에 따른 놀라움의 decay를 제어하고, θ_t 항은 순간적 놀라움이 data-dependent 방식으로 최종 surprise metric에 얼마나 통합되어야 하는지를 제어합니다. 이 data-dependency는 이 설계에서 특히 중요합니다. 이전 tokens의 놀라움이 다음 token의 놀라움에 영향을 미치기 위해 필요할 수 있지만, 이는 모든 tokens가 관련되고 동일한 context에 있는 경우에만 유효합니다. 따라서 data-dependent η는 memory가 (1) η_t를 0으로 설정하여 마지막 놀라움을 무시해야 하는지(context 변경으로 인해 가능), 또는 (2) η_t를 1로 설정하여 마지막 놀라움을 완전히 통합해야 하는지(token이 최근 과거 tokens와 관련성이 높기 때문에 가능)를 제어할 수 있습니다.
  • M_t = M_(t-1) + S_t, (9)
    S_t = η_t S_(t-1)   -   θ_t ∇ℓ(M_(t-1); x_t). (10)
          |{z}            | {z }
        Past Surprise    Momentary Surprise
    
  • M_t = M_(t-1) - θ_t ∇ℓ(M_(t-1); x_t)  (8)
            | {z }
            Surprise
    
  • Objective. 위의 surprise metric은 loss function ℓ(.; .)을 기반으로 하며, 이는 우리 memory가 test time에 작동하도록 학습하는 objective입니다. 즉, 우리 memory module은 loss function ℓ(.; .)을 기반으로 함수를 학습하는 meta model입니다. 이 연구에서 우리는 과거 data를 key와 value 쌍으로 저장하는 것을 목표로 하는 associative memory에 중점을 둡니다. x_t가 주어지면, Transformers와 유사하게 두 개의 linear layers를 사용하여 x_t를 key와 value로 투영합니다.여기서 W_K와 W_V ∈ R^(d_in × d_in)입니다. 다음으로, 우리 memory module이 keys와 values 간의 associations을 학습하기를 기대합니다. 이를 위해 다음과 같이 loss를 정의합니다.meta model(memory)의 inner-loop에서 위의 loss function을 최적화함으로써 모델은 test time에 keys와 values 간의 mapping을 memorize하는 방법을 학습합니다. meta-learning models와 유사하게 memory의 training은 inner-loop에 있으므로 parameters W_K와 W_V는 위의 loss function에서 hyperparameters입니다. 따라서 inner loop에서는 M의 weights를 최적화하고, outer-loop에서는 전체 architecture의 다른 parameters를 최적화합니다.
  • ℓ(M_(t-1); x_t) = ∥M_(t-1)(k_t) - v_t∥^2_2 (12)
    
  • k_t = x_t W_K, v_t = x_t W_V, (11)
    
  • Forgetting Mechanism. 매우 큰 sequences(예: 수백만 개의 tokens)를 처리할 때는 deep 또는 매우 큰 matrix-valued memory를 사용하더라도 어떤 과거 정보를 잊어야 하는지 관리하는 것이 중요합니다. 이를 위해 우리는 memory가 더 이상 필요하지 않은 정보를 잊을 수 있도록 하는 adaptive forgetting mechanism을 사용하여 memory의 제한된 용량을 더 잘 관리합니다. 즉, 다음 token x_t가 주어지면 update 규칙을 다음과 같이 수정합니다.여기서 α_t ∈ [0, 1]은 memory를 유연하게 제어하는 gating mechanism입니다. 즉, 얼마나 많은 정보를 잊어야 하는지 결정합니다. 예를 들어, α_t를 0으로 설정하여 과거 추상화에 영향을 주지 않으면서 memory를 update할 수 있고, α_t를 1로 설정하여 전체 memory를 지울 수 있습니다. 이 섹션의 뒷부분에서 이 weight decay mechanism이 최신 RNNs의 gating mechanism과 밀접한 관련이 있음을 보여줍니다.
  • M_t = (1 - α_t)M_(t-1) + S_t, (13)
    S_t = η_t S_(t-1) - θ_t ∇ℓ(M_(t-1); x_t), (14)
    
  • Memory Architecture. 이 논문에서는 long-term memory의 architecture로 L_M ≥ 1 layers를 가진 간단한 MLPs에 중점을 둡니다. 이러한 선택의 주된 이유는 long-term memory 설계와 architecture에 통합될 수 있는 방법을 더 잘 동기 부여하는 데 중점을 두고 싶기 때문입니다. 그러나 우리의 공식과 architectural design은 data memorization에 더 효과적이고 효율적인 neural architectures를 설계하는 새로운 연구 방향을 엽니다. 최근에는 이러한 architectures를 설계하는 유망한 연구 라인이 있습니다. 이러한 architectures를 우리 framework에 통합하는 것(즉, 간단한 MLPs를 이러한 architectures로 대체하는 것)은 흥미로운 미래 연구가 될 수 있습니다.
  • vector-valued 또는 matrix-valued memory를 사용할 때 memory module은 과거 data를 압축하여 한 줄에 맞춥니다. 즉, meta learning 또는 online learning 관점에서 matrix-valued memory M = W ∈ R^(d_in × d_in)을 사용하는 것은 ℓ(W_(t-1); x_t) = ∥W_(t-1)k_t - v_t∥^2_2를 최적화하는 것과 같으며, 이는 online linear regression objective이므로 최적의 솔루션은 과거 data의 기본 종속성이 linear라고 가정합니다. 반면에, 우리는 deep memory modules(즉, L_M ≥ 2)이 더 효과적이라고 주장합니다. 적어도 두 개의 layers를 가진 MLPs가 linear models보다 엄격하게 더 표현력이 있다는 이론적 결과와 일치합니다. 섹션 5.5에서 우리는 deep memory modules이 실제로 더 효과적임을 보여줍니다.
  • Retrieving a Memory. 위에서 우리는 test time에 memorize하도록 학습하는 long-term memory module을 어떻게 설계하고 train할 수 있는지 논의했습니다. 남은 주요 질문은 memory에서 어떻게 정보를 검색할 수 있습니까? 우리는 단순히 weight update가 없는 forward pass(즉, 추론)를 사용하여 query에 해당하는 memory를 검색합니다. 공식적으로, 입력 x_t가 주어지면 linear layer W_Q를 사용하여 입력을 투영, 즉 q_t = x_t W_Q하고 다음과 같이 memory y_t에서 해당 (또는 유용한) 정보를 검색합니다.
  • y_t = M*(q_t). (15)
    

3.2 How to Parallelize the Long-term Memory Training

위에서 논의한 바와 같이, long-term memory module의 설계는 momentum과 weight decay를 사용한 gradient descent를 사용하여 associative memory loss function ℓ(M_(t-1); x_t) = ∥M_(t-1)(k_t) - v_t∥^2_2를 최적화하여 meta model을 training하는 것과 같습니다. 따라서 이론적으로 long-term memory module의 training에는 O(N) FLOPs가 필요합니다. 여기서 N은 sequence length입니다. 그러나 실제로는 training process를 병렬화해야 하고 하드웨어 가속기(예: TPUs, GPUs)를 최대한 활용하려면 process를 tensor화하고 더 많은 matmuls를 사용해야 합니다.

다음으로, mini-batch gradient descent, data-dependent learning rate, weight decay를 사용한 inner loop의 weights 계산이 matmuls와 sum만 사용하도록 재구성될 수 있음을 보여줍니다. 우리는 mini-batch gradient descent(상수 learning rate 사용)로 최적화하는 모델의 forward pass가 matmuls를 사용하여 계산될 수 있음을 보여주는 Yu Sun 등의 연구를 기반으로 합니다. sequence를 크기 b ≥ 1의 chunks로 분할하고 mini-batch gradient descent를 다음과 같이 작성할 수 있습니다.

M_t = (1 - α_t)M_(t-1) - θ_t∇ℓ(M_(t-1); x_t) = β_t M_0 - ∑︁_(i=1)^t (θ_i / (β_t / β_i)) ∇ℓ(M_t'; x_i), (16)

여기서 t' = t - mod(t, b)이고 β_i = ∏︁_(j=1)^i (1 - α_j)입니다. 단순화를 위해 첫 번째 chunk, 즉 t = b에 초점을 맞추므로 t' = 0입니다. 또한 M_t = W_t가 linear인 경우에 대한 process를 설명합니다. N_p ≥ 2인 MLPs에 대한 process는 유사합니다. loss function을 사용하면 다음과 같습니다.

∇ℓ(W_0; x_t) = (W_0 x_t - x_t)x^⊤_t ⇒ ∑︁_(i=1)^b (θ_i / (β_b / β_i)) ∇ℓ(W_0; x_i) = Θ_b B_b (W_0 X - X)X^⊤, (17)

여기서 Θ_b = diag([θ_1 θ_2 . . . θ_b])이고 B_b는 β_b / β_i에 대해 유사하게 정의됩니다. k = 1, ..., N/b에 대해 모든 Θ_kb와 B_kb를 저장할 필요가 없으며, 대신 각 chunk에 대해 이러한 matrices를 저장하여 memory를 덜 사용합니다. 다음으로, momentum 항을 통합할 수 있도록 이 표현을 확장합니다. momentum을 사용한 chunk 단위 gradient descent에서 momentum 항을 보면 다음과 같습니다.

S_t = η_t S_(t-1) - θ_t u_t, (18)

여기서 u_t = ∇ℓ(M_t'; x_t)입니다. 모든 u_t를 동시에 계산할 수 있으므로 식 18은 u_t를 입력으로, S_t를 hidden state로, η_t를 입력 종속 전이 값으로 하는 linear recurrence입니다. 따라서 parallel associative scan을 사용하여 이 chunk에서 S_t를 계산할 수 있습니다.

  • Parameters as the Function of Chunks. α_t, θ_t, η_t와 같은 parameters를 입력 종속(즉, token x_t의 함수)으로 만드는 대신 chunk의 함수로 만들 수 있습니다. 표현력을 잃는 것에도 불구하고 이 공식은 training을 더 빠르게 만드는 데 도움이 될 수 있습니다. 이 경우 각 chunk에서 α, θ, η 각각에 대해 동일한 값을 사용합니다. 따라서 식 17에서 Θ를 단일 스칼라를 사용하여 저장할 수 있습니다. 유사하게 식 18을 더 빠르게 만들 수 있습니다. 즉, η와 θ가 각 chunk 내에서 학습 가능하지만 시간 불변일 때 이 방정식은 global convolution으로 계산될 수 있는 linear time-invariant system (LTI)이 됩니다. 실험에서는 이러한 parameters를 tokens의 함수로 만듭니다. 그러나 이러한 단순화(즉, chunks의 함수로서)는 더 큰 models을 더 효율적인 방식으로 training하기 위한 미래 연구의 관심사가 될 수 있습니다.

3.3 Persistent Memory

우리의 long-term memory는 contextual memory로도 볼 수 있습니다. 즉, output이 context에 전적으로 의존합니다. 따라서 long-term memory 외에도 학습 가능하지만 입력 독립적인 parameters 집합을 사용하여 task 관련 memory 역할을 합니다. 이러한 유형의 memory는 문헌에서 persistent 또는 meta-memory로 언급되었습니다. N_p ≥ 1이 주어지면 학습 가능한 parameters P = [p_1 p_2 . . . p_Np]를 사용하고 이를 sequence의 시작 부분에 추가합니다. 즉, context window size N이 주어지면 다음과 같이 입력을 수정합니다.

x_new = [p_1 p_2 . . . p_Np] || x, (19)

여기서 ||는 concatenation입니다. 다음으로, 세 가지 관점에서 persistent memory의 동기를 논의합니다.

  • Memory Perspective. 앞서 논의한 바와 같이, 우리의 neural long-term memory는 모든 parameters가 입력 종속적인 contextual memory입니다. 그러나 효과적인 memory system은 task 지식의 추상화를 저장하기 위해 입력 독립적인 parameters도 필요합니다. 즉, task를 마스터하려면 task를 수행하는 방법에 대한 지식을 memorize해야 하며 이러한 parameters는 그러한 지식을 저장하는 역할을 합니다.
  • Feedforward Network Perspective. Transformer architectures에는 attention module 뒤에 fully connected layers가 있으며, 이는 data-independent parameters를 가진 attention weights와 유사한 것으로 나타났습니다. 즉, Sukhbaatar, Grave 등은 fully connected layers의 ReLU를 Softmax로 대체하면 weights가 data-independent인 attention과 유사한 weights가 생성될 수 있음을 보여주었습니다.실제로 W_K와 W_V는 입력 독립적일 때 attention module의 K와 V matrices와 유사하게 작동합니다. persistent memory weights는 동일한 기능을 가질 것으로 예상됩니다. 즉, sequence의 첫 부분에서 이를 사용하면 입력 독립적인 attention weights를 갖게 됩니다.
  • FFN(x) = W_V Softmax(W_K x). (20)
    
  • Technical Perspective. causal mask가 있는 attention은 sequence의 초기 tokens에 대한 암시적 편향을 가지므로 attention weights는 거의 항상 초기 tokens에 대해 매우 활성화되어 성능이 저하됩니다. 기술적 관점에서 sequence 시작 부분의 이러한 학습 가능한 parameters는 attention weights를 보다 효과적으로 재분배하여 이러한 효과를 완화할 수 있습니다.

 

핵심 키워드: Long-term Memory, Neural Memory, Memorization, Surprise Metric, Meta-model, Online Learning, Forgetting Mechanism, Parallelization, Persistent Memory

기존 연구의 한계:

  • 기존 models은 long-term memory가 부족하여 정보를 효과적으로 학습, 망각, 검색하는 데 어려움이 있었습니다.
  • Memorization은 generalization 저하, 개인 정보 문제 등을 야기하는 바람직하지 않은 현상으로 여겨졌습니다.
  • Test time에 data가 out-of-distribution일 경우 training data memorization은 도움이 되지 않을 수 있습니다.

본 논문의 핵심 아이디어 및 기여:

  1. Neural Long-term Memory Module 제안:
    • Test time에 memorize하는 법을 배우는 meta-model인 neural long-term memory module을 제안합니다.
    • Online learning 관점에서 과거 정보를 memory module의 parameters에 압축합니다.
  2. Surprise Metric 기반 Memory Update:
    • 인간의 기억 원리에서 영감을 받아 "Surprising"한 사건을 더 잘 기억하도록 설계했습니다.
    • 입력에 대한 모델의 gradient를 "Surprise"의 척도로 사용합니다. (식 8)
    • "Past Surprise"와 "Momentary Surprise"를 결합하여 더 정교한 Surprise Metric을 제안합니다. (식 9, 10)
      • Momentum을 사용하여 시간에 따른 Surprise의 기억을 유지합니다.
      • Data-dependent surprise decay term η_t를 사용하여 context 변화에 따라 과거 surprise의 영향력을 조절합니다.
  3. Associative Memory Loss를 사용한 Memory Training:
    • Memory module은 test time에 key-value associations을 학습하는 meta-model입니다.
    • Associative Memory Loss (식 12)를 사용하여 memory를 train합니다.
    • Inner-loop에서는 memory (M)의 weights를 최적화하고, outer-loop에서는 전체 architecture의 다른 parameters를 최적화합니다.
  4. Forgetting Mechanism:
    • 제한된 memory 용량을 효율적으로 관리하기 위해 adaptive forgetting mechanism을 도입합니다. (식 13, 14)
    • Gating mechanism α_t를 사용하여 memory를 유연하게 제어하고, 필요 없는 정보를 망각합니다.
  5. Deep Memory Modules:
    • Long-term memory의 architecture로 L_M ≥ 1 layers를 가진 MLPs를 사용합니다.
    • Deep memory modules (L_M ≥ 2)이 linear models보다 더 효과적임을 실험적으로 검증합니다. (섹션 5.5)
  6. Memory Retrieval:
    • Weight update가 없는 forward pass (inference)를 통해 query에 해당하는 memory를 검색합니다. (식 15)
  7. Parallelizable Training:
    • Long-term memory module training은 이론적으로 O(N) FLOPs가 필요하지만, 실제로는 병렬화가 중요합니다.
    • Mini-batch gradient descent, data-dependent learning rate, weight decay를 사용한 inner loop 계산을 matmuls와 sum만 사용하도록 재구성하여 병렬화를 가능하게 합니다. (식 16, 17, 18)
    • Parameters를 chunks의 함수로 만들어 training을 더욱 가속화할 수 있는 방법을 제시합니다.
  8. Persistent Memory:
    • Contextual memory인 long-term memory 외에도, task 관련 memory 역할을 하는 learnable, input-independent parameters인 Persistent Memory를 사용합니다. (식 19)
    • Persistent Memory는 task knowledge 저장, feedforward network와의 유사성, attention weights의 효율적인 재분배 등 여러 관점에서 그 유용성이 설명됩니다.

정리 노트 요약:

  • 본 논문은 test time에 memorize하는 법을 배우는 neural long-term memory module을 제안합니다.
  • "Surprise" 기반 memory update, adaptive forgetting mechanism, associative memory loss를 사용한 training 등 새로운 기법들을 제시합니다.
  • Deep memory modules의 효과를 보이고, 효율적인 parallelizable training 방법을 제안합니다.
  • Contextual memory와 Persistent Memory를 결합하여 더욱 강력한 memory system을 구축합니다.