Skip to content

[Test] Seemed to pass CI in this form, need to verify #18189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sycl/include/sycl/platform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,16 @@ inline namespace _V1 {
// Forward declaration
class device;
class context;
class platform;

template <backend BackendName, class SyclObjectT>
auto get_native(const SyclObjectT &Obj)
-> backend_return_t<BackendName, SyclObjectT>;
namespace detail {
class platform_impl;
template <class T>
std::enable_if_t<std::is_same_v<T, platform>, platform>
createSyclObjFromImpl(platform_impl &);

/// Allows to enable/disable "Default Context" extension
///
Expand Down Expand Up @@ -231,6 +235,9 @@ class __SYCL_EXPORT platform : public detail::OwnerLessBase<platform> {
template <class Obj>
friend const decltype(Obj::impl) &
detail::getSyclObjImpl(const Obj &SyclObject);
template <class T>
friend std::enable_if_t<std::is_same_v<T, platform>, platform>
detail::createSyclObjFromImpl(detail::platform_impl &);

template <backend BackendName, class SyclObjectT>
friend auto get_native(const SyclObjectT &Obj)
Expand Down
8 changes: 4 additions & 4 deletions sycl/source/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ __SYCL_EXPORT device make_device(ur_native_handle_t NativeHandle,
NativeHandle, Adapter->getUrAdapter(), nullptr, &UrDevice);

// Construct the SYCL device from UR device.
auto Platform = platform_impl::getPlatformFromUrDevice(UrDevice, Adapter);
auto &Platform = platform_impl::getPlatformFromUrDevice(UrDevice, Adapter);
return detail::createSyclObjFromImpl<device>(
Platform->getOrMakeDeviceImpl(UrDevice, Platform));
Platform.getOrMakeDeviceImpl(UrDevice, Platform));
}

__SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle,
Expand Down Expand Up @@ -288,9 +288,9 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
std::transform(
ProgramDevices.begin(), ProgramDevices.end(), std::back_inserter(Devices),
[&Adapter](const auto &Dev) {
auto Platform =
platform_impl &Platform =
detail::platform_impl::getPlatformFromUrDevice(Dev, Adapter);
auto DeviceImpl = Platform->getOrMakeDeviceImpl(Dev, Platform);
auto DeviceImpl = Platform.getOrMakeDeviceImpl(Dev, Platform);
return createSyclObjFromImpl<device>(DeviceImpl);
});

Expand Down
4 changes: 2 additions & 2 deletions sycl/source/backend/level_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ using namespace sycl::detail;
__SYCL_EXPORT device make_device(const platform &Platform,
ur_native_handle_t NativeHandle) {
const auto &Adapter = ur::getAdapter<backend::ext_oneapi_level_zero>();
const auto &PlatformImpl = getSyclObjImpl(Platform);
platform_impl &PlatformImpl = *getSyclObjImpl(Platform).get();
// Create UR device first.
ur_device_handle_t UrDevice;
Adapter->call<UrApiKind::urDeviceCreateWithNativeHandle>(
NativeHandle, Adapter->getUrAdapter(), nullptr, &UrDevice);

return detail::createSyclObjFromImpl<device>(
PlatformImpl->getOrMakeDeviceImpl(UrDevice, PlatformImpl));
PlatformImpl.getOrMakeDeviceImpl(UrDevice, PlatformImpl));
}

} // namespace ext::oneapi::level_zero::detail
Expand Down
7 changes: 4 additions & 3 deletions sycl/source/detail/allowlist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,9 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,

// Get platform's backend and put it to DeviceDesc
DeviceDescT DeviceDesc;
auto PlatformImpl = platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter);
backend Backend = PlatformImpl->getBackend();
platform_impl &PlatformImpl =
platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter);
backend Backend = PlatformImpl.getBackend();

for (const auto &SyclBe : getSyclBeMap()) {
if (SyclBe.second == Backend) {
Expand All @@ -395,7 +396,7 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,

int InsertIDx = 0;
for (ur_device_handle_t Device : UrDevices) {
auto DeviceImpl = PlatformImpl->getOrMakeDeviceImpl(Device, PlatformImpl);
auto DeviceImpl = PlatformImpl.getOrMakeDeviceImpl(Device, PlatformImpl);
// get DeviceType value and put it to DeviceDesc
ur_device_type_t UrDevType = UR_DEVICE_TYPE_ALL;
Adapter->call<UrApiKind::urDeviceGetInfo>(
Expand Down
9 changes: 4 additions & 5 deletions sycl/source/detail/buffer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,11 @@ buffer_impl::getNativeVector(backend BackendName) const {
// doesn't have context and platform
if (!Ctx)
continue;
const PlatformImplPtr &Platform = Ctx->getPlatformImpl();
assert(Platform && "Platform must be present for device context");
if (Platform->getBackend() != BackendName)
const platform_impl &Platform = Ctx->getPlatformImpl();
if (Platform.getBackend() != BackendName)
continue;

auto Adapter = Platform->getAdapter();
auto Adapter = Platform.getAdapter();

ur_native_handle_t Handle = 0;
// When doing buffer interop we don't know what device the memory should be
Expand All @@ -94,7 +93,7 @@ buffer_impl::getNativeVector(backend BackendName) const {
&Handle);
Handles.push_back(Handle);

if (Platform->getBackend() == backend::opencl) {
if (Platform.getBackend() == backend::opencl) {
__SYCL_OCL_CALL(clRetainMemObject, ur::cast<cl_mem>(Handle));
}
}
Expand Down
18 changes: 9 additions & 9 deletions sycl/source/detail/context_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ context_impl::context_impl(const device &Device, async_handler AsyncHandler,
const property_list &PropList)
: MOwnedByRuntime(true), MAsyncHandler(AsyncHandler), MDevices(1, Device),
MContext(nullptr),
MPlatform(detail::getSyclObjImpl(Device.get_platform())),
MPlatform(detail::getSyclObjImpl(Device.get_platform()).get()),
MPropList(PropList), MSupportBufferLocationByDevices(NotChecked) {
verifyProps(PropList);
MKernelProgramCache.setContextPtr(this);
Expand All @@ -41,10 +41,10 @@ context_impl::context_impl(const std::vector<sycl::device> Devices,
async_handler AsyncHandler,
const property_list &PropList)
: MOwnedByRuntime(true), MAsyncHandler(AsyncHandler), MDevices(Devices),
MContext(nullptr), MPlatform(), MPropList(PropList),
MSupportBufferLocationByDevices(NotChecked) {
MContext(nullptr),
MPlatform(detail::getSyclObjImpl(MDevices[0].get_platform()).get()),
MPropList(PropList), MSupportBufferLocationByDevices(NotChecked) {
verifyProps(PropList);
MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform());
std::vector<ur_device_handle_t> DeviceIds;
for (const auto &D : MDevices) {
if (D.has(aspect::ext_oneapi_is_composite)) {
Expand Down Expand Up @@ -77,7 +77,7 @@ context_impl::context_impl(ur_context_handle_t UrContext,
MDevices(DeviceList), MContext(UrContext), MPlatform(),
MSupportBufferLocationByDevices(NotChecked) {
if (!MDevices.empty()) {
MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform());
MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform()).get();
} else {
std::vector<ur_device_handle_t> DeviceIds;
uint32_t DevicesNum = 0;
Expand All @@ -96,13 +96,13 @@ context_impl::context_impl(ur_context_handle_t UrContext,
make_error_code(errc::invalid),
"No devices in the provided device list and native context.");

std::shared_ptr<detail::platform_impl> Platform =
platform_impl &Platform =
platform_impl::getPlatformFromUrDevice(DeviceIds[0], Adapter);
for (ur_device_handle_t Dev : DeviceIds) {
MDevices.emplace_back(createSyclObjFromImpl<device>(
Platform->getOrMakeDeviceImpl(Dev, Platform)));
Platform.getOrMakeDeviceImpl(Dev, Platform)));
}
MPlatform = Platform;
MPlatform = &Platform;
}
// TODO catch an exception and put it to list of asynchronous exceptions
// getAdapter() will be the same as the Adapter passed. This should be taken
Expand Down Expand Up @@ -158,7 +158,7 @@ uint32_t context_impl::get_info<info::context::reference_count>() const {
this->getAdapter());
}
template <> platform context_impl::get_info<info::context::platform>() const {
return createSyclObjFromImpl<platform>(MPlatform);
return createSyclObjFromImpl<platform>(*MPlatform);
}
template <>
std::vector<sycl::device>
Expand Down
8 changes: 5 additions & 3 deletions sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ inline namespace _V1 {
// Forward declaration
class device;
namespace detail {
using PlatformImplPtr = std::shared_ptr<detail::platform_impl>;
class context_impl {
public:
/// Constructs a context_impl using a single SYCL devices.
Expand Down Expand Up @@ -89,8 +88,10 @@ class context_impl {
/// \return the Adapter associated with the platform of this context.
const AdapterPtr &getAdapter() const { return MPlatform->getAdapter(); }

// TODO: Think more about `const`
/// \return the PlatformImpl associated with this context.
const PlatformImplPtr &getPlatformImpl() const { return MPlatform; }
const platform_impl &getPlatformImpl() const { return *MPlatform; }
platform_impl &getPlatformImpl() { return *MPlatform; }

/// Queries this context for information.
///
Expand Down Expand Up @@ -257,7 +258,8 @@ class context_impl {
async_handler MAsyncHandler;
std::vector<device> MDevices;
ur_context_handle_t MContext;
PlatformImplPtr MPlatform;
// TODO: Make it a reference instead, but that needs a bit more refactoring:
platform_impl *MPlatform = nullptr;
property_list MPropList;
CachedLibProgramsT MCachedLibPrograms;
std::mutex MCachedLibProgramsMutex;
Expand Down
12 changes: 6 additions & 6 deletions sycl/source/detail/device_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ namespace detail {

/// Constructs a SYCL device instance using the provided
/// UR device instance.
device_impl::device_impl(ur_device_handle_t Device, PlatformImplPtr Platform)
: MDevice(Device), MPlatform(Platform),
device_impl::device_impl(ur_device_handle_t Device, platform_impl &Platform)
: MDevice(Device), MPlatform(&Platform),
MDeviceHostBaseTime(std::make_pair(0, 0)) {
const AdapterPtr &Adapter = Platform->getAdapter();
const AdapterPtr &Adapter = Platform.getAdapter();

// TODO catch an exception and put it to list of asynchronous exceptions
Adapter->call<UrApiKind::urDeviceGetInfo>(
MDevice, UR_DEVICE_INFO_TYPE, sizeof(ur_device_type_t), &MType, nullptr);

// No need to set MRootDevice when MAlwaysRootDevice is true
if (!Platform->MAlwaysRootDevice) {
if (!Platform.MAlwaysRootDevice) {
// TODO catch an exception and put it to list of asynchronous exceptions
Adapter->call<UrApiKind::urDeviceGetInfo>(
MDevice, UR_DEVICE_INFO_PARENT_DEVICE, sizeof(ur_device_handle_t),
Expand Down Expand Up @@ -74,7 +74,7 @@ cl_device_id device_impl::get() const {
}

platform device_impl::get_platform() const {
return createSyclObjFromImpl<platform>(MPlatform);
return createSyclObjFromImpl<platform>(*MPlatform);
}

template <typename Param>
Expand Down Expand Up @@ -177,7 +177,7 @@ std::vector<device> device_impl::create_sub_devices(
std::for_each(SubDevices.begin(), SubDevices.end(),
[&res, this](const ur_device_handle_t &a_ur_device) {
device sycl_device = detail::createSyclObjFromImpl<device>(
MPlatform->getOrMakeDeviceImpl(a_ur_device, MPlatform));
MPlatform->getOrMakeDeviceImpl(a_ur_device, *MPlatform));
res.push_back(sycl_device);
});
return res;
Expand Down
9 changes: 4 additions & 5 deletions sycl/source/detail/device_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@ namespace detail {

// Forward declaration
class platform_impl;
using PlatformImplPtr = std::shared_ptr<platform_impl>;

// TODO: Make code thread-safe
class device_impl {
public:
/// Constructs a SYCL device instance using the provided
/// UR device instance.
explicit device_impl(ur_device_handle_t Device, PlatformImplPtr Platform);
explicit device_impl(ur_device_handle_t Device, platform_impl &Platform);

~device_impl();

Expand Down Expand Up @@ -278,9 +277,9 @@ class device_impl {
/// Get the backend of this device
backend getBackend() const { return MPlatform->getBackend(); }

// TODO: const-correctness
/// @brief Get the platform impl serving this device
/// @return PlatformImplPtr
const PlatformImplPtr &getPlatformImpl() const { return MPlatform; }
platform_impl &getPlatformImpl() const { return *MPlatform; }

/// Get device info string
std::string get_device_info_string(ur_device_info_t InfoCode) const;
Expand All @@ -292,7 +291,7 @@ class device_impl {
ur_device_handle_t MDevice = 0;
ur_device_type_t MType;
ur_device_handle_t MRootDevice = nullptr;
PlatformImplPtr MPlatform;
platform_impl *MPlatform = nullptr;
bool MUseNativeAssert = false;
mutable std::string MDeviceName;
mutable std::once_flag MDeviceNameFlag;
Expand Down
13 changes: 6 additions & 7 deletions sycl/source/detail/device_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
namespace sycl {
inline namespace _V1 {
namespace detail {

inline std::vector<memory_order>
readMemoryOrderBitfield(ur_memory_order_capability_flags_t bits) {
std::vector<memory_order> result;
Expand Down Expand Up @@ -1171,9 +1170,9 @@ template <> struct get_device_info_impl<device, info::device::parent_device> {
throw exception(make_error_code(errc::invalid),
"No parent for device because it is not a subdevice");

const auto &Platform = Dev.getPlatformImpl();
platform_impl &Platform = Dev.getPlatformImpl();
return createSyclObjFromImpl<device>(
Platform->getOrMakeDeviceImpl(result, Platform));
Platform.getOrMakeDeviceImpl(result, Platform));
}
};

Expand Down Expand Up @@ -1337,10 +1336,10 @@ struct get_device_info_impl<
ext::oneapi::experimental::info::device::component_devices>::value,
ResultSize, Devs.data(), nullptr);
std::vector<sycl::device> Result;
const auto &Platform = Dev.getPlatformImpl();
platform_impl &Platform = Dev.getPlatformImpl();
for (const auto &d : Devs)
Result.push_back(createSyclObjFromImpl<device>(
Platform->getOrMakeDeviceImpl(d, Platform)));
Platform.getOrMakeDeviceImpl(d, Platform)));

return Result;
}
Expand All @@ -1363,9 +1362,9 @@ struct get_device_info_impl<
sizeof(Result), &Result, nullptr);

if (Result) {
const auto &Platform = Dev.getPlatformImpl();
platform_impl &Platform = Dev.getPlatformImpl();
return createSyclObjFromImpl<device>(
Platform->getOrMakeDeviceImpl(Result, Platform));
Platform.getOrMakeDeviceImpl(Result, Platform));
}
throw sycl::exception(make_error_code(errc::invalid),
"A component with aspect::ext_oneapi_is_component "
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/global_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ ProgramManager &GlobalHandler::getProgramManager() {
return PM;
}

std::unordered_map<PlatformImplPtr, ContextImplPtr> &
std::unordered_map<platform_impl *, ContextImplPtr> &
GlobalHandler::getPlatformToDefaultContextCache() {
// The optimization with static reference is not done because
// there are public methods of the GlobalHandler
Expand All @@ -207,8 +207,8 @@ Sync &GlobalHandler::getSync() {
return sync;
}

std::vector<PlatformImplPtr> &GlobalHandler::getPlatformCache() {
static std::vector<PlatformImplPtr> &PlatformCache =
std::vector<std::shared_ptr<platform_impl>> &GlobalHandler::getPlatformCache() {
static std::vector<std::shared_ptr<platform_impl>> &PlatformCache =
getOrCreate(MPlatformCache);
return PlatformCache;
}
Expand Down
9 changes: 4 additions & 5 deletions sycl/source/detail/global_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class ods_target_list;
class XPTIRegistry;
class ThreadPool;

using PlatformImplPtr = std::shared_ptr<platform_impl>;
using ContextImplPtr = std::shared_ptr<context_impl>;
using AdapterPtr = std::shared_ptr<Adapter>;

Expand Down Expand Up @@ -60,9 +59,9 @@ class GlobalHandler {
bool isSchedulerAlive() const;
ProgramManager &getProgramManager();
Sync &getSync();
std::vector<PlatformImplPtr> &getPlatformCache();
std::vector<std::shared_ptr<platform_impl>> &getPlatformCache();

std::unordered_map<PlatformImplPtr, ContextImplPtr> &
std::unordered_map<platform_impl *, ContextImplPtr> &
getPlatformToDefaultContextCache();

std::mutex &getPlatformToDefaultContextCacheMutex();
Expand Down Expand Up @@ -118,8 +117,8 @@ class GlobalHandler {
InstWithLock<Scheduler> MScheduler;
InstWithLock<ProgramManager> MProgramManager;
InstWithLock<Sync> MSync;
InstWithLock<std::vector<PlatformImplPtr>> MPlatformCache;
InstWithLock<std::unordered_map<PlatformImplPtr, ContextImplPtr>>
InstWithLock<std::vector<std::shared_ptr<platform_impl>>> MPlatformCache;
InstWithLock<std::unordered_map<platform_impl *, ContextImplPtr>>
MPlatformToDefaultContextCache;
InstWithLock<std::mutex> MPlatformToDefaultContextCacheMutex;
InstWithLock<std::mutex> MPlatformMapMutex;
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/kernel_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ void kernel_impl::checkIfValidForNumArgsInfoQuery() const {
}

void kernel_impl::enableUSMIndirectAccess() const {
if (!MContext->getPlatformImpl()->supports_usm())
if (!MContext->getPlatformImpl().supports_usm())
return;

// Some UR Adapters (like OpenCL) require this call to enable USM
Expand Down
Loading
Loading