dq.tracemm
tracemm(x: ArrayLike, y: ArrayLike) -> Array
Return the trace of a matrix multiplication using a fast implementation.
The trace is computed as sum(x * y.T)
where *
is the element-wise product,
instead of trace(x @ y)
where @
is the matrix product. Indeed, we have:
\[
\tr{xy} = \sum_i (xy)_{ii}
= \sum_{i,j} x_{ij} y_{ji}
= \sum_{i,j} x_{ij} (y^\intercal)_{ij}
= \sum_{i,j} (x * y^\intercal)_{ij}
\]
Note
The resulting time complexity for \(n\times n\) matrices is \(\mathcal{O}(n^2)\) instead of \(\mathcal{O}(n^3)\) with the naive formula.
Parameters
-
x
(array_like of shape (..., n, n))
–
Array.
-
y
(array_like of shape (..., n, n))
–
Array.
Returns
(array of shape (...)) Trace of x @ y
.
Examples
>>> x = jnp.ones((3, 3))
>>> y = jnp.ones((3, 3))
>>> dq.tracemm(x, y)
Array(9., dtype=float32)