Summary

transformer arch.를 따르는 모델에서 inference 시 중복 연산 과정에서 발생하는 throughput bottleneck 해소를 위핸 K, V matrix를 저장해놓는 방법.

Important

attn 연산의 결과는 결국 token 하나하나당 embedding이 완성되는 과정.

Details


Summary

Attn 연산을 생각해보면, 결국 input인 hidden state를 3개의 weight matrix로 affine-transformation시켜서, q,k,v를 만든 다음 이들간 행렬 연산.

attention map을 그려보면, 정방 행렬꼴로 그려질텐데, 그 행렬은 결국 q@k^T
auto-regressive한 token-generation 과정을 생각해보면, 새로운 token(이전 step에서 생성된 token)만 새로운 input으로 사용될 거임.
이전에 계산된 k, v가 있을 거고 q token은 현재 step을 포함한 이후에도 사용되지 않으므로 caching대상이 아님.
caching해둔 k, v를 사용한다면 결국 attn-map에서 추가되어야 하는 부분은 현재 token에 대한 행과 열.
그마저도 gpt 같은 decoder-only 구조에서는 casual-attn이니, 행 부분만 채우면 됨.
이전에 연산된 attn-score 같은 경우, inference 시에는 저장해두지 않는데, 이유는 마찬가지로 이후 연산에 사용되지 않기 때문. new-token generation에는 새로운 query에 대한 연산이 진행되는 방향만 영향을 주니까.

useful image: self-attn in Attention