Skip to content

Why Transformer?

Before Transformers

Model Structure

attention mechanism

Attention mechanism allows the model to attend every token in the sequence with different amount of focus for each token.

scaled dot-product attention

Before applying softmax to the dot product attention, it should be scaled by a factor of dk\sqrt{d_{k}} to avoid gradient vanishing and slow training.

self-attention

mask interactions between two tokens by setting the attention values to -\infty before softmax layer.

cross-attention

In self-attention, we are working with the same input sequence. While in cross-attention, we are mixing or combining two different input sequences. In the case of the vanilla transformer architecture, that’s the sequence returned by the last/top encoder layer on the left and the input sequence being processed by the decoder part on the right.

causal/masked attention

layer norm


read more

Calculating Transformers Parameters

On a high level, the transformer model consists of LL identical blocks, each block composed of an attention module and an MLP module, or FFN for feed-forward neural network.

The weight matrices for query QQ, key KK, value VV and output OO are Wq,Wk,WvW_{q}, W_{k}, W_{v}, and WoRh×hW_{o}\in\mathbb{R}^{h\times h}, respectively. Same goes for bias matrices of shape Rh\mathbb{R}^{h}.1 Hence the parameters size for this part is 4h2+4h4h^{2}+4h.

The FFN module has two linear layers. What happenes is the first layer scales up to a higher dimension, or intermediate dimension, and then the second layer scales back down to a dimension of hh. Back in GPT’s early days, the scaling factor is 4 (recent models adopt different intermediate dimensions but around 3 to 5 times of hh) 2, i.e., the weight matrix for the first layer is W1Rh×4hW_{1}\in\mathbb{R}^{h\times 4h} and the weight matrix for the second layer is W2R4h×hW_{2}\in\mathbb{R}^{4h\times h}. The bias matrices are R4h\mathbb{R}^{4h} and Rh\mathbb{R}^{h}, respectively. Hence the parameters size for the MLP module is 8h2+5h8h^{2}+5h.

Dont’t forget about LayerNorm. Both self attention module and MLP module are equipped with layer norm layers, learnable parameters including weights γ\gamma and biases β\beta. They are all Rh\mathbb{R}^{h}. Hence the parameters size for layer norm is 4h4h.

In terms of positional encoding, there is a relatively small amount of parameters if the encoding is learnable. For relative positional encoding, such as RoPE and ALiBi, no trainable parameters are included.

As a matter of fact, the model starts with tokenization with word embedding and positional embedding. Word embedding matrix is of shape RV×h\mathbb{R}^{V\times h}. To reduce memory footprint, many models made the adoption to share the same parameters for the FFN in the final output layer and the word embedding.

Take a look at the model layers of EleutherAI’s gpt-neo-1.3B, a replication of the GPT-3 architecture.
Layer: transformer.wte.weight, Size: torch.Size([50257, 2048])
Layer: transformer.wpe.weight, Size: torch.Size([2048, 2048])
Layer: transformer.h.0.ln_1.weight, Size: torch.Size([2048])
Layer: transformer.h.0.ln_1.bias, Size: torch.Size([2048])
Layer: transformer.h.0.attn.attention.k_proj.weight, Size: torch.Size([2048, 2048])
Layer: transformer.h.0.attn.attention.v_proj.weight, Size: torch.Size([2048, 2048])
Layer: transformer.h.0.attn.attention.q_proj.weight, Size: torch.Size([2048, 2048])
Layer: transformer.h.0.attn.attention.out_proj.weight, Size: torch.Size([2048, 2048])
Layer: transformer.h.0.attn.attention.out_proj.bias, Size: torch.Size([2048])
Layer: transformer.h.0.ln_2.weight, Size: torch.Size([2048])
Layer: transformer.h.0.ln_2.bias, Size: torch.Size([2048])
Layer: transformer.h.0.mlp.c_fc.weight, Size: torch.Size([8192, 2048])
Layer: transformer.h.0.mlp.c_fc.bias, Size: torch.Size([8192])
Layer: transformer.h.0.mlp.c_proj.weight, Size: torch.Size([2048, 8192])
Layer: transformer.h.0.mlp.c_proj.bias, Size: torch.Size([2048])
...<23 identical layers omitted>...
Layer: transformer.ln_f.weight, Size: torch.Size([2048])
Layer: transformer.ln_f.bias, Size: torch.Size([2048])

Memory Footprint During Training

During the training process, the memory footprint is mainly divided into four parts: model parameters, intermediate activations results produced during the forward pass, gradients computed during the backward pass, and optimizer states. Here we focus on the memory footprint of parameters, gradients, and optimizer states. During training large language models, AdamW optimizer is commonly used, and mixed precision training is used to accelerate the training process. Based on this premise, we now take on analyzing the memory footprint in the training process.

Inside a typical training iteration, each learnable parameter corresponds to one gradient and two optimizer states (first and second order momentums from AdamW). Denote the number of learnable parameters in the model as Φ\varPhi, the number of gradients is also Φ\varPhi, and the number of optimizer states is 2Φ2\varPhi.

A float16 typed data occupies 2 bytes, 4 bytes for float32. In mixed precision training, float16 is used for forward and backward passes, hence the gradients are stored in float16. During model parameter update, float32 optimizer states, float32 gradients, and float32 model parameters are used. Therefore, for each learnable parameter, it occupies:

2+4weights+2+4gradients+4+4optimizer states=20 bytes\underbrace{2+4}_{\text{weights}}+\underbrace{2+4}_{\text{gradients}}+\underbrace{4+4}_{\text{optimizer states}}=20\text{ bytes}

Memory Footprint During Inference

During the inference process, there is no optimizer states and gradients, and we don’t need to store intermediate activation results. The memory footprint is therefore significantly smaller than that of training. The majority of memory footprint comes from the model parameters. If float16 is used for inference, the memory footprint of model parameters is about 2Φ2\varPhi bytes. Moreover, if KV-Cache is used for speeding up inference, it would also induce additional memory footprint.

Estimating FLOPs

Footnotes

  1. 分析transformer模型的参数量、计算量、中间激活、KV cache

  2. For instance, llama2 uses an intermediate dimension of 11008 (scaled by 2.6875 times of hh), Qwen2 uses 22016 (scaled by 5.375 times of hh), while mistal and llama 3 use 14336 (scaled by 3.5 times of hh). They all use 4096 as the hidden dimension.