Skip to content

Commit

Permalink
refactor(hardware): 为 Blob 之间的拷贝处理更多情况
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Nov 27, 2023
1 parent 0592ef0 commit 0d8739f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
14 changes: 10 additions & 4 deletions src/02hardware/include/hardware/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,22 @@ namespace refactor::hardware {

Device &_device;
void *_ptr;
size_t _size;

Blob(decltype(_device) device, size_t);

public:
~Blob();
// clang-format off
void copyFromHost(void const * ) const;
void copyFromHost(void const *, size_t) const;
void copyToHost(void *, size_t) const;
void copyFrom(Blob const &, size_t) const;
void copyTo(Blob const &, size_t) const;

void copyToHost (void * ) const;
void copyToHost (void *, size_t) const;
void copyFrom (Blob const & ) const;
void copyFrom (Blob const &, size_t) const;
void copyTo (Blob const & ) const;
void copyTo (Blob const &, size_t) const;
// clang-format on
constexpr void *get() const noexcept { return _ptr; }
template<class T> constexpr T *get() const noexcept { return static_cast<T *>(_ptr); }
constexpr operator void *() const noexcept { return get(); }
Expand Down
23 changes: 21 additions & 2 deletions src/02hardware/src/device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace refactor::hardware {

Device::Blob::Blob(decltype(_device) device, size_t size)
: _device(device), _ptr(nullptr) {
: _device(device), _ptr(nullptr), _size(size) {
_device.setContext();
_ptr = _device._mem->malloc(size);
}
Expand All @@ -13,24 +13,43 @@ namespace refactor::hardware {
_device.setContext();
_device._mem->free(std::exchange(_ptr, nullptr));
}
void Device::Blob::copyFromHost(void const *ptr) const {
_device.setContext();
_device._mem->copyHD(_ptr, ptr, _size);
}
void Device::Blob::copyFromHost(void const *ptr, size_t size) const {
ASSERT(size <= _size, "size too large");
_device.setContext();
_device._mem->copyHD(_ptr, ptr, size);
}
void Device::Blob::copyToHost(void *ptr) const {
_device.setContext();
_device._mem->copyDH(ptr, _ptr, _size);
}
void Device::Blob::copyToHost(void *ptr, size_t size) const {
ASSERT(size <= _size, "size too large");
_device.setContext();
_device._mem->copyDH(ptr, _ptr, size);
}
void Device::Blob::copyFrom(Blob const &rhs) const {
copyFrom(rhs, rhs._size);
}
void Device::Blob::copyFrom(Blob const &rhs, size_t size) const {
_device.setContext();
ASSERT(size <= rhs._size && size <= _size, "size too large");
if (_device._mem == rhs._device._mem) {
_device.setContext();
_device._mem->copyDD(_ptr, rhs._ptr, size);
} else if (rhs._device.type() == Device::Type::Cpu) {
copyFromHost(rhs._ptr, size);
} else {
std::vector<uint8_t> tmp(size);
rhs.copyToHost(tmp.data(), size);
copyFromHost(tmp.data(), size);
}
}
void Device::Blob::copyTo(Blob const &rhs) const {
rhs.copyFrom(*this);
}
void Device::Blob::copyTo(Blob const &rhs, size_t size) const {
rhs.copyFrom(*this, size);
}
Expand Down

0 comments on commit 0d8739f

Please sign in to comment.