Skip to content

Commit

Permalink
✨Supported create
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Jan 24, 2024
1 parent 70b2316 commit 7f85d2b
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions core/toolkit/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,13 +587,20 @@ def __init__(
name: str,
dtype: Union[type, np.dtype],
shape: Union[List[int], Tuple[int, ...]],
*,
create: bool = True,
data: Optional[np.ndarray] = None,
):
d_size = np.dtype(dtype).itemsize * np.prod(shape).item()
self.name = name
self.dtype = dtype
self.shape = shape
self._shm = SharedMemory(create=True, size=int(round(d_size)), name=name)
if create:
d_size = np.dtype(dtype).itemsize * np.prod(shape).item()
self._shm = SharedMemory(name, create=True, size=int(round(d_size)))
else:
if data is not None:
raise ValueError("`data` should not be provided when `create` is False")
self._shm = SharedMemory(name)
self.value = np.ndarray(shape=shape, dtype=dtype, buffer=self._shm.buf)
if data is not None:
self.value[:] = data[:]
Expand All @@ -607,7 +614,7 @@ def destroy(self) -> None:

@classmethod
def from_data(cls, data: np.ndarray) -> "SharedArray":
return cls(random_hash()[:16], data.dtype, data.shape, data)
return cls(random_hash()[:16], data.dtype, data.shape, data=data)


def to_labels(logits: np.ndarray, threshold: Optional[float] = None) -> np.ndarray:
Expand Down

0 comments on commit 7f85d2b

Please sign in to comment.