Count the parameters in LLaMA V1 model

LLM tech

Table of Contents


Let’s load the model

from transformers import LlamaModel, LlamaConfig
model = LlamaModel.from_pretrained("llama-7b-hf-path")

def count_params(model, is_human: bool = False):
    params: int = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return f"{params / 1e6:.2f}M" if is_human else params

print(model)
print("Total # of params:", count_params(model, is_human=True))

Print out the layers:

LlamaModel(
  (embed_tokens): Embedding(32000, 4096, padding_idx=0)
  (layers): ModuleList(
    (0-31): 32 x LlamaDecoderLayer(
      (self_attn): LlamaSdpaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)
Total # of params: 6607.34M

The Transformers shows that there are 6607.34M float16 parameters, roughly 13GB, that is aligned to the actual weight size.

The basic setting of the 7B model

  • model dimension \(d_{model}=4096\)
  • number of heads \(n_{head}=32\)
  • head size \(d_{head} = \frac{d_{model}}{n_{head}}\)
  • dimension of the feed-forward network’s inner layer \(d_{ff}=11008\)
  • number of tokens \(n_{token}=32000\)
  • number of transformer layers \(n_{layer}=32\)

Layer-by-Layer Parameter Count

Embedding layer

For vocabulary embedding, \(n_{token}\times d_{model}=131.072M\), while for position embedding, since RoPE doesn’t need a separate embedding, so that is 0.

Transformer layers

input_layernorm and post_attention_layernorm

Both are RMSNorm whose parameters are \(d_{model}\), so both sum to \(2\times d_{model}=8M\)

multi-head self-attention

For Q,K,V and O, each is a Linear layer of size \(d_{model} \times d_{model}\), so in total, there are \(4\times d_{model}^2=67.1M\).

There is one tiny issue here, why a linear layer could generate Q, while in the original transformer paper, each head is calculated separately, for example, \(Q_i=QW^Q_i\) where \(i\) is the head id. That is because, if we concatenate all all the heads, that is identical to a linear of \(d_{model} \times (n_{head} \times d_{head})\), that is \(d_{model} \times d_{model}\) in llama v1.

The self-attention doesn’t have extra parameters since they simply applies the following formula

\[ Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V \]

mlp

The LlamaMLP layer contains three separate Linear layers:

  1. gate_proj: \(d_{model} \times d_{ff}\)
  2. up_proj: \(d_{model} \times d_{ff}\)
  3. down_proj: \(d_{ff} \times d_{model}\)

So in total, they have \(3\times d_{model} \times d_{ff} = 135.27M\) parameters.

Total count of parameters

The overall parameters are composed of two major parts, the vocabulary embedding, and the transformer layers, that is embed + 32 * (mha + mlp + norm):

  • \(embed=n_{token}\times d_{model}=131.07M\)
  • \(mha=4* d_{model}^2=67.1M\)
  • \(mlp=3* d_{model}\times d_{ff}=135.27M\)
  • \(norm=2*d_{model}=8.19M\)

And the count of the parameters is 6607.3M, which is aligned to the number from Transformers.

def count_llama_params(d_model, d_ff, n_tokens, n_layers):
    embed = n_tokens * d_model
    mha = 4 * d_model**2
    mlp = 3 * d_moel * d_ff
    norm = 2 * d_model
    return embed + n_layers * (mha + mlp + norm)

For example, the Llama 65B model

LlamaModel(
  (embed_tokens): Embedding(32000, 8192, padding_idx=0)
  (layers): ModuleList(
    (0-79): 80 x LlamaDecoderLayer(
      (self_attn): LlamaSdpaAttention(
        (q_proj): Linear(in_features=8192, out_features=8192, bias=False)
        (k_proj): Linear(in_features=8192, out_features=8192, bias=False)
        (v_proj): Linear(in_features=8192, out_features=8192, bias=False)
        (o_proj): Linear(in_features=8192, out_features=8192, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=8192, out_features=22016, bias=False)
        (up_proj): Linear(in_features=8192, out_features=22016, bias=False)
        (down_proj): Linear(in_features=22016, out_features=8192, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)
Total # of params: 65023.52M

And let’s use the function

count_llama_params(d_model=8192,
    d_ff=22016,
    n_tokens=32000,
    n_layers=80)

It gives 65023.5M, is is roughly aligned.

References