dq.set_device
set_device(device: Literal['cpu', 'gpu', 'tpu'], index: int = 0)
Configure the default device.
Equivalent JAX syntax
This function is equivalent to
jax.config.update('jax_default_device', jax.devices(device)[index])
See JAX documentation on devices.
Parameters
-
device
(string 'cpu', 'gpu', or 'tpu')
–
Default device.
-
index
–
Index of the device to use, defaults to 0.