SliceGPT: Compress Large Language Models by Deleting Rows and Columns
SliceGPT는 Transformer 기반 언어 모델의 효율성을 극대화하기 위해 Structured Pruning 방식을 제안합니다. 주성분 분석(PCA)으로 weight matrix를 최적화하여 최대 64%의 연산 비용을 줄이면서 성능을 유지합니다.
Apr 25, 2024
Introduction
Motivation
- LLM의 computational complexity를 줄이기 위해 많은 weight pruning 방식이 고안됨
- Unstructured pruning 방식은 모델 구조에 관계 없이 범용적으로, 효율적으로 사용될 수 있지만 다음과 같은 단점 있음
- weight matrix의 중간 중간이 제거되므로 정확도 감소가 클 가능성이 높고, 추가적으로 fine-tuning이 필요함
- matrix의 크기가 유지되므로 (embedding dimension이 동일) 메모리 사용량은 감소하지 않음
- Structured pruning은 차원의 감소로 연산량과 메모리 사용량에서 큰 이점을 가지지만 방법을 찾기 어렵고 범용성이 낮다는 단점이 있음
Contribution
- Computational invariance
- Model을 변경하지 않고 transformer 모델 내의 weight matrix에 orthogonal matrix transformation을 수행할 수 있음을 보임
- 이를 이용하여 transformer 구조에서 주성분 분석을 통해 weight matrix의 차원을 감소시킴
- 제안하는 방법으로 Dense model의 90% 이상의 수준의 성능을 유지하면서 최대 30% pruning이 가능함을 보임
Related work
- Magnitude-based sparsification
- 상대적으로 값이 적은 weight를 0으로 바꿔서 sparse matrix로 변환
- Optimal Brain Sergeon (OBS)
- loss function에 기여도가 낮은 weight를 제거하는 방식 사용
- 매 번 hessian matrix를 저장하여 계산해야 하므로 큰 모델에서는 impractical 함
- SparseGPT
Methodology
Transformer 구조에 적용 가능한 pruning 방식 제안
Computational Invariance in Transformer Networks
- Orthogonal matrix와의 곱 연산은 vector norm을 동일하게 유지한다.
- RMSNorm을 통과하기 이전에 orthogonal matrix Q를 곱하고 통과한 후 Q의 전치행렬을 곱하면 결과 값이 동일하게 유지될 수 있다.
pf)
- Computational invariance 결과를 유지하면서 직교 변환 Q를 transformer weight에 적용할 수 있으므로 어떤 변환된 상태에서도 연산을 수행할 수 있다
Theorem 1에 의해 transformer network를 위와 같이 치환해도 동일한 결과 값을 얻게 된다.
pf)
Layernorm Transformers can be converted to RMSNorm
- 다음과 같은 방법으로 모델 구조를 변경하여 LayerNorm을 RMSNorm으로 변환하여도 동일한 결과를 얻을 수 있다.
- scale matrix diag(alpha)를 다음 W_in matrix에 곱한다
- mean-subtraction matrix M을 W_out matrix에 곱한다
- alpha’ 를 이전 weight matrices에 곱한다
A transformation per Block
- Transformer 내의 모든 LayerNorm이 RMSNorm으로 변환됨
- Orthogonal matrix 계산
- Training dataset을 이용해 transformed network의 output에 PCA를 적용해 orthogonal matrix Q를 계산함
- dataset의 i 번째 sequence에 대한 l 번째 RMSNorm block의 output을 X_l,i 라 할 때 공분산 행렬 C를 다음과 같이 계산:
- Q_l 은 C_l의 sorted eigenvector로 구해진다.
- Orthogonal matrix 추가
- input weight matrix 앞에 Q^T 곱합
- output weight matrix 이후에 Q 곱합
- 이 때 residual connection의 경우 이전, 이후 레이어에 맞게 (Q^T_l-1 Q_l) linear transformation을 추가함
Slicing
- 다음과 같이 weight matrix X를 slicing 하여 차원 축소
- Q: X^T X의 eigenvector
- D: D*D_small 차원의 deletion matrix, column 수가 더 적은 identity matrix
- X의 lower dimensional representaion Z를 구한 후, 이후 Z를 이용해 weight matrix의 근사치 계산
위 방식을 통해 weight matrix의 차원이 감소되는 효과를 볼 수 있음
Experiments
Generation Task
- Dataset: WikText-2
- 30%까지 slice했을 때 2.7B OPT 모델을 제외하고 모두 2:4 sparseGPT보다 높은 성능 보임
Zero-shot tasks
- Dataset: PIQA, WinoGrande, HellaSwag, ARC-e, ARC-c
- OPT 모델이 다른 모델에 비해 pruning에 더 높은 효율성 보임
- 크기가 큰 모델일 수록 pruning에 의한 정확도 감소가 낮음
Benchmarking Throughput
- 25% sliced 모델의 경우 PPL의 큰 증가 없이 최대 1.55배의 throughput 증가 보임
- 50% sliced 모델의 경우 PPL이 크게 증가하였으나 최대 6.26배의 throughput 증가를 보였으며 OPT 66B, LLAMA-2 70B의 경우 1개의 GPU로도 연산이 가능함
Inference Time
- A100에서 11-13%, RTX6000에서 16-17%의 속도 향상 보임
- GPU time으로 봤을 때 RTX6000에서 계산량이 최대 64% 감소
Compute cost
- LLAMA-2와 Phi-2 모델의 Slicing에 걸린 시간은 1-3h, recovery fine-tuning에 1-5 시간 소요됨
Conclusion
- Transformer based LLM에 적용할 수 있는 새로운 structured pruning 방식 제안
- 제안한 방식으로 여러 LLM에 대한 추가적인 code optimization 없이 inference computational cost를 최대 64% 감소시킴
- Cost를 감소시키면서도 dense model에 대해 90% 내외의 성능을 보임
Discussion
- 큰 모델의 pruned 모델이 이 원래 작은 모델보다 낮은 성능을 보인다면 pruning의 의미가 있을까?
- Zero-shot task에서 LLAMA-2 70B의 30% sliced 모델이 13B 모델보다 낮은 성능을 보임
Share article