Introduction
Relational reasoning은 오브젝트들의 관계를 추론해내는 태스크입니다. 오브젝트들은 이미지 내 물체일수도, 텍스트 내의 특정 단어일수도, 혹은 속성 정보가 밝혀져 있는 대상일수도 있습니다. 이러한 관계 추론은 '지능'의 필수적인 요건이지만, AI에게는 아직까지는 어려운 태스크로 인식되고 있습니다.
관계추론을 위해 쉽게 떠올릴 수 있는 접근법인 symbolic approach는 태생적으로 symbol grounding problem을 갖고 있어서, 작은 태스크나 입력 변화에 robust하지 않습니다. 다른 여러 방법들의 경우에도 딥러닝에서 자주 맞닥뜨리는 데이터 부족 문제로부터 자유롭지 않습니다.
이 논문에서는 이에 대한 일반적인 해결책으로 Relation networks (RN)을 제안합니다.
Relation Networks
Relation network는 relational reasoning을 가장 잘 수행하기 위해 고안된 네트워크 구조로, 기존의 네트워크에 RN을 추가적으로 붙여서 (RN-augmented network) relational reasoning 문제들을 해결하게 됩니다.
RN의 기본적인 구조는 다음과 같습니다. 여기서 O는 오브젝트들의 집합이며, f, g는 각각 파라미터 \phi와 \theta를 지닌 MLP에 해당합니다.
g는 임의의 두 오브젝트를 입력으로 받으며, 입력 오브젝트들 사이에 관계가 있는지 없는지, 있다면 어떤 관계인지를 추론하는 네트워크가 됩니다. f는 오브젝트들 간의 모든 관계들과 질문을 합쳐, 최종적으로 답을 내놓게 됩니다.
RN은 다음과 같은 장점이 있습니다:
- 관계를 추론하는 방법을 학습하는 것이기 때문에, 기존에 특정 관계에 대해 학습할 필요가 없습니다.
- RN에서는 기본적으로 모든 오브젝트 쌍에 대해 동일한 g를 사용하기 때문에, g가 특정 오브젝트 조합에 과적합하게 학습되지 않아 generalization이 잘 됩니다. 또한 MLP와 비교했을 때, 관계를 추론하는 함수 f가 다뤄야하는 차원 자체가 작기 때문에, 적은 데이터로도 학습이 가능합니다.
- 오브젝트의 집합(set)에 대해 동작하기 때문에, 추론 결과가 입력되는 오브젝트들의 순서에 invariant합니다. 즉, 오브젝트들에서 관계를 추출하기 위해 오브젝트들을 구조화하거나 특정 순서로 넣어주어야 할 필요가 전혀 없습니다.
Tasks and Datasets
이 논문에서는 relational reasoning 중 visual QA, text-based QA, dynamic physical systems 등의 문제를 해결하며, 사용한 데이터셋과 각 데이터셋의 태스크들은 다음과 같습니다.
CLEVR
이미지를 보고 질의문에 답변하는 태스크(VQA)를 위한 데이터셋으로, 입체도형들이 3D 렌더링된 이미지들과 다양한 유형의 질문들로 구성되어 있습니다. 질문의 유형 예시를 들자면 다음과 같습니다:
- Query attribute questions: What is the color of the sphere?
- Compare attribute questions: Is the cube the same material as the cylinder?
- Count questions
데이터는 이미지들로만 이루어진 pixel version (visual input)과, 이미지내 각 물체들의 3D 좌표, 색깔, 모양, 재질, 크기 등의 정보가 포함된 state description version (language input) 이렇게 두 개의 버젼이 있습니다.
기존의 데이터셋에는 모호성과 언어적 편향이 있어서, 모델들이 reasoning 과정이 아니라 질문에 답변하는 법만을 학습하게 하는 경우들이 있었는데, 이 데이터셋은 그러한 문제들을 해결하였다고 합니다.
Sort-of-CLEVR
일반적인 relational reasoning을 위해 CLEVR과 비슷하게 만든 데이터셋입니다. 각 이미지에는 6개의 도형(원 또는 사각형)이 6개의 색깔로 렌더링 되어 있어, 이미지 처리 과정에서의 복잡함과 모호함을 최대한 배제하였습니다. 질문들은 관계형 질문과 비관계형 질문이 각각 10개씩 있으며, 질문 역시 자연어 처리의 어려움을 고려하지 않아도 되는 방식으로 구성되어 있습니다.
bAbI
텍스트로만 구성된 QA용 데이터셋으로, 각기 다른 reasoning에 속하는 20개의 태스크들로 이루어져 있습니다. 각 질문은 여러 개의 문장(supporting facts)과 연관되어 있습니다. 예를 들어, "Where is the football?"이라는 질문에는 "Sandra picked up the football" 과 "Sandra went to the office"의 두 문장을 supporting facts 삼아 답("Office")을 추론해낼 수 있는 식입니다.
Dynamic physical systems
MuJoCo라는 물리 엔진을 이용해서 mass-spring 시스템을 모사해 만든 데이터셋입니다. 각 장면에는 서로 다른 색깔의 공들이 10개씩 있는데, 어떤 공들은 완전히 독립적으로 랜덤하게 움직이며, 어떤 공들은 보이지 않는 걸로 연결되어서 스프링 혹은 강체로 연결되어 연관된 움직임을 보입니다.
시간에 따른 각 공들의 좌표 히스토리가 주어졌을 때, 어떤 공들이 어떤 식으로 연결되어 있는지 판별하는 두 가지 태스크를 수행했는데, 1) 여러 프레임을 관찰하여 공들 사이에 관계가 있는지, 없는지를 판단하는 것과 2) 해당 클립에 몇 개의 시스템(연결 관계; 각 공들을 node로 하고 공들이 연결된 관계를 edge로 하는 그래프)이 존재하는지 찾아내는 것 입니다.
Models
RN은 기본적으로 구조화되지 않은 "오브젝트" 위에서 동작합니다. 먼저 각 태스크에서 오브젝트를 어떻게 추출하는지를 살펴보겠습니다.
- 이미지 픽셀에서 정보를 추출하는 경우에는 CNN을 사용했습니다. CNN의 가장 마지막 피쳐맵의 사이즈가 d*d, 채널이 k개라고 할 때, 각 픽셀을 하나의 오브젝트로 취급하여 각 오브젝트가 1*1*k 크기의 벡터가 되도록 합니다. 이렇게 하면 배경, 특정 물체, 특정 질감, 심지어 물체들의 연결부 까지도 "오브젝트"로 취급할 수 있기 때문에 모델의 굉장히 유연해지게 됩니다.
- State description으로 이루어진 데이터셋의 경우에는, 이미 각 오브젝트들의 속성이 추출되어 있기 때문에 바로 RN에 입력하였습니다.
- 자연어 형태의 input을 오브젝트로 만들어야 하는 bAbI의 경우에는 공간적 정보가 없기 때문에, support set에서 질문 직전 문장 20문장을 set 내에서의 상대적인 위치와 함께 저장했다가, LSTM에 단어 단위로 넣습니다. 이렇게 단어 단위로 나온 최종 상태들을 오브젝트로 간주하여, RN에 입력하게 됩니다.
- 질문에 따라 오브젝트들 사이의 관계 중 어느 부분을 추출할지가 완전히 달라지기 때문에, 관계를 추출할 때 질의문을 함께 고려할 필요가 있습니다. LSTM으로 임베딩된 질의문을 q라고 할 때, 질의문이 고려된 RN의 구조를 다음과 같이 표현합니다.
이렇게 CNN이나 LSTM 등으로 임베딩된 오브젝트들이 RN에 입력됩니다. 아래 그림은 CLEVR-from-pixels의 VQA 태스크를 예시로 하여 전체적인 RN augmented network의 구조를 보여줍니다.
질문은 LSTM을 이용해서 임베딩되고, 이미지는 CNN을 통해 오브젝트들(위의 'Final CNN feature maps'에서 각각 빨강, 노랑, 파랑으로 나타낸 부분들)의 집합으로 나타내어집니다. 이러한 오브젝트들을 RN에 입력하면, RN은 모든 오브젝트들 사이의 관계를 고려하여 질문에 대한 답을 출력합니다.
예시로 든 태스크에서는 4개의 컨볼루션 레이어를 이용해서 피쳐를 임베딩하고, 질의문 처리와 단어 탐색에 각각 128, 32개의 LSTM을 사용했습니다. RN의 g에는 256 unit의 MLP 4-layer와 ReLU를, f에는 256, 256, 29 unit의 MLP 3-layer와 ReLU를 사용했습니다. 가장 마지막 linear 레이어 후에는 정답 후보 단어들 사이에서 softmax를 취해서 답을 찾습니다.
기존에 CLEVR 데이터셋에서 VQA를 수행한 네트워크들이 ResNet이나 VGG로 피쳐를 추출하고, 큰 규모의 LSTM을 이용해 language embedding을 했으며, node가 4천 개 이상 되는 MLP를 차용했던 것과 비교하면 매우 간단한 구조를 갖습니다.
Experiments and Results
CLEVR from pixels
CLEVR에서 95.5%로 당시로써는 압도적으로 state-of-the-art를 달성했습니다. (현재는 2018 MAC의 98.9%?) RN을 더한 모델이 모든 task에서 다른 모델들을 능가하고, 심지어 전반적으로 사람보다 좋은 성능을 보이는 것을 확인할 수 있습니다. 특히 다른 모든 모델들이 고전을 면치 못하던 "compare attribute"와 "task" 유형의 태스크에서 사람과 유사한 수준의 성능을 보였습니다.
RN이 이렇게 간단한 구조에도 불구하고 잘된다는 사실로부터, 이미지나 자연어 처리가 문제가 아니라 relational reasoning이 제대로 되지 않고 있었다는 것을 알 수 있습니다.
다음은 틀린 문제들의 예시입니다. 전체적으로 틀린 문제들을 보았을 때, 물체들이 많이 가려져있거나, 물체들의 정확한 위치 정보가 필요할 때 틀리는 경향을 보였다고 합니다. 그리고 사람에게도 쉽지 않은 문제들이었다고 합니다(!)
CLEVR from state descriptions
이 버전의 태스크에서는 96.4%의 정확도를 달성했습니다. 이는 굳이 visual task 뿐만 아니라, 다양한 형태의 오브젝트들에 대해서도 RN이 높은 수준의 relational reasoning을 해낼 수 있다는 것을 보여줍니다.
Sort-of-CLEVR
CLEVR과는 다르게 Sort-of-CLEVR에서는 관계형 질문과 비관계형 질문을 분리하기 때문에, 각 유형별로 결과를 보면 다음과 같습니다. 관계형과 비관계형 모두에서 CNN+RN은 94% 이상의 정확도를 보인 데 반해, CNN+MLP는 관계형 질문에서 63%로 부진했습니다. 특히, 각 오브젝트들 간의 거리를 모두 측정한 후에 비교해야 하기 때문에 난도가 높은 질문에 속하는 "closest-to", "furthest-from" 유형에서 CNN+MLP는 52.3%에 불과했습니다.
bAbI
bAbI는 각 task에서 95% 이상을 맞춘 경우 성공으로 판정하는데, 전체 20개 중 18개의 태스크에서 성공을 거두었습니다. 기존에는 (모든 태스크에 대해 공통의 네트워크로 학습하는) joint training 기준, Sparse DNC가 19개, DNC가 18개, EntNet은 16개의 태스크에서 성공했던 것과 비교하면 나쁘지 않은 결과입니다. 특히 기초 귀납추론 영역에서 97.9%의 정확도로, 기존의 Sparse DNC (54%), DNC (55.1%), EntNet(52.1%)를 모두 압도했습니다. 95%를 넘지 못한 두 개의 태스크는 "two supporting fact"와 "three supporting fact"로 각각 3.1%, 11.5% 차이로 기준을 맞추지 못한 것으로, 크게 실패한 수준은 아니었습니다. 또한 다른 모델들 (특히 Sparse DNC)과 다르게 hyperparameter 최적화를 위한 노력을 하지 않았기 때문에 더 나아질 수 있을 것으로 보입니다.
Dynamic physical systems
Connection inference task에서는 93%, Counting task에서는 95%의 정확도를 보였습니다. 비슷한 파라미터 수의 MLP로 구성된 다른 네트워크는 찍어서 맞추는것보다 딱히 나을 게 없었습니다.
이 RN은 걸어가는 사람의 움직임을 학습한 적이 없음에도 불구하고, 사람의 움직임으로부터 관적 위치를 추론하는 태스크로 transfer 했을 때에도 system을 잘 잡아내는 것을 확인할 수 있었습니다. (https://www.youtube.com/channel/UCIAnkrNn45D0MeYwtVpmbUQ)
Discussion and Conclusions
기존의 이미지를 "process"하는데 좋은 성능을 보인 ResNet 같은 경우 "reasoning"에서는 그다지 좋은 모델이 아니었습니다. 이 논문에서는 간단한 CNN이나 LSTM 기반의 VQA 아키텍쳐에 RN을 끼워넣는 것 만으로도 심지어 사람을 뛰어넘는 수준의 상당한 개선이 있다는 것을 보였습니다. RN은 reasoning을 위한 모듈이며, CNN이 로컬의 공간 구조를 process하는데에만 집중할 수 있도록 만들어 줄 수 있습니다. 또한 RN의 입력 오브젝트들은 어떠한 구조나 의미들도 추출되지 않은 상태로 바로 reasoning이 가능하다는 점에서, 확장성이 매우 높다고 볼 수 있습니다.
Share article