Skip to content

Commit

Permalink
Remove duplicated Inv Op
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 22, 2023
1 parent e58bd91 commit 49bf1fe
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 33 deletions.
13 changes: 0 additions & 13 deletions pytensor/link/numba/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Det,
Eig,
Eigh,
Inv,
MatrixInverse,
MatrixPinv,
QRFull,
Expand Down Expand Up @@ -125,18 +124,6 @@ def eigh(x):
return eigh


@numba_funcify.register(Inv)
def numba_funcify_Inv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)

@numba_basic.numba_njit(inline="always")
def inv(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)

return inv


@numba_funcify.register(MatrixInverse)
def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
Expand Down
21 changes: 1 addition & 20 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,6 @@ def pinv(x, hermitian=False):
return MatrixPinv(hermitian=hermitian)(x)


class Inv(Op):
"""Computes the inverse of one or more matrices."""

def make_node(self, x):
x = as_tensor_variable(x)
return Apply(self, [x], [x.type()])

def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
z[0] = np.linalg.inv(x).astype(x.dtype)

def infer_shape(self, fgraph, node, shapes):
return shapes


inv = Inv()


class MatrixInverse(Op):
r"""Computes the inverse of a matrix :math:`A`.
Expand Down Expand Up @@ -169,7 +150,7 @@ def infer_shape(self, fgraph, node, shapes):
return shapes


matrix_inverse = MatrixInverse()
inv = matrix_inverse = MatrixInverse()


def matrix_dot(*args):
Expand Down

0 comments on commit 49bf1fe

Please sign in to comment.