Self-RAG: Learning to Retrieve, Generate, and Critique through Self-Reflection
LLM의 quality와 factuality를 향상시키 위해 Self-RAG를 제안, LM을 on-demand로 조건에 따라 passage retrieve를 하도록 학습한다. Retrieved passage를 reflection token 이라는 special token을 이용해 self-check를 하여 generation의 quality와 factuality를 향상시킨다.
Apr 25, 2024
Preprint, Oct 17, 2023
Introduction
Motivation
- LLM에는 여전히 factual inaccurate response가 존재함.
- Retrieval-Augmented Generation (RAG)의 등장으로 문제가 어느정도 보완 되기는 했으나,
- Non-factual prompt: 질문 prompt에 따라 검색이 필요하지 않은 경우가 있을 수 있다. 이 경우 LM의 답변의 다양성을 떨어뜨린다.
- Fixed-number of document: 항상 고정된 갯수의 문서를 retrieve 하는 경우 관련성이 낮은 정보를 사용하게 될 수 있다.
Propose
- LLM generation quality 향상을 위해 Self-RAG 제안
- on-demand retrieval and self-reflection
- LLM에 대해 self-check 방식을 도입하고 학습을 유도
Related work
Retrieval-Augmented Generation
Reinforcement Learning from Human Feedback (RLHF)
GPT-3 ➝ ChatGPT에서 큰 성능 변화를 이끌어낸, 사람을 이용한 LLM 학습 방식
- Supervised Fine-Tuning (SFT)
- 적은 양의 labeled dataset으로 Pre-trained LM을 fine-tuning
- Reward Model 학습
- human labeler는 SFT 모델이 생성한 답변 후보를 특정 기준에 따라 랭킹. 이를 점수화한 데이터셋을 수집함.
- 수집된 데이터셋으로 reward model 학습
- Proximal Policy Optimization
- SFT 모델에 여러 입력을 주고 Reward model을 통해 강화학습 수행
학습과정
1➝ (2➝3➝2➝3 ... )
Methodology
Problem formalization
- 변수
- x: input prompt
- y: generated output
- d: retrieved document
- 다음과 같은 네 가지 reflection token을 통해 판단이 이루어짐
- [Retrieve]: input x가 document retrieve가 필요한지 판단
- [IsREL]: Retrieved document가 x와 관련성 있는지 판단
- [IsSUP]: y의 생성에 d가 얼마나 도움이 되었는지 판단
- [IsUse]: y가 x에 대해 얼마나 유용한 답변이었는지 5-scale scoring
Inference 과정
- 각 decision step에서 special token을 생성하도록 함
Supervised model training
Critic model
- Data collection
- 기존의 (Input, output) 데이터에 대해 reflection token annotation이 추가적으로 필요함.
- GPT-4 prompting: GPT-4가 각 reflection token을 생성하도록 instruction prompting을 수행하였다.
- 위 방법으로 20k training data 생성
[Retrieve] token annotation 예시)
기존 데이터에 다음과 같은 isntruction을 같이 주어 [Retrieve] 토큰을 생성하게 함: “Given an instruction, make a judgment on whether finding some external documents from the web helps to generate a better response.”
- Critic learning
- 주어진 critic dataset에서 x, y에 대해 올바른 reflection token을 생성할 수 있도록 conditional language modeling을 수행. critic model은 GPT-4와 90% 이상의 agreement를 보임
Generator model
- Data collection
- 기존의 (x,y) 데이터로부터 생성된 augmented output에 critic model을 이용해 reflection token을 추가함.
- Generator learning
- 주어진 generator dataset에서 input prompt x에 대해 next token과 reflection token을 predict하도록 학습됨.
- Training data example
Self-RAG Inference
- Adaptive retrieval with threshold
- text 생성 과정에서 retrieval가 필요한 순간을 dynamic하게 결정함.
- Hard constraint: [Retrieve] 토큰이 yes로 생성된 경우 retrieve
- Soft constraint: 모든 출력 토큰에 대해 정규화 된 [Retrieve] = yes 를 생성할 확률이 threshold값 이상일 경우 retrieval이 triggered 되도록 설정함.
- Tree-decoding with critique tokens
- Segment를 생성하는 매 순간 마다 critic score S를 기준으로 beam search를 통해 candidate을 선택함.
Experiments
Task & Dataset
- Closed-set tasks
- PubHealth: fact verification dataset about public health
- ARC: multiple-choice reasoning dataset from scientific exams
- Short-form generation tasks (open-domain QA)
- PopQA
- TriviaQA
- Long-form generation tasks
- Biography generation task
- ALCE-ASQA
Model
- LMs with proprietary data
- Baseline without retrievals
- Baseline with retrievals
- Self-RAG
Result
- Main result
- Ablation study
- Training ablation
- No Retrieval R: retrieval 없이 training한 경우
- No critic C: Top-1 document를 항상 사용한 경우
- Test ablation
- Inference time customization
- [IsSUP] token의 weight가 증가할 수록 precision은 올라가지만, Mauve score는 감소한다.
- Efficiency and accuracy trade-off
- Retrieval threshold가 증가할 수록 retrieve frequency가 감소하여 efficiency가 높아지지만 성능에 하락을 보인다.
- 공중 보건 이라는 특정 주제에 대한 내용인 Pubhealth 데이터셋에서는 retrieve frequency가 낮은 경우에도 성능 변화가 작지만, open-domain QA 데이터셋인 popQA 에서는 성능 하락이 크다.
- Training scale and Human analysis
- Training 데이터가 증가할 경우 대부분 모델 성능도 향상되는 모습을 보임
- Human evaluation 결과로부터 self-RAG가 어느 정도 유효한 결과를 낸다는 것을 확인할 수 있음
Conclusion
LLM의 quality와 factuality를 향상시키 위해 Self-RAG를 제안,
- LM을 on-demand로 조건에 따라 passage retrieve를 하도록 학습한다.
- Retrieved passage를 reflection token 이라는 special token을 이용해 self-check를 하여 generation의 quality와 factuality를 향상시킨다.
- 위 방법이 실제로 다양한 task에서 모델의 성능을 향상시켰고, reflection 토큰을 통해 추론 단계에서 LM을 제어할 수 있으므로 다양한 task에서 필요 사항에 따라 동작을 조정할 수 있다.
Discussion
- LLM을 이용한 dataset 생성이 얼마나 신뢰가능할까?
Share article