• AIPressRoom
  • Posts
  • Accurate Quantized Training (AQT) for TPU v5e

Accurate Quantized Training (AQT) for TPU v5e

AI models continue to get bigger, requiring larger compute clusters with exa-FLOPs (10^18 FLOPs) of computing. While large-scale models continue to unlock new capabilities, driving down the cost of training and serving these models is the key to sustaining the pace of this innovation.

Typically, the tensor operations (ops)1 are the most compute-intensive part of large artificial intelligence (AI) models. The recently announced Cloud TPU v5e can execute INT8 tensor ops up to 2x faster than the default BFLOAT16 tensor ops. Similarly, some NVIDIA GPUs can execute FLOAT8 or INT8 tensor ops up to 2x faster than BFLOAT16 tensor ops. In order to benefit from these capabilities, especially in production settings, comprehensive software support is needed.

This is where quantization comes into play. Quantization enables reduced (e.g. INT8) precision operations (i.e. tensor ops) and is one of the few effective methods for significantly increasing the efficiency of modern machine learning (ML) hardware. Quantized training reduces the hardware cost of training ML models.

Quantized training acceleration is hard

There are three families of model quantization algorithms:

  • Post Training Quantization (PTQ)

  • Quantization Aware Training (QAT)

  • Quantized Training (QT)

PTQ is a process of turning weights from BFLOAT16 to INT8 (or other similar format). It has the advantage of not requiring access to training data or to the training hardware. However, because we train a different model than we actually serve, it often suffers from poor quality. PTQ also usually does not utilize accelerated INT8 tensor operations.

QAT improves on PTQ by introducing the quantization logic in the forward pass, before the training or before fine-tuning. It allows the training process to take into account quantization numerics and learn around it. This process makes it easy to obtain improved model quality. QAT also allows quantization of both tensor operation inputs and which enables the use of INT8 tensor acceleration on TPU v5e during inference. QAT training length remains unchanged.

QT takes quantization a step further. Not only is the forward pass quantized, but so is the backward pass (gradient backpropagation). This preserves all the benefits of QAT while also accelerating the training itself.

Even with the right algorithm, getting QT to work on real hardware in production can be hard due to software complexity and computational overheads of quantization. Up until now QT (including backpropagation quantization) was largely confined to research papers. However, the open-source AQT library hides the software complexity and algorithmic complexity, allowing any production model owner to benefit from QT, increasing the value of TPU v5e to users, and considerably simplifying QT research.

Introducing Accurate Quantized Training (AQT) library

We’re excited to introduce the open-source Accurate Quantization Training (AQT) library that provides the software support needed for easy tensor operation quantization in JAX.

The main goals of AQT library is to simultaneously provide:

  • improved training performance in production

  • improved model quality with no hand-tuning

  • A simple and flexible API to simultaneously serve production and quantization research

For more information, consult AQT README.md.

AQT INT8-mode delivers improved hardware performance

AQT has allowed us to achieve remarkable speed improvements in large language model (LLM) training. Numbers below indicate BFLOAT16 / INT8 step time ratio measured on MaxText 16B and MLPerfTM 3.1 results:

  • MaxText 16B training: 9,054 ms / 7,268 ms = 124%

  • MLPerfTM 3.1 GPT-3 175B Training: 11,798ms / 8,431ms = 139%

Details of AQT configuration and MaxText model configuration can be found in the appendix. All runs used Google Cloud TPU v5e.

MaxText experiments were done before we implemented additional AQT optimization (local AQT) for MLPerfTM 3.1.

AQT delivers improved model quality

The quality difference between the AQT INT8 and BFLOAT16 models measured as training loss deterioration is almost indistinguishable even with a long training.

Measuring tiny model differences

To measure miniscule quantization-induced deterioration, one needs to remove other sources of noise from the training loss. We configured MaxText to train deterministically by controlling the randomness of model initialization and data generation. The variation in training loss caused by either of them is larger than the quantization-induced deterioration of the training loss.

Results

We measure quantization quality by quantization-induced loss of training loss, i.e., the difference between training loss in BF16 and INT8 models. The quantization-induced deterioration of training loss is 0.00133, which is less than 0.1% of the final training loss. The tradeoff of this relatively insignificant loss in return for a considerable amount of training performance boost validates the power of AQT and INT8 techniques compared to unquantized BFLOAT16 training.

The plot shows a log loss of quantized and unquantized 16B models (configuration details in the appendix).