Mamba: Linear-Time Sequence Modeling with Selective State Spaces
DetPTQ와 ODOL을 활용해 Document AI 모델의 PTQ 성능을 혁신적으로 개선하고, 성능 저하 없이 효율적 양자화를 실현합니다.
Apr 25, 2024
1. Introduction
- 현재 대부분의 foundation 모델들은 transformer 구조와 attention 메커니즘을 기반으로 함.
- transformer 기반 모델은 시퀀스가 길어질수록 계산량이 exponentially 하게 증가함.
- 최근 SSM을 강화한 structured state space sequence models (S4) 가 시퀀스 모델링을 위한 유망한 아키텍처로 등장.
- S4 은 linear or near-linear 수준의 계산량을 가짐.
- 하지만 텍스트와 같은 이산적이고 정보 밀도가 높은 데이터를 모델링하는 데 덜 효과적.
- 이러한 문제를 해결한 Selective State Spaces model (Mamba) 를 소개
[ Mamba’s contribution ]
Selection Mechanism.
입력에 따라 데이터를 효율적으로 선택하는 능력
관련 없는 정보를 걸러내고 관련 있는 정보를 계속해서 저장하는 선택 메커니즘 설계
Hardware-aware Algorithm.
SSM 모델은 계산 효율성을 위해 시간 및 입력에 불변해야 함
GPU 메모리 계층 간의 I/O 액세스를 피하기 위해 확장된 상태를 생성하지 않고 모델을 스캔 방식으로 반복적으로 계산하는 하드웨어 인식 알고리즘으로 이를 극복
Architecture.
selective state spaces 을 SSM 아키텍처와 결합하고, 이를 확장한 Mamba 아키텍처 구현
- sequence length 1M
- Transformer보다 5배 빠른 추론 속도
- Mamba-3B 에서 두 배 큰 크기의 Transformer 성능과 일치
2. Background
a. State Space Models (SSM)
u : 입력 | y : 출력 | x : 상태 변수
입력 u, 출력 y 및 상태 x는 모두 시간에 따라 달라지며 u(t), y(t), x(t)를 참조
- x에 행렬 A를 곱하고, u에 행렬 B를 곱한 후 합산하여 x'를 생성
- x'에 행렬 C를 곱하고, u에 행렬 D를 곱하고 합산하면 y가 생성
A는 잠재 상태 x를 제어하는 행렬
B는 입력 상태을 얼마나 고려할 것인지 결정하는 행렬
C는 최종 출력을 생성할 때 hidden state 를 얼마나 고려할 것인지 결정하는 행렬
D는 ResNet 의 skip connection 과 같은 역할
1. Continuous Representation
- State Space Models (SSM) 의 기본 구조
2. Recurrent Representation
- 현실 세계에서 처리하는 데이터는 이산적이기 때문에 이산화해야 함
- 각각의 개별 데이터 점을 계산
- Zero-order hold (ZOH) 와 같은 규칙으로 이산화를 진행
- 델타는 이산화 샘플링을 할 시간을 나타냄
- 입력 데이터의 세분성
A 바 : A에 샘플링 시간 델타를 곱한 후 지수 함수 적용
B 바 : 델타 A의 역행렬과 A 바에서 단위 행렬 I 를 뺀 행렬의 곱을 입력 행렬 B 와 곱함
이산화 된 각각의 점들을 그려보면 RNN 과 유사한 반복 표현을 얻음
3. Convolutional Representation
각각 시점에 대한 상태 변수 X
각각 시점에 대한 출력 상태 Y
Y 방정식의 일반화
위의 방정식에서 입력을 제외하고 계수만 추출하면 SSM 커널인 K바를 얻음
순차적으로 계산하는 대신 입력 벡터 u에 대해 컨볼루션을 수행하여 병렬 계산을 수행할 수 있음
b. structured state space sequence models (S4)
- SSM 의 잠재 영역 X 의 차원이 크게 증가하면서 계산 시간이 오래 걸림
- SSM 의 병목은 A 바 행렬의 반복된 곱셈 수행 단계
- 행렬을 대각선으로 만들어 모델을 단순화
- A 행렬을 HiPPO Matrix 를 통해 초기화
3. Selective State Space Models
SSM 의 문제점
- 텍스트와 같이 개별적이고 정보 밀도가 높은 데이터를 모델링하는 데 성능이 부족
- 입력과 관련하여 적절하게 데이터를 효과적으로 선택하는 능력이 부족
Motivation
- attention은 context를 전혀 압축하지 않고 사용하는 효과적이면서도 비효율적인 방법
- RNN과 같은 순환 모델은 전체 컨텍스트를 저장하지 않고 정보를 유한 상태로 압축
- 시퀀스 모델의 효율성 대 성능 트레이드오프는 상태를 얼마나 효과적으로 압축하는지에 따라 결정
1. Improving SSMs with Selection
Copying
- 입력과 출력 요소 사이의 일정한 간격을 가짐
- 시간 불변 모델로 해결 가능
Selective Copying
- 의미 있는 정보만 기억하고 의미 없는 데이터는 무시
- 토큰의 중요도에 따라 정보를 선택하는 것이 가능
- 무작위성 때문에 시간 변동 모델 필요
- Induction Head는 이전 컨텍스트를 기반으로 다음 토큰을 생성하는 기능을 테스트하는 데 사용
SB(x), SC(x) 는 linear function
Time-invariant 였던 기존 SSM 로직에서 Time-varying 으로 바뀜
따라서 convolution representation 에서의 병렬 처리가 불가능해짐
하지만
recurrent 계산은 O(BLDN) FLOP
convolution 계산은 O(BLD log(L)) FLOP
시퀀스 길이(L) 가 매우 길고 상태 차원 (N) 이 크지 않다면 recurrent 이 효율적일 수 있음
2. Efficient Implementation of Selective SSMs
광범위한 메모리 사용량을 사용하지 않기 위해 상태 ‘h’ 를 구체화하지 않음
SRAM 은 (HBM) DRAM에 비해 액세스 속도가 100배 이상 빠름
상태 ‘h’는 저장되지 않은 상태로 SRAM 에 올라가있고
입력과 출력인 x, y 만 DRAM 에 저장되어 연산 시만 SRAM 에 로드하여 계산하는 방식
메모리 I/O 양을 줄이기 위해 커널 퓨전을 사용
3. A Simplied SSM Architecture
기존의 SSM 구조에서 사용하는 아키텍쳐인 H3 형식에서
Mamba 는 Gated MLP 구조를 결합한 아키텍쳐로 변경
activation function 으로는 SiLU / Swish 을 사용
4. Additional Model Details
Real vs. Complex
- 이전의 SSM 모델은 강력한 성능을 위해 상태 h 에서 복소수를 사용
- Mamba 에서는 어떤 경우에는 실수가 더 나은 결과를 낼 수 있다는 경험적 관찰이 있음
- 대부분의 작업에서 실수를 사용
Initialization
HiPPO Matrix 를 사용하여 초기화
4. Evaluation
Scaling Laws
125만부터 약 13억 개의 매개변수 크기를 가진 모델
Transformer++ 와 맞먹는 수준의 모델
Scaling: Context Length
Zero-shot Evaluations
Speed and Memory Benchmarks
학습(왼쪽): 표준 구현보다 40배 빠름
추론(오른쪽): Mamba는 Transformers보다 5배 높은 처리량을 달성할 수 있음
5. Conclusion
- 선택성으로 인해 언어나 유전체학과 같은 밀집된 데이터에 대해 우수한 성능을 발휘
- 시퀀스 길이에 따른 계산 및 메모리의 선형 증가로 인해 더 빠른 훈련 및 추론이 가능
- 품질과 효율성을 모두 자랑하며약 1M의 매우 긴 시퀀스에서도 잘 작동
Share article