Count the parameters in LLaMA V1 model
LLM techTable of Contents
Let’s load the model
from transformers import LlamaModel, LlamaConfig
model = LlamaModel.from_pretrained("llama-7b-hf-path")
def count_params(model, is_human: bool = False):
params: int = sum(p.numel() for p in model.parameters() if p.requires_grad)
return f"{params / 1e6:.2f}M" if is_human else params
print(model)
print("Total # of params:", count_params(model, is_human=True))
Print out the layers:
LlamaModel(
(embed_tokens): Embedding(32000, 4096, padding_idx=0)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
Total # of params: 6607.34M
The Transformers shows that there are 6607.34M float16 parameters, roughly 13GB, that is aligned to the actual weight size.
The basic setting of the 7B model
- model dimension \(d_{model}=4096\)
- number of heads \(n_{head}=32\)
- head size \(d_{head} = \frac{d_{model}}{n_{head}}\)
- dimension of the feed-forward network’s inner layer \(d_{ff}=11008\)
- number of tokens \(n_{token}=32000\)
- number of transformer layers \(n_{layer}=32\)
Layer-by-Layer Parameter Count
Embedding layer
For vocabulary embedding, \(n_{token}\times d_{model}=131.072M\), while for position embedding, since RoPE doesn’t need a separate embedding, so that is 0.
Transformer layers
input_layernorm
and post_attention_layernorm
Both are RMSNorm whose parameters are \(d_{model}\), so both sum to \(2\times d_{model}=8M\)
multi-head self-attention
For Q,K,V and O, each is a Linear layer of size \(d_{model} \times d_{model}\), so in total, there are \(4\times d_{model}^2=67.1M\).
There is one tiny issue here, why a linear layer could generate Q, while in the original transformer paper, each head is calculated separately, for example, \(Q_i=QW^Q_i\) where \(i\) is the head id. That is because, if we concatenate all all the heads, that is identical to a linear of \(d_{model} \times (n_{head} \times d_{head})\), that is \(d_{model} \times d_{model}\) in llama v1.
The self-attention doesn’t have extra parameters since they simply applies the following formula
\[ Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V \]
mlp
The LlamaMLP layer contains three separate Linear layers:
gate_proj
: \(d_{model} \times d_{ff}\)up_proj
: \(d_{model} \times d_{ff}\)down_proj
: \(d_{ff} \times d_{model}\)
So in total, they have \(3\times d_{model} \times d_{ff} = 135.27M\) parameters.
Total count of parameters
The overall parameters are composed of two major parts, the vocabulary embedding, and the transformer layers, that is embed + 32 * (mha + mlp + norm)
:
- \(embed=n_{token}\times d_{model}=131.07M\)
- \(mha=4* d_{model}^2=67.1M\)
- \(mlp=3* d_{model}\times d_{ff}=135.27M\)
- \(norm=2*d_{model}=8.19M\)
And the count of the parameters is 6607.3M, which is aligned to the number from Transformers.
def count_llama_params(d_model, d_ff, n_tokens, n_layers):
embed = n_tokens * d_model
mha = 4 * d_model**2
mlp = 3 * d_moel * d_ff
norm = 2 * d_model
return embed + n_layers * (mha + mlp + norm)
For example, the Llama 65B model
LlamaModel(
(embed_tokens): Embedding(32000, 8192, padding_idx=0)
(layers): ModuleList(
(0-79): 80 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=8192, out_features=8192, bias=False)
(k_proj): Linear(in_features=8192, out_features=8192, bias=False)
(v_proj): Linear(in_features=8192, out_features=8192, bias=False)
(o_proj): Linear(in_features=8192, out_features=8192, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=8192, out_features=22016, bias=False)
(up_proj): Linear(in_features=8192, out_features=22016, bias=False)
(down_proj): Linear(in_features=22016, out_features=8192, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
Total # of params: 65023.52M
And let’s use the function
count_llama_params(d_model=8192,
d_ff=22016,
n_tokens=32000,
n_layers=80)
It gives 65023.5M, is is roughly aligned.