COLT5: Faster Long-Range Transformers with Conditional Computation
이 논문은 LONGT5를 기반으로 어텐션 및 피드포워드 레이어에 대한 구조를 개선하여 긴 입력을 빠르게 처리할 수 있는 새로운 모델인 COLT5를 제안합니다.
Jun 22, 2023
Introduction
많은 자연어 처리 작업은 긴 텍스트를 인코딩하는 데 필요하며, 트랜스포머 모델을 사용하여 긴 문서를 처리하는 것은 계산 비용이 매우 많이 듭니다.
효율적인 트랜스포머 접근법이 제안되어왔지만, 대형 모델의 경우 feedforward 및 projection 레이어가 계산 부담의 대부분을 차지하여 긴 입력을 처리하기 어렵습니다.
이 논문은 LONGT5를 기반으로 어텐션 및 피드포워드 레이어에 대한 구조를 개선하여 긴 입력을 빠르게 처리할 수 있는 새로운 모델인 COLT5를 제안합니다.
COLT5는 일부 토큰이 다른 토큰보다 중요하다는 직관에 기반하여 중요한 토큰에 더 많은 계산을 할당함으로써 더 낮은 비용으로 더 나은 성능을 달성할 수 있다는 것을 이용합니다.
또한, 중요한 토큰의 비율은 문서 길이가 길어짐에 따라 감소하므로 긴 문서를 처리하는 것이 가능합니다.
COLT5는 각 피드포워드 레이어 및 어텐션 레이어를 모두 모든 토큰에 적용하는 Light 브랜치와 중요한 토큰 집합에 적용되는 Heavy 브랜치로 분할하여 처리하는 방식을 사용합니다.
결과적으로 COLT5는 긴 입력(64k 토큰) 작업에서 더 나은 품질과 속도 향상을 달성합니다.
Background
Transformer FLOPs
- Transformer 인코더 레이어의 각 구성 요소의 FLOPs
n은 입력 길이
d는 모델 차원
w는 로컬 어텐션 윈도우의 크기
Sparse attention
Transformer를 긴 입력에 적용하는 첫 번째 문제는 자기 어텐션 메커니즘의 FLOPs가 입력 길이의 제곱에 비례하여 급격하게 증가하며, 긴 입력에 대해 처리하기 어려워진다는 것입니다.
sparse attention 방식
- 입력의 하위 집합 사이에서 어텐션을 제한
- 어텐션을 일부 레이어에만 적용
- LONGT5 :
- Local Attention : 입력의 각 16 토큰 블록마다 로컬 창 안에서 어텐션을 수행
- TGlobal Attention : global token들은 각각의 token들을 블록으로 묶고 해당 블록 내의 모든 token embedding들을 합한 결과를 Layer Normalization하여 global token생성, Input Tokens와 Global Tokens 에 대해 어텐션을 수행
Conditional computation
sparse attention 메커니즘을 적용한 후, 피드포워드 및 어텐션 프로젝션 레이어가 대부분의 FLOPs를 차지합니다.
이 비용은 입력의 길이와 비례하여 증가하기 때문에, 긴 입력을 처리하는 것은 여전히 inference 시간이 많이 소요됩니다.
나머지 비용을 줄이기 위한 일반적인 방법은 조건부 계산을 사용하여 모든 모델 매개변수를 전체 입력에 적용하지 않는 것입니다.
Device utilization
텍스트의 경우, autoregressive decoder 추론은 반복적으로 긴 시퀀스의 키와 값을 로드하는 메모리 대역폭 제약 때문에 매우 느립니다.
Multi-query Attention(MQA)을 통해 키와 값을 공유하는 헤드를 사용하여 메모리 대역폭 오버헤드를 줄입니다.
Training objectives
- T5의 span corruption objective(SCO)
- LONGT5의 PEGASUS: 문서 내에서의 중요 문장을 찾아내고 이를 masking한 후에 예측
- COLT5 모델은 UL2 방식을 사용
Proposed Method (COLT5)
3.1 Conditional computation
이전 섹션에서 설명한 대로, Transformer FLOPs의 상당 부분은 입력 시퀀스의 길이에 비례하여 확장되는 feedforward 및 projection 레이어에서 비롯됩니다.
COLT5는 일부 토큰에 대해서만 더 많은 계산을 하고 중요하지 않는 토큰에 대해 sparse하게 계산하여 처리 비용을 줄였습니다.
COLT5의 Conditional computation 메커니즘은 라우팅 모듈, Conditional feedforward 레이어 및 Conditional attention 레이어로 구성됩니다.
Routing
(1) 입력에 학습 된 임베딩을 곱하여 라우팅 점수를 얻고,
(2) 상위 k개의 최고 점수 입력을 선택하는 두 단계 메커니즘
Xi : 토큰 i의 representation
u : d-차원의 학습 가능한 임베딩
라우팅 점수에 softmax를 사용하여 전체적으로 라우팅 점수를 정규화하여 s˜i를 얻음
각 COLT5 레이어에는 feedforward 레이어, attention queries 및 attention key-values 각각에 대한 독립적인 라우터가 존재
Conditional Feedforward
Conditional Feedforward 레이어는 선택된 토큰에 대해 heavy feedforward layer를 추가로 적용시킵니다.
Xi를 i번째 토큰의 모델 상태, s˜i를 정규화된 라우팅 점수라고 하면, COLT5의 피드포워드 연산은 다음과 같습니다.
(상위 k개에 포함되지 않는 토큰은 0으로 설정)
light와 heavy Feedforward 브랜치는 hidden dimension만 다릅니다.
light 브랜치는 기본 T5 피드포워드 레이어보다 작은 hidden dimension을 가지고 있으며,
heavy 브랜치는 더 큰 hidden dimension을 가지고 있습니다.
n : 입력 토큰 수
m : 선택된 토큰 수
rL, rH : light와 heavy hidden dimension과 T5 hidden dimension에 대한 비율
COLT5 레이어의 FLOPs는 다음과 같습니다:
light와 heavy dimension 비율을 각각 1/2와 4로 설정하였습니다.
토큰의 1/16이 heavy 브랜치로 라우팅되도록 설정했습니다.
결과적으로, COLT5 피드포워드 레이어의 FLOPs는 다음과 같습니다.
이는 기본 T5 feedforward 레이어의 75%를 사용합니다.
Conditional Attention
COLT5 Conditional Attention 레이어는 선택된 쿼리 토큰에서 선택된 키-값 토큰으로 어텐션하는 고용량 어텐션 레이어를 추가로 적용합니다.
Xi를 i번째 토큰의 모델 상태, s˜i를 정규화된 라우팅 점수라고 하면, COLT5의 attention 연산은 다음과 같습니다.
(상위 k개에 포함되지 않는 토큰은 0으로 설정)
COLT5 Conditional Attention의 light branch와 heavy branch는 head 수와 attention target 토큰 수에서 차이가 있습니다.
light branch는 더 적은 head를 가지고 local context window attention을 진행합니다.
heavy branch는 더 많은 head를 가지고 라우팅된 key-value 토큰을 어텐션합니다.
q, v : 선택된 쿼리 및 키-값 토큰의 수
w : 로컬 어텐션 창의 크기
rL, rH : T5에 대한 head의 비율
COLT5 attention layer의 FLOPs는 다음과 같습니다.
head 비율을 각각 1/4, 3/4로 설정
선택된 쿼리 토큰과 키-값 토큰의 비율을 1/16, 1/8로 설정
이를 토대로 COLT5 총 계산은 다음과 같습니다.
질의응답 데이터셋 : TriviaQA (TQA), NarrativeQA (NQA), QASPER (QAS), QuALITY (QuAL), NLI 데이터셋 ContractNLI (CNLI)
요약 데이터셋 : arXiv, SummScreenFD (SumS), QMSum (QMS), GovReport (GovR)가 포함된다. SCROLLS 결과는 COLT5-XL이 SOTA를 달성한 리더보드 테스트 세트에서의 결과이다
평균 속도는 추론(inference)과 fine-tuning (fn)을 위한 초당 샘플 수
Rgm은 ROUGE-1,2,L의 기하학적 평균
3.2 Multi-query Attention
Conditional computation은 인코더의 계산 비용을 줄이는 데 효과적입니다.
그러나 인코더-디코더 모델의 경우 입력이 길 경우 디코더에서 대부분의 추론 시간이 소비되는데, 이는 반복적으로 긴 key와 value 시퀀스를 로드하기 때문입니다.
따라서 MQA를 cross-attention 레이어에 적용하여 훨씬 빠른 추론을 수행합니다.
Multi-query Attention
- 서로 다른 헤드가 단일 키와 값 집합을 공유
- K, V "heads" 차원을 나타내는 문자 "h"를 tf.einsum 방정식에서 제거
3.3 UL2
COLT5 모델을 UL2 pre-training objective로 학습시켰습니다.
UL2는 다양한 denoising objectives를 결합하여 in-context learning을 향상시키는 데 도움이 되는 것으로 알려져 있습니다.
- X-denoising(extreme denoising): 극단적인 span 길이와 corruption rate. 약 50%가 마스크 처리 된 공격적인 노이즈
- S-denoising(sequential denoising): 타겟이 미래 정보를 의존하지 않도록 하는 노이즈 방식
- R-denoising(regular denoising): T5에서 소개된 일반적인 방식의 span corruption 스팬 길이 2~5개의 토큰 범위를 사용하여 입력 토큰의 약 15%를 마스크 처리
Experiments
4.1 Experimental setup
Configurations
T5 아키텍처를 기반
LONGT5와 마찬가지로 Base, Large 및 XL 모델 크기로 실험
COLT5 모델은 동일한 임베딩 차원, 레이어 수 및 전체 어텐션 헤드 수를 갖는 LONGT5 모델과 같은 크기의 모델과 동일
Pre-training
UL2 pre-training objective
batch size 256
input size 4096
output size 910
Fine-tuning
learning rate 0.001
batch size 128
dropout rate 0.1
Question answering 데이터셋은 output length 128,
summarization 데이터셋은 output length 512
Data
데이터셋 별 중앙값과 90% 분위수에 해당하는 입력 길이를 SentencePiece 토큰으로 측정
Timing
성능 측정은 xprof를 사용하여 TPUv4 칩의 샘플 당 시간 측정
추론에는 TPUv4 하나를 사용하고 배치 크기는 16 또는 메모리에 맞는 가장 큰 배치 크기를 사용합니다.
Fine-tuning에는 8개의 TPUv4 칩을 사용하여 각 모델별로 처리량을 극대화하도록 프로파일링합니다.
4.2 Main results
LONGT5와 COLT5의 성능과 속도 트레이드 오프를 비교
16k 입력 길이에서 Large와 XL에서 COLT5는 LONGT5와 같거나 더 나은 성능을 보이며, 학습 속도가 35-75% 빨라지고 추론 속도가 50-100% 빨라집니다.
4.3 In-context learning
In-context Learning은 태스크 설명과 같은 프롬프트형 텍스트(prompt or text description)와 몇 개(few)의 예시들을 조합하여 다음 예시의 결과를 언어 모델의 자연어 생성 성능을 통해 풀고자 하는 few-shot learning 방식입니다.
in-context learning에서 더 많은 샷을 사용하여 COLT5의 긴 입력 기능을 활용할 수 있다는 것을 보여주는 그래프
COLT5 모델은 긴 입력에 대한 추론을 가능하게 하기 때문에, UL2 objective에서 학습된 모델은 작은 크기에서도 강력한 few-shot in-context learning (ICL) 기능을 보이고 있습니다.
따라서 COLT5를 활용하여, in-context learning에 사용되는 예제 수를 늘릴 수 있게 되었습니다.
위 가설을 검증하기 위해, Natural Questions와 TriviaQA 데이터셋에서 input 길이에 따른 few-shot learning 성능을 평가합니다.
이를 위해 가능한 많은 예시를 context로 사용하고, COLT5는 미리 학습된 input 길이까지만 in-context learning을 수행할 수 있다는 것을 발견했습니다.
아래 표는 입력 길이의 함수로서 COLT5의 few-shot 성능을 나타내며, COLT5가 증가하는 예제 수에서 정보를 추출하는 데에 그의 장문 입력 기능을 적용할 수 있다는 것을 보여줍니다.
4.4 Ablations
Routing
입력을 균일하게 라우팅시키는 정적 라우팅은 성능의 큰 하락을 유발한다는 것을 알 수 있습니다.
라우팅의 중요성은 모델이 중요한 토큰에 대한 weight를 학습하며,
COLT5의 이점이 추가 매개변수에 의한 것이 아니라는 증거를 제공합니다.
Query 및 KV 토큰에 대한 라우팅 결과을 공유하는 것은 v=q와 비교하여 적은 품질 감소 및 속도 증가를 초래합니다.
효과적인 성능과 더 중요한 레이어에 대한 계산 비용 사이에는 최적의 경로 토큰 수가 존재합니다.
Attention
v=all, 즉 라우팅된 토큰이 전체 입력에 주의를 기울이는 방식과 v=q, 라우팅된 키와 값이 쿼리와 같은 수로 사용되는 방식의 두 가지 다른 어텐션 설정으로 성능을 비교하여 이를 확인합니다.
입력 전체에 주의를 기울이는 경우 크게 증가하는 비용에 비해 성능 향상이 거의 없습니다.
Other
PEGASUS가 UL2보다 미세하게 더 잘 세밀 조정되는 것으로 나타납니다.
4.5 Routing analysis
토큰을 (1) 질문 토큰, (2) 답변 토큰, (3) 기타 토큰 세 가지 범주로 나누어 라우팅되는 각 유형의 토큰의 평균 비율을 파악
실제로 의미 있는 (1) 질문 토큰, (2) 답변 토큰, (3) 기타 토큰 순서대로 라우팅이 많이 포함된 걸 확인
Which tokens are selected
COLT5-Large 모델의 각 레이어와 라우팅 구성 요소별로 라우팅된 토큰의 종류(질문 토큰, 정답 토큰, 기타 토큰)의 비율을 보여줍니다.
층별로 라우팅 결정을 분리하면 흥미로운 패턴이 나타납니다.
초기 층에서는 질문 및 답변 토큰이 선택될 가능성이 적지만, 나중에는 급격하게 증가하여 마지막 층에서 최고점에 달합니다.
초기 층에서는 모델이 중요한 토큰과 문서의 부분을 식별할 기회가 없기 때문입니다.
그러나 증가는 단조적이지 않으며 층 간에 강한 변동이 있습니다.
이 변동은 서로 다른 층이 다른 유형의 토큰에 집중하거나 일부 라우팅 구성 요소가 중요한 토큰을 성공적으로 식별하지 못하는 것일 수 있다는 것을 나타낼 수 있습니다.
Correlation between routing processes.
COLT5 Large 모델에서 각 레이어에서의 라우팅 가중치 간의 Pearson 상관계수를 보여줍니다.
각 레이어에서
- MLP 라우팅과 KV 라우팅
- MLP 라우팅과 Q 라우팅
- KV 라우팅과 Q 라우팅
간의 상관관계를 보여줍니다.
key/value 라우팅과 query 라우팅은 모든 레이어에서 매우 상관 관계가 높습니다.
MLP 내의 토큰 라우팅은 다른 두 프로세스와 상관관계가 낮습니다.
흥미롭게도, MLP와 어텐션 라우팅 간의 상관관계는 모델의 마지막 레이어에서 증가합니다.
Conclusions
COLT5는 조건부 계산을 활용하여 더 높은 성능과 더 빠른 속도를 제공하는 long text용 새로운 모델입니다.
COLT5는 입력 전체에 적용되는 light 피드포워드 및 어텐션 레이어와 학습 라우터가 선택한 중요한 토큰의 일부에만 적용되는 heavy 브랜치로 구성됩니다.
COLT5는 LONGT5에 비해 속도에서도 강력한 성능을 발휘하며, 최대 64k 토큰까지 매우 긴 입력을 효과적이고 효율적으로 활용할 수 있음을 보여줍니다.
Share article