dq.set_matmul_precision
set_matmul_precision(matmul_precision: Literal['low', 'high', 'highest'])
Configure the default precision for matrix multiplications on GPUs and TPUs.
Some devices allow trading off accuracy for speed when performing matrix multiplications (matmul). Three options are available:
'low'
reduces matmul precision tobfloat16
(fastest but least accurate),'high'
reduces matmul precision tobfloat16_3x
ortensorfloat32
if available (faster but less accurate),'highest'
keeps matmul precision tofloat32
orfloat64
as applicable (slowest but most accurate, default setting).
Equivalent JAX syntax
This function is equivalent to setting jax_default_matmul_precision
in
jax.config
. See JAX documentation on matmul precision
and JAX documentation on the different available options.
Parameters
-
matmul_precision
(string 'low', 'high', or 'highest')
–
Default precision for matrix multiplications on GPUs and TPUs.