LLM
flash-attention Usage: a Worknote for LLM inference
llm techBackground
The flash-attention project provides flash_attn
package in Python, and it provides multiple APIs in the interface.
As the APIs contains many LLM optimization concepts such as paged kv-cache, variant-length (continuous batching) and so on.
This post tries to aggregate related information for the related concepts, and focus on inference only
We will not cover the modules defined for training, and only focus on several basic functional APIs used in inference
, for using the flash_attn
APIs.
Count the parameters in LLaMA V1 model
LLM techLet’s load the model
from transformers import LlamaModel, LlamaConfig
model = LlamaModel.from_pretrained("llama-7b-hf-path")
def count_params(model, is_human: bool = False):
params: int = sum(p.numel() for p in model.parameters() if p.requires_grad)
return f"{params / 1e6:.2f}M" if is_human else params
print(model)
print("Total # of params:", count_params(model, is_human=True))
Print out the layers:
LlamaModel(
(embed_tokens): Embedding(32000, 4096, padding_idx=0)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
Total # of params: 6607.34M
The Transformers shows that there are 6607.34M float16 parameters, roughly 13GB, that is aligned to the actual weight size.
Notes on LLM technologies (keep updating)
LLM techBrief notes on LLM technologies.
Models
GPT2
Model structure

The GPT model employs a repeated structure of Transformer Blocks, each containing two sub-layers: a Masked Multi-Head Attention (MMHA) layer and a Position-wise Feed-Forward Network.
The MMHA is a central component of the model. It operates by splitting the input into multiple ‘heads’, each of which learns to attend to different positions within the input sequence, allowing the model to focus on different aspects of the input simultaneously. The output of these heads is then concatenated and linearly transformed to produce the final output.