QuadTree Attention for Vision Transformers
컴퓨터비전에서도 트랜스포머의 제곱 복잡도는 문제가 됩니다. 본 논문은 Quadtree 어텐션을 도입하여 제곱의 연산 복잡도를 선형으로 줄입니다. Quadtree 트랜스포머는 토큰 피라미드를 생성하여 어텐션을 Coarse-To-Fine 방식으로 연산합니다.
Apr 20, 2022
초록
컴퓨터비전에서도 트랜스포머의 제곱 복잡도는 문제가 됩니다.
본 논문은 Quadtree 어텐션을 도입하여
제곱의 연산 복잡도
를 선형
으로 줄입니다. Quadtree 트랜스포머는
토큰 피라미드
를 생성하여 어텐션을 Coarse-To-Fine
방식으로 연산합니다.피라미드의 각 레벨에서 어텐션 점수가 높은
상위 K개
의 패치를 선택하고, 다음 레벨에서 상위 K개 패치와 Relevant한 Regions 안에서만 어텐션을 수행합니다.Quadtree 어텐션은 Feature Matching의 ScanNet의 Flops를 50% 줄이면서 4.0% 개선합니다.
이미지 분류에서는 이전 SOTA 모델의 Top-1 정확도를 0.4~1.5% 개선합니다.
이외의 태스크들에서도 이전 SOTA 트랜스포머 모델의 성능을 개선했습니다.
도입
컴퓨터비전에서 표준 트랜스포머를 활용하기 위해 기존 연구들은 저해상도 혹은 Sparse Tokens에 트랜스포머를 적용합니다.
하지만 역시 고해상도에 트랜스포머를 적용하는 것이 더 좋은 성능을 내기 때문에 많은 연구들이 트랜스포머의 계산 복잡도를 줄이기 위해 디자인을 연구하고 있습니다.
그 중 Linear Approximate 트랜스포머인 PVT는 선형 방식으로 어텐션 연산을 진행했지만 여전히 비전 태스크에서 좋지 않은 성능을 냈습니다. 그 이유는 키와 밸류를 다운샘플링하여 연산 비용을 줄였지만,
픽셀 레벨의 디테일
을 포착하지 못했기 때문입니다.이와 반대로 스윈 트랜스포머는 트랜스포머의 가장 중요한 이점인
Long-Range Dependencies
를 손해보면서도, 하나의 어텐션 블록안에 있는 Local Windows로 어텐션을 제한했습니다.이러한 이전 방법들과 다르게 본 논문은 효율적인 비전 트랜스포머를 구성하여 Fine Image Details와 Long-Range Dependencies 모두 잡았습니다. 이는 대부분의 이미지 영역은
서로 Irrelevant
하다는 관측에 기반하여, 토큰 피라미드를 생성하고 Coarse-To-Fine 방식으로 어텐션을 계산하기 때문입니다. 이러한 방식을 통해,
거친 수준(Coarse Level)
영역이 유망하지 않을 경우 미세 수준(Fine Level)
에서 관련없는 영역을 빠르게 건너뛸 수 있습니다.위 그림은 크로스 어텐션을 통해 피쳐 매칭을 하는 과정의 일부입니다.
위 그림의 Level 1에서
이미지 A 속 파란 이미지
의 어텐션을 계산할 때 이미지 B 안의 모든 패치들
을 함께 사용합니다. 그리고 파란색으로 하이라이트된 상위 K개(K=2) 패치
를 선택합니다.Level 2에서 이미지 A 속
4개의 Framed Sub-Patches
(Level 1의 파란 패치의 자식 패치들)에 대해, Level 1의 이미지 B 속 상위 K개의 패치에 대응되는 Sub-Patches와 함께 어텐션을 계산합니다.다른 모든 그림자 처리된 서브 패치들은 연산을 줄이기 위해 건너뜁니다.
이미지 A의 2개 서브 패치들을 노란색과 초록색으로 하이라이트합니다. 이들에 대응되는 이미지 B의 상위 K개 패치들 또한 같은 색으로 하이라이트 됩니다.
이러한 방식으로 Level 3까지 반복하며, Fine Scale 어텐션과 Long-Range Connections를 얻습니다.
Quadtree 구조가 이러한 방식으로 형성되므로, 본 논문의 방식을 QuadTree Attention 혹은 QuadTree Transformer라고 부릅니다.
뒤에 나올 실험 파트에서, Feature Matching과 같은 크로스 어텐션을 요구하는 태스크와 이미지 분류와 같이 Self-Attention만을 활용하는 태스크에서 Quadtree 트랜스포머의 효과를 입증합니다.
방법론
트랜스포머의 어텐션
트랜스포머의 어텐션 모듈은 피쳐 임베딩 간 넓은 범위의 정보를 포착합니다.
두 개의 이미지 임베딩 과 가 있을 때, 어텐션 모듈은 이 둘 사이에서 정보를 Pass합니다.
Self-Attention은 과 가 같은 경우이며, Cross Attention은 과 가 서로 다를 때와 같이 더 일반적인 상황을 커버합니다.
어텐션은 먼저 쿼리 Q, 키 K, 밸류 V를 다음의 식으로 생성합니다.
여기서 그리고 는 학습 가능한 파라미터입니다.
그 후, 쿼리와 키 간의 어텐션 점수를 계산함으로써
메시지 집계(Message Aggregation)
를 수행합니다.여기서 C는 임베딩 채널 차원입니다.
위 과정은 의 계산 복잡도를 가지며, N은 이미지 패치의 수를 뜻합니다. 이 제곱 복잡도로 인해 트랜스포머가 고해상도의 출력이 필요한 태스크에 적합하지 않습니다.
QuadTree Attention
이름에서 알 수 있듯이, 자식 노드가 4개인 쿼드트리 자료구조의 아이디어를 차용하여, 2차원 공간을 재귀적으로 하위 분할하여 4개의 사분면으로 나눕니다.
Quadtree 어텐션은 Coarse-To-Fine 방식으로 어텐션을 계산합니다. Coarse 레벨의 결과에 따라 Fine 레벨에서 상관없는 이미지 영역을 건너뜁니다. 이 구조가 높은 효율을 유지하면서 적은 정보 손실을 가능하게 합니다.
일반적인 트랜스포머와 동일하게, 과 를 선형적으로 투영하여 쿼리, 키 그리고 밸류 토큰을 만듭니다.
빠른 어텐션 계산을 가능하게 하기 위해, 피쳐맵을 다운샘플링하여 쿼리,키, 밸류를 위한 L 단계의 피라미드를 구축합니다.
쿼리와 키 토큰에 대해서 Average Pooling 레이어를 사용하며, 밸류 토큰에 대해서는 특별한 상황이 아니면 Average Pooling이 크로스 어텐션에서 사용되며, Stride 2를 가지는 Convolutional-Normalization-Activation 레이어가 Self-Attention에서 사용됩니다.
위 그림에서 보이다 싶이, Coarse 레벨에서 어텐션 점수를 계산한 후, 각 쿼리 토큰에 대해 어텐션 점수가 가장 높은
상위 K개의 키 토큰
을 선택합니다. Fine 레벨에서는 쿼리 서브 토큰들은 이 상위 K개의 키 토큰의 서브 토큰들만 사용합니다. 이 과정은 Finest 레벨에 다다를 때까지 반복되며, 어텐션 점수 계산이 끝난 후 모든 레벨에서 메시지(Messages)를 집계합니다.여기서 본 논문은 2개의 구조를 디자인하는데 이를 QuadTree-A와 QuadTree-B라고 부릅니다.
QuadTree-A
Finest 레벨에서의 i번째 쿼리 토큰 를 생각해봅니다. 모든 키 토큰으로부터 받은
쿼리 토큰의 Received Message
를 계산해야합니다. 이 구조는 다른 피라미드 레벨로부터 Partial Messages를 수집하여 전체 메시지를 조립(Assembles)합니다.여기서 은 레벨 에서 계산된 Partial Message를 뜻합니다. 번째 레벨에서, 이 Partial Message 는 영역 안에 있는 토큰들을 조립합니다. 이러한 방법으로 Coarse 레벨에서 관련성이 적은 영역들에 있는 메세지들이 계산됩니다. 그리고 서로 관련이 높은 영역들은 Fine 레벨에서 계산됩니다.
이 방식은 위 그림의 (b) 부분입니다. 메시지 는
다른 색
을 가진 다른 이미지 영역
으로부터 계산된 3개의 Partial Messages 조립함으로써 생성되며, 이는 집합적으로(Collectively) 전체 이미지 공간을 커버합니다.초록색 영역은 가장 관련이 있는 영역을 가리키며 Finest 레벨에서 계산됩니다.
빨간색 영역은 가장 관련이 없는 영역을 가리키며 Coarse 레벨에서 계산됩니다.
영역
은 로 정의되며, 여기서 영역 은 레벨 의 상위 K개 토큰에 대응되며 위 그림의 (c)입니다. 영역 은 전체 이미지를 포괄합니다.Partial Message는 다음과 같이 계산됩니다.
여기서 은 레벨 에서
쿼리와 키 토큰 사이의 어텐션 점수
입니다. 위 그림의 (a)는 과 같이 같은 색을 가진 을 계산하는데 연관되어있는 쿼리와 키 토큰을 하이라이트합니다.
어텐션 점수는
재귀적
으로 계산됩니다.여기서 은 대응되는 부모 쿼리와 키 토큰의 점수이며, 입니다.
잠정적인(Tentative) 어텐션 점수 는 같은 부모 쿼리 토큰의 토큰들 사이에서 Equation 1에 따라 계산됩니다.
QuadTree-A에 대해서, Average Pooling 레이어를 사용하여 모든 쿼리, 키 그리고 밸류 토큰을 다운샘플링합니다.
QuadTree-B
QuadTree-A의 어텐션 점수 는 모든 레벨에서 재귀적으로 계산되며, 이로 인해 Finer 레벨에서 점수가 더 작아지게 하고, 미세한 이미지 피쳐의 기여를 줄이게 합니다.
게다가, Fine 레벨 점수는 Coarse 레벨에서의 Inaccuracy에 크게 영향을 받습니다. 따라서 저자는 다른 전략인 QuadTree-B를 설계합니다.
다른 레벨로부터 얻은 Partial Messages의 Weighted Average로써 를 계산합니다.
여기서 은 학습된 가중치입니다.
위 그림의 (c)에서 보이다 싶이, Partial Messages는 서로 겹치며, 다음과 같이 계산됩니다.
여기서 어텐션은 Equation 1과 같이 Attention Message Computation입니다. 는 영역 내에 있는 모든 키와 밸류를 쌓아 만든 행렬입니다.
QuadTree-A와 QuadTree-B 모두 Sparse Attention Evaluation을 사용합니다. 따라서 본 논문의 방법은 계산 복잡도를 크게 줄입니다.
복잡도 계산
쿼리 토큰, 키 토큰 그리고 밸류 토큰의 길이가 모두 라 가정합니다.
레벨의 토큰 피라미드를 구축하고, 번째 레벨은 길이의 토큰을 가집니다.
Quadtree 어텐션의 Flops는 다음과 같습니다.
여기서 과 은 토큰 피라미드의 Coarest 레벨의 높이와 너비입니다. 따라서 는 상수이며 계산 복잡도는 입니다. K가 상수이기 때문에 Quadtree 어텐션의 복잡도는 토큰 길이에 선형 비례합니다.
실험
Cross Attention Tasks
Feature Matching
저자는 피쳐 매칭의 SOTA 기술인 LoFTR를 기반으로 연구를 진행합니다.
LoFTR는 CNN 기반의 Feature Extractor와 트랜스포머 기반의 Matcher로 구성되는데, 여기서 LoFTR의 선형 트랜스포머를 Quadtree 트랜스포머로 교체하여 실험합니다.
실험은 ScanNet이라는 데이터를 기반으로 진행하며 1513 장의 이미지를 사용합니다.
빠른 학습을 위해 LoFTR-lite라는 세팅을 만들었습니다. 이는 LoFTR의 피쳐 채널의 절반을 사용하고, 453장의 학습 스캔본을 사용합니다.
배치 크기는 8, 30 에폭으로 학습시키며, Quadtree 트랜스포머를 위해 3단계의 피라미드를 구축하고 가장 낮은 레벨의 해상도는 픽셀입니다.
파라미터 K는 Finest 레벨에서는 8로 설정하며, 단계가 낮아질수록 2배씩 늘립니다.
위 표는 () 하에서의 Camera Pose Errors의 AUC를 나타냅니다. Camera Pose Error는 카메라 방향, 평행이동 방향을 Degree로 예측한 값과 GT의 차이를 나타내는 값입니다.
Self-Attention Task
Image Classification
이미지 분류 태스크에서 기존의 비전 트랜스포머와 본 논문의 방법을 비교하기 위해, 저자는 PVTv2 모델을 사용했으며, 이 모델의 Spatial Reduction Attention을 모두 Quadtree Attention으로 대체합니다.
데이터셋은 ImageNet-1K를 사용합니다.
토큰 피라미드는 K=8, Coarest 레벨의 해상도로 설정합니다.
이미지 크롭과 리사이즈로 입력 크기로 만들며, 미니 배치 크기는 128로 설정합니다.
모든 모델은 8개의 GPU로 300 에폭동안 학습되며, 이외의 세팅은 모두 동일합니다.
위 결과표는 Top-1 정확도입니다.
결과를 비슷한 네트워크 복잡도로 묶어 5개의 섹션으로 나누었고, 이는 파라미터 수로 파악할 수 있습니다.
QuadTree-B와 PVTv2를 비교해보면 파라미터 수는 섹션 별로 엇갈리지만 정확도는 QuadTree 방식이 더 높은 것을 알 수 있습니다.
이를 통해 Global 정보가 더 중요하다는 것을 알 수 있다고 합니다.
일반적으로, 본 논문의 Quadtree 트랜스포머는 Coarse 레벨에서 Global 정보를 활용하고 Fine 레벨에서 Local 정보를 활용하기에 PVTv2와 Swin 트랜스포머를 능가합니다.
결론
이 논문 이후에 피쳐 매칭 모델인 LoFTR 코드는 QuadTree Attention을 적용하여 모델을 업데이트 되었습니다.
현재 템플릿 OCR에서 베이스라인 매칭 모델로 Quadtree가 반영된 모델을 사용하고 있고, 문서 이미지들로 파인 튜닝을 하지 않았음에도 현재까지 피쳐 매칭과 관련된 이슈는 없었습니다.
더 정확한 결과를 위해 추후에 매칭에 대한 평가지표를 설계하여 모델 성능을 정량적으로 평가할 예정입니다.
LoFTR 논문은 이전 스터디에 있습니다. 참고하시면 이해에 도움이 될 것 같습니다.
Share article