LLM

flash-attention Usage: a Worknote for LLM inference

llm tech

Background

The flash-attention project provides flash_attn package in Python, and it provides multiple APIs in the interface. As the APIs contains many LLM optimization concepts such as paged kv-cache, variant-length (continuous batching) and so on. This post tries to aggregate related information for the related concepts, and focus on inference only We will not cover the modules defined for training, and only focus on several basic functional APIs used in inference , for using the flash_attn APIs.

Count the parameters in LLaMA V1 model

LLM tech

Let’s load the model

from transformers import LlamaModel, LlamaConfig
model = LlamaModel.from_pretrained("llama-7b-hf-path")

def count_params(model, is_human: bool = False):
    params: int = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return f"{params / 1e6:.2f}M" if is_human else params

print(model)
print("Total # of params:", count_params(model, is_human=True))

Print out the layers:

LlamaModel(
  (embed_tokens): Embedding(32000, 4096, padding_idx=0)
  (layers): ModuleList(
    (0-31): 32 x LlamaDecoderLayer(
      (self_attn): LlamaSdpaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)
Total # of params: 6607.34M

The Transformers shows that there are 6607.34M float16 parameters, roughly 13GB, that is aligned to the actual weight size.

Notes on LLM technologies (keep updating)

LLM tech

Brief notes on LLM technologies.

Models

GPT2

Model structure

The GPT model employs a repeated structure of Transformer Blocks, each containing two sub-layers: a Masked Multi-Head Attention (MMHA) layer and a Position-wise Feed-Forward Network.

The MMHA is a central component of the model. It operates by splitting the input into multiple ‘heads’, each of which learns to attend to different positions within the input sequence, allowing the model to focus on different aspects of the input simultaneously. The output of these heads is then concatenated and linearly transformed to produce the final output.