Reformer: The Efficient Transformer

Inc Lomin's avatar
Mar 03, 2021
Reformer: The Efficient Transformer
본 논문은 구글 리서치에서 아카이브에 2020년 1월에 공개한 연구로서, ICLR 2020 오럴 발표에 선정되었습니다.
N. Kitaev, Ł. Kaiser, and A. Levskaya, “Reformer: The Efficient Transformer,” in ICLR, 2020, pp. 1–11.
notion image
제목에서 나타난 것처럼, 본 논문은 기존의 트랜스포머 구조를 효율적으로 구현하기 위한 방안을 제시합니다. 트랜스포머가 소개된 이후로, 시퀀스를 분석하거나 예측하는 태스크를 포함하여 다양한 분야에서 획기적인 성능 개선을 보여주는 연구결과들이 나왔습니다. CNN에 기반한 연구에서와 마찬가지로 트랜스포머를 이용한 연구에서도 입력시퀀스 길이, 레이어수, feature dimension이 늘어날 수록 성능이 개선되는 효과를 확인할 수 있었습니다. 하지만, 문제는 점점 모델의 사이즈가 커짐에 따라 학습에 필요한 GPU 메모리가 늘어나서, 현재까지 SOTA 성능을 달성하는 연구에 사용된 모델들은 distributed 환경에서 대규모의 GPU를 이용하여 학습이 가능한 일부 연구기관 및 기업에서만 학습이 가능한 상황이 되었습니다.
본 논문에서는 트랜스포머 구조에서 메모리가 크게 사용되는 원인을 분석하고, 이 문제를 해결하기 위한 방안 3가지를 제시합니다.

Introduction

트랜스포머의 메모리 소모량 분석

다음과 같은 설정에서 학습에 필요한 메모리량을 계산해 봅시다.
  • 0.5B parameters per layer: 현재까지 보고된 가장 큰 모델 경우, 메모리사용량 2GB
  • Activations for 64K tokens(=input sequence length) with embedding size 1024, batchsize 8: 64K x 1K x 8 = 0.5B floats, 메모리사용량 2GB
위 사용량은 오직 1개의 레이어 계산에만 필요한 메모리용량입니다. 레이어당 필요한 메모리는 4GB에 불과하지만 실제 학습시에는 아래와 같은 이유로 훨씬 많은 메모리를 필요로 합니다.
  1. N개의 레이어를 사용하는 모델은 back-propagation 계산시 필요한 activation을 저장하기 위해 N배 많은 메모리를 필요로합니다.
  1. feed-forward layer의 depth d_ff는 일반적으로 모델의 depth d_model 보다 훨씬 크기 때문에, 훨씬 많은 메모리를 사용합니다.
  1. 길이 L의 시퀀스에 대한 attention 연산은 O(L^2)의 메모리 복잡도와 연산 복잡도를 가집니다. 따라서, 64K 길이의 시퀀스만으로도 메모리를 모두 소모하게 됩니다.
 

Contribution

본 논문에서는 위에서 제시한 3가지 문제에 대한 해결책 3가지를 제시합니다.
  1. Reversible Layers: 이 방법은 아래 논문에서 제안된 방법으로, back-prop 연산시, 모델 전체에서 오직 한 개 레이어의 activation만 저장해 두면, 나머지 레이어의 activation을 연산해 낼 수 있는 방식입니다.
    1. Gomez, Aidan N., et al. "The reversible residual network: Backpropagation without storing activations." Advances in neural information processing systems. 2017.
  1. feed-forward layers의 activation을 나눠서 연산하여, 메모리 사용량을 줄입니다.
  1. attention 연산을 locality-sensitive hashing 을 이용하여, 훨씬 작은 양의 연산으로 approximate합니다. 그 결과, attention layer의 메모리 및 연산 복잡도가 O(L^2)에서 O(L logL)로 크게 줄어들어, 매우 긴 시퀀스에 대한 연산이 가능하도록 합니다.
 

Proposed Method

Locality-Sensitive Hashing(LSH) Attention

본 챕터에서는 Attention의 연산량을 줄이기 위한 방법을 제안한니다.
먼저, 기존 트랜스포머에서 사용하는 scaled dot-product attention 연산의 수식은 다음과 같습니다.
notion image
트랜스포머에서는 위와 같은 연산을 각 레이어마다 패러럴하게 여러 개를 동시에 수행합니다.(=Multi-headed attention)
수식 (1)의 연산에 대한 메모리 사용량을 살펴봅시다.
  • Q, K, V 의 shape: [batch_size, length, d_model]
  • 메인이슈는 QK^T term 입니다. shape [batch_size, length, length]
  • length=64K 인 경우, 32bit float형의 64K x 64K 행렬은 16GB 메모리를 차지합니다.
하지만, 행렬 Q 대신 하나의 쿼리벡터 q_i를 위한 연산은 훨씬 작은 메모리를 사용합니다. 따라서, 각 쿼리벡터마다 attention을 따로 계산하고, backward path에서 필요한 attention을 다시 계산하여 사용한다면, 연산 측면에서는 비효율적이지만, 메모리 사용량은 크게 줄어들게 됩니다. → 제안 논문에서는 이 방식을 full-attention baseline으로 사용합니다.
Shared-QK Transformer
attention 연산에서 사용되는 Q,K,V 행렬은 본래 하나의 행렬 A으로부터 서로다른 파라미터를 projection(행렬곱) 하여 생성됩니다. 하지만, 제안방법에서는 Q 행렬과 K 행렬을 같은 값을 사용합니다.(=동일한 projection 행렬을 사용) 본 논문의 실험에서는 shared-QK의 사용이 트랜스포머의 성능에는 아무런 영향을 미치지 않음을 보여줍니다.
  • 위와 같은 결과는 Query와 Key 간의 attention 연산이 사실은 K행렬의 칼럼 벡터 중에서 입력 벡터(q_i)와 유사한 벡터를 찾는 연산(=cross-correlation)에 기반하기 때문이라고 해석할 수 있습니다. 따라서, 입력 행렬 A에서 Key와 Query로 변환하는 projection 행렬이 동일한 경우, 변환된 벡터 간의 cross-correlation을 구하는 원래의 목적에 여전히 부합합니다.
Hashing Attention
LSH attention 연산은 두 개의 텐서(Q=K, V)로 부터 시작됩니다.(shape [batch_size, length, d_model]) Attention 연산의 대부분을 차지하는 연산은 [batch_size, length, length]의 shape을 가지는 QK^T 연산입니다. 하지만, 우리가 정말 관심이 있는 연산은 이 QK^T이 softmax 함수를 통과한 결과값 입니다. softmax 함수의 결과는 입력 값 중 가장 큰 원소 값들에만 dominant하게 반응하므로, 각 쿼리벡터 q_i 와 가장 가까운 몇개(32 or 64) key 벡터들에 대해서만 attention 연산을 구하고 나머지 key 벡터와의 attention 결과는 무시해도 거의 유사한 softmax 결과 값을 얻을 수 있습니다. 따라서, 제안 방법은 시퀀스 길이의 제곱에 비례하는 attention 연산 대신에 서로 유사한 벡터임을 판별할 수 있는, (훨씬 작은 연산량을 가지는) hashing function을 사용하여, 벡터 간의 유사도를 먼저 판별하고, 유사한 벡터 간에만 attention 연산을 수행합니다.
Locality sensitive hashing
High-dimensional 공간에서 nearest neighbor를 빠르게 찾는 문제는 locality-sensitive hashing(LSH)에 의해 풀 수 있습니다.
각 입력 벡터 x에 대해, 서로 유사한 벡터끼리는 값은 값을 할당하고, 서로 거리가 있는 벡터에는 다른 값을 할당할 확률이 높은 hash 함수 h(x)를 locality-sensitive 하다고 합니다.
이러한 hashing 함수는 사이즈 [d_k, b/2]의 랜덤 행렬 R에 의해 입력 벡터를 projection 하는 연산을 통해 얻을 수 있습니다. 길이 b의 hash 벡터를 얻기 위한 hash 함수는 다음과 같습니다.
notion image
Locality sensitive hash의 원리는 아래의 그림으로 이해 할 수 있습니다. 아래 그림에서 두 점 x,y는 2차원상의 구면에 projection 된 후 hash 값에 의해 random하게 회전하게 됩니다. 이 후, singed axes projection에 대한 argmax 함수에 의해, 해당 벡터의 hash bucket index가 결정됩니다. 첫번째 row의 두 점 서로 거리가 먼 x,y는 다른 hash 값을 가지는데 비해, 서로 가까운 점 x,y는 값은 hash 값을 할당 받을 확률이 높습니다.
notion image
 
LSH attention
LSH를 transformer의 attention 연산에 적용하는 수식은 아래와 같습니다.
먼저, 기존의 attention 연산을 벡터 연산으로 표현하면 다음과 같습니다.
notion image
위 주식에서 z()는 softmax의 normalize term 이고 P는 i 번째 query가 attention을 하게 되는 index의 set을 의미 합니다. 위 수식에서 scaling term (d_k)^0.5 는 생략되었습니다.
위 수식을 배치 연산을 위한 수식으로 변경하고자, P_tilde를 새로 정의하고 mask 연산 term을 추가하면 아래와 같습니다.
notion image
notion image
이제, LSH attention 연산을 적용하면, 우리는 set P_i 를 i번째 query가 attention을 하게 되는 target item이라고 보면, 동일한 hash bucket에 속한 key-query 벡터간에만 attention 연산을 수행합니다.
notion image
동일한 hash bucket 간의 attention 연산을 효율적으로 수행하기 위해서는 동일한 hash 값을 가지는 벡터들이 서로 근처에 오도록 sorting하는 것이 좋습니다. 이와 같은 연산을 도식화 하면 아래 그림은 (a)~(b)에 해당합니다.
notion image
이 연산을 batch 단위로 적용할 때 발생하는 문제는 hash bucket간의 사이즈가 달라서, batch 연산이 어렵다는 점입니다. 이 문제를 해결하기 위해서 먼저, 값은 인덱스의 key와 query가 동일한 hash를 가지도록, k_j를 q_j의 유닛 벡터가 되도록 할당합니다.
notion image
그리고, 각 query를 bucket number에 따라 정렬하고, 또 동일한 bucket 안에서도 sequence position에 따라 정렬합니다. 이후, m개의 연속된 쿼리에 대해, 해당 블록과 직전 블록의 동일한 bucket에 해당하는 쿼리에 대해서만 attention을 수행하게 됩니다. 이를 수식으로 표현하면 다음과 같습니다.
notion image
Multi-round LSH attention
hash 함수를 사용하면, 적은 확률로 유사한 벡터가 서로 다른 hash 값을 가지될 확률이 존재합니다. 이 부분을 보완하기 위해, 동시에 여러개의 hash 함수를 사용하고, 해당 hash 함수들에 의해 동일한 hash 값을 가지게 되는 벡터들의 합집합을 attention에 활용하는 방법을 적용합니다.
notion image
Casual masking for shared-QK attention
시퀀스 연산에서 미래의 입력에 대한 attention 연산을 방지 하기 위한 masking 연산을 적용할 때, 원래의 transformer에서는 입력 query 인덱스와 동일한 key 벡터에 대한 attention 연산을 허용합니다. 하지만, shared-QK 를 사용하는 경우, 자기 자신의 대한 attention 이 항상 가장 dominant한 값을 가지게 되므로, 자신의 index에 대한 attention이 되지 않도록 마스킹을 적용합니다.
Analysis on a Synthetic Task
Full-attention 모델 대비 제안 방법에 의한 예측 성능 차이를 확인하기 위해, Synthetic dataset에 의한 Language modeling task에 대해 성능 비교 실험을 수행하였다. 길이 511의 word가 주어졌을 때, 길이 511 word를 예측하는 태스크에서, full-attention 모델은 1개의 레이어를 가지는 transformer만으로 accuracy 100%를 달성할 수 있다. 이 모델을 n_rounds 개의 LSH 함수를 사용한 모델과 성능을 비교 하면 다음과 같다.
notion image
위 결과를 보면, LSH-4 학습 모델은 거의 100%에 달하는 정확도를 달성하였고, 학습에 사용한 hash 함수의 수와는 상관없이 평가시 8개의 해쉬함수를 사용하면 거의 100%에 달하는 정확도를 달성하는 것을 확인할 수 있었다.
 

Reversible Transformer

notion image
Trasnformer 모델의 메모리 사용량은 레이어 수 n_l 에 비례한다. 또, 각 레이어의 fully-connected layer의 크기 d_ff도 많은 양의 메모리를 차지 한다. 이 두가지 요인에 의한 메모리 사용량을 줄이기 위해 아래와 같은 방법을 적용한다.
Reversible Residual Networks
Reversible Residual Network는 아래의 논문에서 제안된 방법으로, 주어진 레이어의 activation을 다음 레이어의 activation 값과 모델 파라미터로 부터 복원해 낼 수 있는 방법이다. 이 방법을 사용하면, backward pass 연산을 위해 전체 레이어의 activation 값을 저장해 두지 않고, 네트워크 아웃풋으로 부터 입력방향으로 한 레이어씩 activation 값을 복원해 나가면서 메모리를 절약할 수 있습니다.
Aidan N Gomez, Mengye Ren, Raquel Urtasun, and Roger B Grosse. The reversible residual net- work: Backpropagation without storing activations. In Advances in neural information processing systems, pp. 2214–2224, 2017.
일반적인 residual 레이어의 연산은 y = x + F(x) 인데 비해, reversible 레이어는 두개의 입력과 출력 pair (x_1, x_2) → (y_1, y_2)로 표현됩니다.
notion image
backward pass에서 y_1, y_2가 주어지면, x_1, x_2를 다음의 연산으로 복원 가능 합니다.
notion image
 
Reversible Transformer
Reversible 레이어를 transformer에 도입하는 경우, F연산을 attention layer로 변경하고, G 연산을 feedforward 연산으로 변경하여 구현이 가능합니다.
notion image
reversible transformer는 각 레이어의 activation을 저장할 필요가 없어, 메모리 사용량에서 n_l term을 제거할 수 있습니다. reversible transformer의 성능은 실험에서 일반적인 transformer와 동일한 성능을 보여줬습니다.
 
Chunking
최신의 큰 사이즈 transformer 연구에서는 d_ff = 4k 이상의 feedforward 레이어를 사용함에 따라, 이 레이어에서 사용하는 메모리량도 매우 큽니다. 하지만, feedforward 레이어의 연산에서 시퀀스의 위치별 연산은 서로 완전히 독립적이므로, 해당 연산을 c개의 chunks로 분리해서 연산을 수행함으로서 순간적인 메모리 사용량을 줄일 수 있습니다.
notion image
 
Chunking, large batches and parameter reuse
chunking과 reversible layer를 사용함에 따라, 메모리 사용량이 레이어의 숫자에 비례하는 부분을 제거할 수 있었습니다. 하지만, 모델 파라미터의 수는 여전히 레이어 수에 비례하여 증가합니다. 제안방법에서는 레이어 파라이터를 CPU 메모리에 저장하여 이러한 요인을 제거합니다. 기존의 transformer 구현에서는 CPU memory로의 메모리 전달이 느려서 이러한 구현이 매우 비효율적이었지만, 제안된 방법에서는 batchsize x sequence length 의 크기가 기존에 비해 훨씬 커짐에 따라, 각 레이어에서 수행하는 연산량의 비중도 함께 증가하여 CPU 메모리 transfer에 의한 지연이 차지하는 비중이 크게 줄어들게 됩니다.
 
notion image

Experiment

제안 방법의 성능을 확인하기 위해, imagenet64 이미지를 생성해내는 태스크와, 64K 길이의 언어 모델링(enwik8-64K) 태스크를 수행하였다.
 
notion image

Effect of sharing QK

figure 3의 좌측 그래프를 보면, Q, K 행렬을 동일하게 사용하는 경우의 성능은 원래의 transformer의 경우 동일하다. enwik8 실험에서는 오히려, shared QK 의 경우가 더 빨리 수렴하는 것을 알수 있다.

Effect of reversible layers

reversible layers를 사용하는 경우, 기존의 방식과 차이가 없는 것을 figure 3의 우측에서 확인할 수 있다.

LSH attention in Transformer

notion image
LSH attention은 full attention를 근사화한 연산이다. LSH attention에 사용되는 hash 함수의 수에 비례해서 연산량이 증가하는데, 위 그림에서와 같이 hash 함수의 숫자가 증가할 수록 점점 full attention의 성능에 근접하는 것을 확인할 수 있다. Table2 에서 알수 있듯이 hash 함수의 수는 test time에서만 늘려서 사용할 수도 있다.
아래 그림5의 우측 그래프는 동일한 sequence length x batch 크기를 유지하면서, sequence length를 점점 증가시킬때(=batch size는 감소) 기존의 attention 연산은 점점 속도가 느려지는데 비해, LSH attention은 거의 동일한 연산속도를 유지하는 것을 확인 할 수 있다.
notion image

Large Reformer models

제안된 Reformer가 large 모델이 실제로 단일 코어의 메모리에서 학습될 수 있는지를 확인하였다. figure 5의 좌측 실험에서 20개의 레이어의 모델이 하나의 gpu에서 학습되었다. 이중 12-layer Reformer가 1.05 bits/dim을 달성하였다.

Conclusions

제안된 방법에 의해 앞으로 기존보다 훨씬 큰 사이즈의 transformer 모델도 구현이 가능하게 되었다. 본 연구에 의해 향후 텍스트 위주의 sequence 모델링 태스크 뿐만이 아니라, 음악, 이미지, 비디오 generation 등의 연구에도 transformer가 활용되어 질 것을 기대한다.
 
Share article