Skip to content

dq.set_matmul_precision

set_matmul_precision(matmul_precision: Literal['low', 'high', 'highest'])

Configure the default precision for matrix multiplications on GPUs and TPUs.

Some devices allow trading off accuracy for speed when performing matrix multiplications (matmul). Three options are available:

  • 'low' reduces matmul precision to bfloat16 (fastest but least accurate),
  • 'high' reduces matmul precision to bfloat16_3x or tensorfloat32 if available (faster but less accurate),
  • 'highest' keeps matmul precision to float32 or float64 as applicable (slowest but most accurate, default setting).
Note

This setting applies only to single precision matrices (float32 or complex64).

Equivalent JAX syntax

This function is equivalent to setting jax_default_matmul_precision in jax.config. See JAX documentation on matmul precision and JAX documentation on the different available options.

Parameters

  • matmul_precision (string 'low', 'high', or 'highest') –

    Default precision for matrix multiplications on GPUs and TPUs.