LongT5: Efficient Text-To-Text Transformer for Long Sequences
최근 NLP task들에서 long input을 다룰 수 있는 Transformer 모델들이 좋은 성능을 기록하였습니다. 또한, Transformer 모델의 크기를 키우는 것이 성능에 도움이 된다는 연구들이 보고되고 있습니다.
May 19, 2022
Introduction
Motivation
최근 NLP task들에서 long input을 다룰 수 있는 Transformer 모델들이 좋은 성능을 기록하였습니다. 또한, Transformer 모델의 크기를 키우는 것이 성능에 도움이 된다는 연구들이 보고되고 있습니다.
위의 두 가지 가설을 확인하기 위해 논문의 저자들은 모델의 input length와 size를 동시에 늘리고 그 효과를 확인하고자 했습니다. 구체적으로, 논문에서는 long-input transformer attention과 scalable T5의 pre-training ideas을 사용하여 두 가지 목표를 해결합니다.
ROGUE score
Summarization task에서 사용하는 metric으로 사람이 만든 reference sentence와 모델이 예측한 predicted sentence간의 N-gram에 대한 f1 score.
Contributions
- 모델의 input length와 size를 동시에 키울 수 있는 LongT5 모델을 제안.
- ETC의 local/global attention을 모방하여 만든 새로운 attention mechanism(TGlobal)을 제안.
- Vanilla T5와 LongT5 모델의 input length와 size를 늘렸을 때의 성능 분석 제공.
- ArXiv, PubMed, BigPatent, MediaSum과 같은 여러 데이터셋에서 SOTA 성능 달성.
Proposed Method
Background
1. Text-To-Text Transfer Transformer (T5)
T5는 NLP의 모든 task들을 Text-to-Text format으로 변환하여 해결합니다. 따라서, T5는 Transformer의 encoder와 decoder를 모두 사용하며 아래와 같이 task에 해당하는 string들과 input string을 결합하여 모델의 입력으로 사용합니다.
T5는 SpanBERT의 pre-training 아이디어를 차용하여 기본 BERT처럼 token 하나만을 masking하고 이를 예측하던 MLM(Masked Language Model) task에서 벗어나 아래의 그림과 같이 연속된 여러 token들을 masking하고 이를 예측하는 SCO(Span Corrpution Objective) task를 통해 pre-training됩니다.
2. Extended Transformer Construction (ETC)
ETC는 Transformer 모델에서 long input을 처리하기 위해 Global-Local Attention을 제안했습니다. Global-Local Attention은 아래의 그림과 같이 full attention을 하는게 아니라, global token과 local token들에 대해서만 sparse하게 attention합니다.
하지만 ETC의 경우, model의 input을 구성할 때 global token을 따로 만들어주어야 한다는 단점이 있습니다. → 이러한 단점 극복을 위해 LongT5에서는 TGlobal 제안.
3. Pre-training with Extracted Gap-sentences for Abstractive Summarization (PEGASUS)
PEGASUS 논문의 저자들은 fine-tuning task에 맞는 pre-training task가 성능 향상에 도움이 될 수 있다고 설명하고 있습니다.
따라서, PEGASUS는 summarization task만을 위해 만들어진 모델로 아래의 그림과 같이 중요한 문장들(principal sentence)을 masking하고 이를 예측해내도록 하는 GSG(Gap Sentences Generation) task를 pre-training task로 사용하였습니다.
PEGASUS의 GSG 방식 설명
GSG는 문서 내에서의 중요 문장을 찾아내고 이를 masking한 후에 예측하도록 하는 pre-training task입니다. 이 때, 각 문장의 중요도를 loop를 돌면서 현재 문장과 남은 문장들간의 ROUGE-F1 score를 통해 구합니다. 이렇게 구해진 중요 문장들의 top m개를 선택하여 masking하고 이 문장들을 예측하도록 합니다.
LongT5
1. Architecture
LongT5 모델은 기존 T5 encoder에 global-local attention을 추가하여 사용하고 decoder에는 기존 T5 decoder를 그대로 사용합니다.
또한, 기존 T5와의 호환성을 최대한 유지하기 위해 Relative Position Bias, Example Packing 등은 그대로 유지했습니다.
마지막으로, 논문에서 제안하는 TGlobal attention이 효과있는지 확인하기 위해 Local attention만 사용하는 모델과 Local + TGlobal attention을 사용하는 모델의 두 가지 버전을 구현했습니다.
2. Local Attention
Local Attention은 아래의 그림과 같이 local radius 을 통해 선택된 neighborhood tokens에 대해서만 attention을 수행하는 mechanism입니다.
Local Attention에서 input token length가 이라면 복잡도는 이 됩니다.
3. Transient Global Attention (TGlobal)
TGlobal은 앞서 언급하였듯이 ETC의 Global-Local Attention에서 global token을 모델의 input에서 만들어주어야 한다는 단점을 극복하고자 제안한 attention mechanism입니다.(아마도 저자들은 T5 모델의 기본 구조를 훼손하지 않는 선에서 long input을 처리할 수 있도록 attention mechanism을 구현하고 싶었던 것으로 보입니다.)
구체적으로, TGlobal은 Local Attention에 각각의 layer에서 동적으로 생성되는 global token들과의 Global Attention을 결합한 것입니다. 이 때, global token들은 각각의 token들을 블록으로 묶고 해당 블록 내의 모든 token embedding들을 합한 결과를 Layer Normalization하여 생성됩니다.
TGlobal에서 block size가 라면 복잡도는 가 됩니다. 선형 복잡도로는 볼 수 없지만 가 충분히 크다면 복잡도가 상당히 줄어들 것으로 생각됩니다.
PEGASUS Principle Sentences Generation Pre-training
NLP에서는 pre-training에 어떤 task를 쓰는지에 따라, 성능이 상당히 달라집니다.
논문의 저자들은 T5에서 사용한 Span Corruption Object보다 PEGASUS의 Gap Sentences Generation 방식을 사용하는게 더 좋다고 주장하며 이를 pre-training task로 사용합니다.
PEGASUS의 pre-training task가 요약 task뿐만 아니라, 다른 task를 위해서도 좋다고 주장하고 있습니다.
Experiments
Configurations
- baseline으로 T5.1.1 checkpoint 모델을 사용
- base(~220M), large(~770M), xl(~3B) 크기의 3가지 버전의 모델 구현 (참고로, BERT의 base모델은 110M, large모델은 340M, GPT-3는 175B)
- 32000개의 sentence piece를 갖는 T5.1.1의 vocab 사용
- 128 batch size, Adafactor optimizer 사용
Pre-training
- 4096 input sequence length과 910 output sequence length를 갖는 LongT5 모델을 1M steps 동안 pre-training 수행 (이 때, pre-training에서의 input sequence length를 이렇게 고정해두고 fine-tuning시에 늘린 것인지 의문? )
- Dropout 사용하지 않음
- T5 논문에서 배포한 C4 dataset을 통해 pre-training 수행
- T5와 동일한 inverse square-root learning rate scheduler 사용
- masked sentence ratio : 0.2
Fine-tuning
- Learning rate : 0.001, Dropout rate : 0.1
- Summarization task의 경우, 4096 ~ 16384개의 input length, 512개의 output length로 실험
- QA task의 경우, 512 ~ 36864개의 input length, 128개의 output length로 실험
Experiment Settings & Experiment Results
1. Summarization
Summarization task에서는 CNN/Daily Mail, arXiv, PubMed, BigPatent, MediaSum, Multi-News 데이터셋들에 대한 LongT5 모델의 성능(TGlobal 포함)을 다른 요약 모델들과 비교하는 실험을 수행합니다.
각 데이터셋들에 대한 통계량과 실험 성능은 다음과 같습니다.
LongT5는 arXiv, PubMed, BigPatent, MediaSum 데이터셋들에서 SOTA 성능을 달성했습니다.
Multi-News 데이터셋의 경우, PRIMER 모델보다 성능이 낮았는데, 저자들은 그 이유를 PRIMER 모델은 뉴스 관련 도메인으로 이루어진 dataset을 pre-training에 사용했기 때문일 것이라고 추측하고 있습니다.
또한, CNN/Daily Mail 데이터셋의 경우, LongT5 모델이 full attention을 사용하지 않음에도 SOTA 모델인 HAT-BART와 성능 차이가 거의 없으며 ROUGE-2 score에서 오히려 능가하기도 한다는 것을 언급하고 있습니다.
2. Question & Answering
LongT5에서 사용한 evaluation method는 QA dataset들에 대한 benchmark의 evaluation 방식과 차이가 있습니다.
구체적으로, TriviaQA dataset에서는 공식적으로 train set, dev set을 제공하고 test set은 비공개이지만 LongT5에서는 training set의 90%를 train set으로 나머지 10%를 dev set으로 사용하여 hyperparameter 튜닝을 하고 dev set을 test set으로 사용했습니다. 또한, NQ dataset에서는 early stopping을 이용하여 dev set을 평가했다고 합니다. 따라서, 논문의 저자들은 LongT5 모델과 다른 SOTA QA 모델들을 비교하는 것은 의미가 없으므로 T5.1.1 모델과의 비교로 대체했다고 합니다.
다음으로, 논문의 저자들은 Local Attention과 TGlobal Attention의 실험 성능 차이를 알아보기 위해 LongT5의 두 가지 버전(only Local, Local+TGlobal)에 대한 실험을 수행했습니다.
마지막으로, base와 large 모델에는 4x8 TPU v3, xl 모델에는 8x16 TPU v3와 8개의 model partitions(model parallelism)을 통해 실험하고 메모리 용량이 지원될때까지 input sequence length를 늘려보는 실험을 했다고 합니다.
각 데이터셋들에 대한 통계량과 성능은 다음과 같습니다.
위의 실험 결과를 통해 알 수 있듯이 input sequence length가 커질수록 성능이 더 좋아지고 Local Attention을 단독으로 사용할 때보다 TGlobal을 같이 사용할 때 성능이 더 좋은 것을 확인할 수 있습니다.
마지막으로, T5.1.1에 비해 LongT5 모델이 max input sequence length를 더 많이 늘릴 수 있음을 확인할 수 있습니다.
Ablation Study
1. Input Length vs. Speed
저자들은 input length와 속도를 비교하기 위해 input length를 늘려가면서 inference speed를 측정했습니다.
위의 figure를 통해 확인할 수 있듯이, input length가 짧을 때는 T5.1.1, LongT5 Local, LongT5 TGlobal 모델들간의 속도 차이가 거의 없으나, 길어질수록 점점 그 차이가 심해지는 것을 확인할 수 있습니다.
또한, 한 가지 더 중요한 사실은 T5.1.1 모델은 메모리 용량의 한계치에 훨씬 빨리 도달하는 것을 확인할 수 있습니다.
2. Input Length vs. Performance
저자들은 input length와 성능을 비교하는 실험을 수행합니다. (사실, 아래의 figure에서 볼 수 있듯이, 성능과 속도를 비교한 것으로 보입니다...)
위의 figure에서 확인할 수 있듯이, LongT5 모델은 같은 input length를 가짐에도 성능이 거의 유사하거나 좋으며(주로, input length가 길어질 때) 속도도 더 빠른 것을 확인할 수 있습니다.
3. Principle Sentences Generation vs. Span Corruption
마지막으로, 저자들은 pre-training task로서 Principle Sentences Generation(PSG)을 사용하는게 좋은지 Span Corruption(SC)을 사용하는게 좋은지 확인하기 위한 실험을 수행했습니다.
Summarization task와 QA task 모두에서 PSG task를 단독으로 사용하는게 가장 성능이 좋았습니다. 따라서, PSG task를 언어모델의 pre-training task로 사용하면 요약 task뿐만 아니라, 다른 task들에서도 좋은 성능을 달성할 수 있습니다.
Related Works
Language model pre-training
BERT 기반의 모델들은 Masked Language Model(MLM) task를 통해 pre-training됩니다. 하지만 MLM task는 auto-regressive하지 않기 때문에 BERT family들이 generation task에서 약한 성능을 보이게 되는 이유가 되었습니다.
이러한 문제를 해결하기 위해 T5는 pre-training object를 span corrution task를 도입하였으며 BART 모델은 T5와 유사하지만 마스킹되지 않은 토큰들도 예측한다는 점에서 다릅니다.
Long text modeling
Long sequence를 Transformer 모델에서 처리하려는 방법론들은 다음과 같이 3가지 종류로 나눌 수 있습니다.
- Using word-level embeddings (not token-level embeddings) : https://arxiv.org/abs/1606.07869, Doc2Vec
- Modeling long document through hierarchical training : https://aclanthology.org/N16-1174/
Conclusion
논문의 저자들은 모델의 input length와 size를 vanilla T5에 비해 늘릴 수 있는 LongT5 모델을 제안하고 이에 대한 실험을 수행했습니다.
LongT5의 TGlobal attention은 BigBird, ETC 모델과 다르게 attention mechanism 구현이 그렇게 어렵지 않은 것으로 생각되어 저는 TGlobal attention 방식을 상당히 좋게 생각했습니다.
기존의 NLP와 Document Understanding 모델들은 Transformer encoder에 head layer를 마지막에 붙히는 방식으로 구현되어 왔습니다. 하지만 이러한 방식들은 모델 내의 class 수, label 명명 규칙, 새로운 데이터셋에 대한 학습 등의 수요가 발생할 때마다 fine-tuning을 다시 해야한다는 문제점이 있습니다. 이는 상당히 귀찮고 유지보수와 모델 관리에 있어서 힘이 많이 드는 작업입니다. 따라서, 이러한 문제를 해결하기 위해 language model의 output을 항상 text가 되도록 하는 것이 좋겠다는 생각을 했습니다.
또한, 실제 현업에서는 학회 발표용 모델들에 비해 long sequence를 처리해야할 때가 많습니다.
이러한 점에서 T5와 long sequence를 처리하기 위한 LongT5는 앞으로 로민에서 참고할 수 있는 중요한 모델이 될 것이라고 생각했습니다.
References
- longT5 공식 블로그 : https://medium.com/syncedreview/googles-transformer-based-longt5-achieves-performance-gains-by-scaling-both-input-length-and-model-687afb8a3274
- huggingface issue :
Share article