Skip to content

Commit

Permalink
Fix numpy dtype conversion in TensorType
Browse files Browse the repository at this point in the history
TensorType.dtype must be a string, so the code
has been changed from `self.dtype = np.dtype(dtype).type`,
where the right-hand side is of type `np.generic`, to
`self.dtype = str(np.dtype(dtype))`, where the right-hand
side is a string that satisfies:

`self.dtype == str(np.dtype(self.dtype))`

This doesn't change the behavior of `np.array(..., dtype=self.dtype)`
etc.
  • Loading branch information
brendan-m-murphy committed Feb 14, 2025
1 parent 175ea07 commit e7b728f
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions pytensor/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Literal, Optional

import numpy as np
import numpy.typing as npt

import pytensor
from pytensor import scalar as ps
Expand Down Expand Up @@ -69,7 +70,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):

def __init__(
self,
dtype: str | np.dtype,
dtype: str | npt.DTypeLike,
shape: Iterable[bool | int | None] | None = None,
name: str | None = None,
broadcastable: Iterable[bool] | None = None,
Expand Down Expand Up @@ -101,11 +102,11 @@ def __init__(
if str(dtype) == "floatX":
self.dtype = config.floatX
else:
if np.dtype(dtype).type is None:
try:
self.dtype = str(np.dtype(dtype))
except TypeError:
raise TypeError(f"Invalid dtype: {dtype}")

self.dtype = np.dtype(dtype).name

def parse_bcast_and_shape(s):
if isinstance(s, bool | np.bool_):
return 1 if s else None
Expand Down

0 comments on commit e7b728f

Please sign in to comment.