Test-Time Adaptation for Visual Document Understanding
대부분 Visual Document Understanding(VDU) 태스크는 self-supervised pre-training 뒤에 이어지는 fine-tuning으로 이루어집니다.
Oct 11, 2022
Introduction
대부분 Visual Document Understanding(VDU) 태스크는 self-supervised pre-training 뒤에 이어지는 fine-tuning으로 이루어집니다. 이러한 방식은 task-agnostic 방식으로 라벨링이 없는 데이터를 사용하여 학습하지만, source domain에서 라벨링된 데이터에 지도학습이 이루어진 뒤에 unseen 데이터에 적용되면 심각한 성능 하락을 겪을 수밖에 없습니다.
Domain shift 문제를 해결하기 위해 source domain과 target domain을 공통 feature space로 매핑하는 Domain Adaptation(DA) 알고리즘이 연구되었지만, 실제로는 개인정보 이슈 또는 학습과 추론 영역의 다른 개발 환경으로 인해 이런 방식을 적용하기 어려울 수 있습니다. Test-time adaptation(TTA) 방식은 unseen target data에 모델을 적응시키기 위해 연구되었습니다. 하지만 VDU는 image classification과 같은 computer vision 태스크와는 좀 다른 특성을 가지고 있기 때문에 기존 TTA 알고리즘을 그대로 적용하기는 어려울 것입니다.
Proposed Method
- DocTTA framework
- 텍스트 시퀀스 :
- 이미지
- 레이아웃 : - 텍스트 시퀀스 각 word의 bounding box
Domain을 다음과 같이 input의 분포 와 라벨링 의 조합으로 표현합니다:
Source domain 에 학습된 모델을 (파라미터: ) 라고 하겠습니다.
TTA는 GT가 없는 도메인 ( )에 대하여 타겟 모델 를 학습하는 것으로 정의할 수 있습니다. 아래 Algorithm 1이 제안하는 DocTTA를 보여줍니다.
Computer vision에서 흔히 사용하는 single-modality input과 다르게, 문서는 멀티모달 input입니다. 본 연구에서 input은 다음과 같이 세 개의 구성요소로 이루어집니다.
여기서는 source 와 target domain이 모두 같은 클래스 라벨을 사용하는 closed-set TTA를 가정합니다.
- DocTTA objective functions
Objective I: masked visual language modeling (MVLM).
테스트 데이터에서 2D position과 텍스트 토큰을 사용하여 더 나은 text representation을 얻기 위하여 DocTTA에서도 MVLM을 수행합니다. 구체적으로, 15%의 토큰을 랜덤하게 선택하고 이 중 80%를 스페셜 토큰 [MASK]로, 나머지 20%를 전체 vocab 중에서 랜덤하게 교체합니다.
인코더의 feature는 전체 vocab에 해당하는 logits를 출력하는 classifier에 사용됩니다. 모델은 마스킹된 텍스트 토큰을 정확하게 예측하도록 NLL을 사용하여 학습됩니다.
Objective II: self training with pseudo labels
MVLM으로 모델을 adapt 하는 중에, 라벨링이 없는 target 데이터에서 pseudo label을 온라인으로 생성하여 이를 target 데이터의 GT로 사용합니다. Noisy pseudo label을 방지하기 위하여 불확실성이 낮은 label을 선택할 수 있도록 uncertainty-aware selection mechanism을 사용합니다. 이 알고리즘에서 선택 기준은 score와 MC-Dropout을 사용합니다.
실험적으로, raw confidence score를 그대로 사용한 것은 over-confident 문제가 있었습니다. 따라서 이 실험에서는 pseudo-label을 선택하기 위하여 Shannon’s entropy 형식의 불확실성을 근거로 사용하였습니다.
Target 샘플 에 대하여 클래스 가 맞는 클래스일 확률을 라고 할 때, pseudo label 은 불확실성이 특정 기준치 이하일 때에만 선택됩니다.
Objective III: diversity objective
모델이 pseudo label에 따라 가장 확률이 높은 클래스에 편향되는 것을 막기 위해 다음 objective도 사용하였습니다.
여기서 는 다음과 같이 target 데이터에 대해 평균된 target 모델 임베딩입니다.
위 3개의 objective를 모두 사용하여 DocTTA를 학습합니다.
- DocTTA vs. DocUDA
만약 target 데이터에 adapt 하는 과정에서 source data에 접근할 수 있다면 TTA는 UDA 태스크로 확장됩니다. 이 추가적인 데이터는 TTA에 비해 이점이 될 수 있지만, 본 실험에서는 거의 차이가 없거나 domain gap이 큰 경우 오히려 독이 되기도 하였습니다. 또한 UDA는 source 데이터로 동시에 학습해야 하기 때문에 학습 시간이 더 길고 메모리도 더 많이 필요합니다.
Experiment
- DocTTA benchmarks
- FUNSD-TTA: 9,707 semantic entity와 7개의 class label을 가진 데이터셋입니다. 여기서는 train과 test split을 모두 합친 다음 폼에 텍스트가 상대적으로 많이 채워진 149개의 이미지를 source domain으로, 텍스트가 상대적으로 적게 채워진 50개의 이미지를 target domain으로 하였습니다.
- SROIE-TTA: Source domain은 600개의 문서로, 적당한 각도와 검정색 잉크가 확실하게 보이는 것들로 구성하였습니다. Target domain은 블러리하거나 회전되어있거나 컬러 잉크를 사용하였거나, 여백이 많은 것들로 347개를 골랐습니다.
- DocVQA-TTA: DocVQA 데이터셋은 거의 20개의 서로 다른 종류의 문서를 포함하는 VQA 데이터셋입니다. 본 연구에서는 이 중에서 1) Email & Letters (E) 2) Tables & Lists (T) 3) Figure & Diagrams (F) 4) Layout (L) 총 4개의 문서 종류를 사용하였습니다.
Distribution shift에 대한 연구를 위해, 잘 알려진 공개 VDU 데이터셋을 사용하여 real-world challenge를 모사하기 위한 벤치마크를 구성하였습니다.
- Results
- Evaluation metric: Entity recognition와 Key-value recognition 태스크에서는 엔터티 레벨의 F1-score을 사용하였습니다. DocVQA 태스크에서는 Average Normalized Levenshtein Similarity (ANLS)을 사용하였습니다.
- Model architecture
- Baselines
- Results
- FUNSD-TTA & SROIE-TTA
- DocVQA-TTA
- Confidence calibration with DocTTA
- Ablation study
모든 실험에서, 레이어 갯수와 헤드 갯수가 모두 12이고 hidden size는 768인 Transformer encoder를 사용했습니다. Visual backbone은 ResNeXt101-FPN을 사용했습니다.
VDU 태스크에서 TTA 접근법은 처음이기 때문에 비교할 baseline이 없습니다. 따라서 image classification의 TTA, UDA 알고리즘을 비교대상으로 사용하였습니다. Source-only는 아무런 adaptation 없이 source domain에서만 학습된 것이고 Train-on-target은 target domain에서 학습이 된 것으로, 비교 대상으로써 성능의 lower-bound와 upper-bound 역할을 할 수 있을 것입니다.
DocTTA는 source domain을 학습에 사용하지 않는데에도 다른 베이스라인 UDA 알고리즘보다 높은 성능을 보여주었습니다.
DocTTA가 DocUDA보다 성능이 높은 경우가 있었습니다. 이것은 target domain 크기가 작고 domain gap이 커서 그런 것으로 유추됩니다.
Confidence가 정확하지 않으면 사람이 속을 수 있기 때문에 confidence를 보정하는 작업이 매우 중요합니다. 아래 Table4는 confidence에 따라 prediction을 10개의 bin으로 나누고 confidence에 대한 accuracy를 도식화한 것입니다. DocTTA를 거친 후 훨씬 calibration이 잘 되고 ECE 값도 낮아졌습니다.
Share article