Commit baca10b7 by John Plate Committed by Angle LUCI CQ

CL: Remove object cache and fix multi-threading

All CL front end objects used to be cached to be able to determine if an object has been created by the front end to check its validity. The validity is now checked with the existence of an intrinsic value (the dispatch table pointer), which is consistent with the patterns found in Mesa and clvk (though clvk uses a magic value). This allows the removal of all cached objects. The cached objects were stored with std::unique_ptr. These are now gone and all remaining pointers are now custom intrinsic reference count pointers. Also remove global lock which causes deadlocks, e.g. when CL API is called from a separate thread to unlock a blocking call with a user event. Most of the front end is constant and already thread-safe. The ref count is also thread-safe now (atomic). A few remaining locks will follow. Without the global lock it was now possible to make the API reentrant, and to remove the workaround with the Khronos ICD loader to skip ANGLE's OpenCL library. Bug: angleproject:6001 Change-Id: I7d3b52db9011a02cb7ea9ebdeb6e22c4c702ef5b Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/2927395 Commit-Queue: John Plate <jplate@google.com> Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Reviewed-by: 's avatarCody Northrop <cnorthrop@google.com>
parent 076974a9
......@@ -31,83 +31,81 @@ using ContextErrorCB = void(CL_CALLBACK *)(const char *errinfo,
using EventCB = void(CL_CALLBACK *)(cl_event event, cl_int event_command_status, void *user_data);
template <typename CLObjectType>
template <typename T = void>
struct Dispatch
{
constexpr Dispatch(const cl_icd_dispatch &dispatch) : mDispatch(&dispatch)
{
static_assert(
std::is_standard_layout<CLObjectType>::value && offsetof(CLObjectType, mDispatch) == 0u,
"Not ICD compatible");
}
~Dispatch() = default;
Dispatch() : mDispatch(sDispatch) {}
const cl_icd_dispatch &getDispatch() const { return *mDispatch; }
bool isValid() const { return mDispatch == sDispatch; }
constexpr const cl_icd_dispatch &getDispatch() { return *mDispatch; }
static bool IsValid(const Dispatch *p) { return p != nullptr && p->isValid(); }
private:
static const cl_icd_dispatch *sDispatch;
protected:
// This has to be the first member to be OpenCL ICD compatible
const cl_icd_dispatch *const mDispatch;
};
} // namespace cl
template <typename T>
const cl_icd_dispatch *Dispatch<T>::sDispatch = nullptr;
struct _cl_platform_id : public cl::Dispatch<_cl_platform_id>
template <typename NativeObjectType>
struct NativeObject : public Dispatch<>
{
constexpr _cl_platform_id(const cl_icd_dispatch &dispatch)
: cl::Dispatch<_cl_platform_id>(dispatch)
{}
~_cl_platform_id() = default;
};
NativeObject()
{
static_assert(std::is_standard_layout<NativeObjectType>::value &&
offsetof(NativeObjectType, mDispatch) == 0u,
"Not ICD compatible");
}
struct _cl_device_id : public cl::Dispatch<_cl_device_id>
{
constexpr _cl_device_id(const cl_icd_dispatch &dispatch) : cl::Dispatch<_cl_device_id>(dispatch)
{}
~_cl_device_id() = default;
};
template <typename T>
T &cast()
{
return static_cast<T &>(*this);
}
struct _cl_context : public cl::Dispatch<_cl_context>
{
constexpr _cl_context(const cl_icd_dispatch &dispatch) : cl::Dispatch<_cl_context>(dispatch) {}
~_cl_context() = default;
};
template <typename T>
const T &cast() const
{
return static_cast<const T &>(*this);
}
struct _cl_command_queue : public cl::Dispatch<_cl_command_queue>
{
constexpr _cl_command_queue(const cl_icd_dispatch &dispatch)
: cl::Dispatch<_cl_command_queue>(dispatch)
{}
~_cl_command_queue() = default;
};
NativeObjectType *getNative() { return static_cast<NativeObjectType *>(this); }
struct _cl_mem : public cl::Dispatch<_cl_mem>
{
constexpr _cl_mem(const cl_icd_dispatch &dispatch) : cl::Dispatch<_cl_mem>(dispatch) {}
~_cl_mem() = default;
static NativeObjectType *CastNative(NativeObjectType *p) { return p; }
};
struct _cl_program : public cl::Dispatch<_cl_program>
{
constexpr _cl_program(const cl_icd_dispatch &dispatch) : cl::Dispatch<_cl_program>(dispatch) {}
~_cl_program() = default;
};
} // namespace cl
struct _cl_kernel : public cl::Dispatch<_cl_kernel>
{
constexpr _cl_kernel(const cl_icd_dispatch &dispatch) : cl::Dispatch<_cl_kernel>(dispatch) {}
~_cl_kernel() = default;
};
struct _cl_platform_id : public cl::NativeObject<_cl_platform_id>
{};
struct _cl_event : public cl::Dispatch<_cl_event>
{
constexpr _cl_event(const cl_icd_dispatch &dispatch) : cl::Dispatch<_cl_event>(dispatch) {}
~_cl_event() = default;
};
struct _cl_device_id : public cl::NativeObject<_cl_device_id>
{};
struct _cl_sampler : public cl::Dispatch<_cl_sampler>
{
constexpr _cl_sampler(const cl_icd_dispatch &dispatch) : cl::Dispatch<_cl_sampler>(dispatch) {}
~_cl_sampler() = default;
};
struct _cl_context : public cl::NativeObject<_cl_context>
{};
struct _cl_command_queue : public cl::NativeObject<_cl_command_queue>
{};
struct _cl_mem : public cl::NativeObject<_cl_mem>
{};
struct _cl_program : public cl::NativeObject<_cl_program>
{};
struct _cl_kernel : public cl::NativeObject<_cl_kernel>
{};
struct _cl_event : public cl::NativeObject<_cl_event>
{};
struct _cl_sampler : public cl::NativeObject<_cl_sampler>
{};
#endif // ANGLECL_H_
......@@ -10,7 +10,7 @@
"scripts/entry_point_packed_gl_enums.json":
"4f7b43863a5e61991bba4010db463679",
"scripts/generate_entry_points.py":
"a2675710977baeb42b33ec28f2a5ed3b",
"225ece78a0a952b688bf2de8e578454e",
"scripts/gl.xml":
"2a73a58a7e26d8676a2c0af6d528cae6",
"scripts/gl_angle_ext.xml":
......@@ -130,7 +130,7 @@
"src/libGLESv2/egl_stubs_autogen.h":
"6439daa350c1663e71dd0af37dcc91df",
"src/libGLESv2/entry_points_cl_autogen.cpp":
"23810a61bbc63bd908d96549b033dce9",
"da26d01245459eb04827b8e0d7f0e58a",
"src/libGLESv2/entry_points_cl_autogen.h":
"dde2f94c3004874a7da995dae69da811",
"src/libGLESv2/entry_points_egl_autogen.cpp":
......
......@@ -256,8 +256,6 @@ TEMPLATE_EGL_ENTRY_POINT_WITH_RETURN = """\
TEMPLATE_CL_ENTRY_POINT_NO_RETURN = """\
void CL_API_CALL cl{name}({params})
{{
ANGLE_SCOPED_GLOBAL_LOCK();
CL_EVENT({name}, "{format_params}"{comma_if_needed}{pass_params});
{packed_gl_enum_conversions}
......@@ -270,10 +268,7 @@ void CL_API_CALL cl{name}({params})
TEMPLATE_CL_ENTRY_POINT_WITH_RETURN_ERROR = """\
cl_int CL_API_CALL cl{name}({params})
{{
ANGLE_SCOPED_GLOBAL_LOCK();
{initialization}
{{{initialization}
CL_EVENT({name}, "{format_params}"{comma_if_needed}{pass_params});
{packed_gl_enum_conversions}
......@@ -286,10 +281,7 @@ cl_int CL_API_CALL cl{name}({params})
TEMPLATE_CL_ENTRY_POINT_WITH_ERRCODE_RET = """\
{return_type} CL_API_CALL cl{name}({params})
{{
ANGLE_SCOPED_GLOBAL_LOCK();
{initialization}
{{{initialization}
CL_EVENT({name}, "{format_params}"{comma_if_needed}{pass_params});
{packed_gl_enum_conversions}
......@@ -310,10 +302,7 @@ TEMPLATE_CL_ENTRY_POINT_WITH_ERRCODE_RET = """\
TEMPLATE_CL_ENTRY_POINT_WITH_RETURN_POINTER = """\
{return_type} CL_API_CALL cl{name}({params})
{{
ANGLE_SCOPED_GLOBAL_LOCK();
{initialization}
{{{initialization}
CL_EVENT({name}, "{format_params}"{comma_if_needed}{pass_params});
{packed_gl_enum_conversions}
......@@ -1003,7 +992,6 @@ LIBCL_SOURCE_INCLUDES = """\
#include "libANGLE/validationCL_autogen.h"
#include "libGLESv2/cl_stubs_autogen.h"
#include "libGLESv2/entry_points_cl_utils.h"
#include "libGLESv2/global_state.h"
"""
TEMPLATE_EVENT_COMMENT = """\
......@@ -1564,7 +1552,7 @@ def format_entry_point_def(api, command_node, cmd_name, proto, params, is_explic
pass_params = [param_print_argument(command_node, param) for param in params]
format_params = [param_format_string(param) for param in params]
return_type = proto[:-len(cmd_name)].strip()
initialization = "InitBackEnds(%s);" % INIT_DICT[cmd_name] if cmd_name in INIT_DICT else ""
initialization = "InitBackEnds(%s);\n" % INIT_DICT[cmd_name] if cmd_name in INIT_DICT else ""
event_comment = TEMPLATE_EVENT_COMMENT if cmd_name in NO_EVENT_MARKER_EXCEPTIONS_LIST else ""
name_lower_no_suffix = strip_suffix(api, cmd_name[2:3].lower() + cmd_name[3:])
......
......@@ -7,8 +7,6 @@
#include "libANGLE/CLBuffer.h"
#include "libANGLE/CLContext.h"
namespace cl
{
......@@ -20,8 +18,7 @@ cl_mem Buffer::createSubBuffer(MemFlags flags,
cl_int &errorCode)
{
const cl_buffer_region &region = *static_cast<const cl_buffer_region *>(createInfo);
return mContext->createMemory(new Buffer(*this, flags, region.origin, region.size, errorCode),
errorCode);
return Object::Create<Buffer>(errorCode, *this, flags, region.origin, region.size);
}
Buffer::Buffer(Context &context,
......
......@@ -40,7 +40,7 @@ class Buffer final : public Memory
Buffer(Buffer &parent, MemFlags flags, size_t offset, size_t size, cl_int &errorCode);
friend class Context;
friend class Object;
};
inline cl_mem_object_type Buffer::getType() const
......@@ -60,8 +60,7 @@ inline bool Buffer::isRegionValid(const cl_buffer_region &region) const
inline bool Buffer::IsValid(const _cl_mem *buffer)
{
return Memory::IsValid(buffer) &&
static_cast<const Memory *>(buffer)->getType() == CL_MEM_OBJECT_BUFFER;
return Memory::IsValid(buffer) && buffer->cast<Memory>().getType() == CL_MEM_OBJECT_BUFFER;
}
} // namespace cl
......
......@@ -9,7 +9,6 @@
#include "libANGLE/CLContext.h"
#include "libANGLE/CLDevice.h"
#include "libANGLE/CLPlatform.h"
#include <cstring>
......@@ -24,21 +23,12 @@ CommandQueue::~CommandQueue()
}
}
bool CommandQueue::release()
{
const bool released = removeRef();
if (released)
{
mContext->destroyCommandQueue(this);
}
return released;
}
cl_int CommandQueue::getInfo(CommandQueueInfo name,
size_t valueSize,
void *value,
size_t *valueSizeRet) const
{
cl_uint valUInt = 0u;
void *valPointer = nullptr;
const void *copyValue = nullptr;
size_t copySize = 0u;
......@@ -46,18 +36,19 @@ cl_int CommandQueue::getInfo(CommandQueueInfo name,
switch (name)
{
case CommandQueueInfo::Context:
valPointer = static_cast<cl_context>(mContext.get());
valPointer = mContext->getNative();
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
case CommandQueueInfo::Device:
valPointer = static_cast<cl_device_id>(mDevice.get());
valPointer = mDevice->getNative();
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
case CommandQueueInfo::ReferenceCount:
copyValue = getRefCountPtr();
copySize = sizeof(*getRefCountPtr());
valUInt = getRefCount();
copyValue = &valUInt;
copySize = sizeof(valUInt);
break;
case CommandQueueInfo::Properties:
copyValue = &mProperties;
......@@ -72,7 +63,7 @@ cl_int CommandQueue::getInfo(CommandQueueInfo name,
copySize = sizeof(mSize);
break;
case CommandQueueInfo::DeviceDefault:
valPointer = static_cast<cl_command_queue>(mDevice->mDefaultCommandQueue);
valPointer = CommandQueue::CastNative(mDevice->mDefaultCommandQueue);
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
......@@ -123,23 +114,14 @@ cl_int CommandQueue::setProperty(CommandQueueProperties properties,
return result;
}
bool CommandQueue::IsValid(const _cl_command_queue *commandQueue)
{
const Platform::PtrList &platforms = Platform::GetPlatforms();
return std::find_if(platforms.cbegin(), platforms.cend(), [=](const PlatformPtr &platform) {
return platform->hasCommandQueue(commandQueue);
}) != platforms.cend();
}
CommandQueue::CommandQueue(Context &context,
Device &device,
CommandQueueProperties properties,
cl_int &errorCode)
: _cl_command_queue(context.getDispatch()),
mContext(&context),
: mContext(&context),
mDevice(&device),
mProperties(properties),
mImpl(context.mImpl->createCommandQueue(*this, errorCode))
mImpl(context.getImpl().createCommandQueue(*this, errorCode))
{}
CommandQueue::CommandQueue(Context &context,
......@@ -148,13 +130,12 @@ CommandQueue::CommandQueue(Context &context,
CommandQueueProperties properties,
cl_uint size,
cl_int &errorCode)
: _cl_command_queue(context.getDispatch()),
mContext(&context),
: mContext(&context),
mDevice(&device),
mPropArray(std::move(propArray)),
mProperties(properties),
mSize(size),
mImpl(context.mImpl->createCommandQueue(*this, errorCode))
mImpl(context.getImpl().createCommandQueue(*this, errorCode))
{
if (mProperties.isSet(CL_QUEUE_ON_DEVICE_DEFAULT))
{
......
......@@ -20,7 +20,6 @@ namespace cl
class CommandQueue final : public _cl_command_queue, public Object
{
public:
using PtrList = std::list<CommandQueuePtr>;
using PropArray = std::vector<cl_queue_properties>;
static constexpr cl_uint kNoSize = std::numeric_limits<cl_uint>::max();
......@@ -34,9 +33,6 @@ class CommandQueue final : public _cl_command_queue, public Object
bool hasSize() const;
cl_uint getSize() const;
void retain() noexcept;
bool release();
cl_int getInfo(CommandQueueInfo name,
size_t valueSize,
void *value,
......@@ -46,8 +42,6 @@ class CommandQueue final : public _cl_command_queue, public Object
cl_bool enable,
cl_command_queue_properties *oldProperties);
static bool IsValid(const _cl_command_queue *commandQueue);
private:
CommandQueue(Context &context,
Device &device,
......@@ -61,14 +55,14 @@ class CommandQueue final : public _cl_command_queue, public Object
cl_uint size,
cl_int &errorCode);
const ContextRefPtr mContext;
const DeviceRefPtr mDevice;
const ContextPtr mContext;
const DevicePtr mDevice;
const PropArray mPropArray;
CommandQueueProperties mProperties;
const cl_uint mSize = kNoSize;
const rx::CLCommandQueueImpl::Ptr mImpl;
friend class Context;
friend class Object;
};
inline const Context &CommandQueue::getContext() const
......@@ -96,11 +90,6 @@ inline cl_uint CommandQueue::getSize() const
return mSize;
}
inline void CommandQueue::retain() noexcept
{
addRef();
}
} // namespace cl
#endif // LIBANGLE_CLCOMMANDQUEUE_H_
......@@ -9,12 +9,8 @@
#ifndef LIBANGLE_CLCONTEXT_H_
#define LIBANGLE_CLCONTEXT_H_
#include "libANGLE/CLCommandQueue.h"
#include "libANGLE/CLDevice.h"
#include "libANGLE/CLEvent.h"
#include "libANGLE/CLMemory.h"
#include "libANGLE/CLProgram.h"
#include "libANGLE/CLSampler.h"
#include "libANGLE/CLPlatform.h"
#include "libANGLE/renderer/CLContextImpl.h"
namespace cl
......@@ -23,31 +19,21 @@ namespace cl
class Context final : public _cl_context, public Object
{
public:
using PtrList = std::list<ContextPtr>;
using PropArray = std::vector<cl_context_properties>;
~Context() override;
const Platform &getPlatform() const noexcept;
const DevicePtrs &getDevices() const;
bool hasDevice(const _cl_device_id *device) const;
const DeviceRefs &getDevices() const;
template <typename T = rx::CLContextImpl>
T &getImpl() const;
bool supportsImages() const;
bool supportsIL() const;
bool supportsBuiltInKernel(const std::string &name) const;
bool hasCommandQueue(const _cl_command_queue *commandQueue) const;
bool hasMemory(const _cl_mem *memory) const;
bool hasSampler(const _cl_sampler *sampler) const;
bool hasProgram(const _cl_program *program) const;
bool hasKernel(const _cl_kernel *kernel) const;
bool hasEvent(const _cl_event *event) const;
EventRefPtr findEvent(const EventPredicate &eventPredicate) const;
void retain() noexcept;
bool release();
cl_int getInfo(ContextInfo name, size_t valueSize, void *value, size_t *valueSizeRet) const;
cl_command_queue createCommandQueue(cl_device_id device,
......@@ -120,7 +106,6 @@ class Context final : public _cl_context, public Object
cl_int waitForEvents(cl_uint numEvents, const cl_event *eventList);
static bool IsValid(const _cl_context *context);
static bool IsValidAndVersionOrNewer(const _cl_context *context, cl_uint major, cl_uint minor);
static void CL_CALLBACK ErrorCallback(const char *errinfo,
......@@ -131,7 +116,7 @@ class Context final : public _cl_context, public Object
private:
Context(Platform &platform,
PropArray &&properties,
DeviceRefs &&devices,
DevicePtrs &&devices,
ContextErrorCB notify,
void *userData,
bool userSync,
......@@ -145,38 +130,14 @@ class Context final : public _cl_context, public Object
bool userSync,
cl_int &errorCode);
cl_command_queue createCommandQueue(CommandQueue *commandQueue, cl_int errorCode);
cl_mem createMemory(Memory *memory, cl_int errorCode);
cl_sampler createSampler(Sampler *sampler, cl_int errorCode);
cl_program createProgram(Program *program, cl_int errorCode);
cl_event createEvent(Event *event, cl_int errorCode);
void destroyCommandQueue(CommandQueue *commandQueue);
void destroyMemory(Memory *memory);
void destroySampler(Sampler *sampler);
void destroyProgram(Program *program);
void destroyEvent(Event *event);
Platform &mPlatform;
const PropArray mProperties;
const ContextErrorCB mNotify;
void *const mUserData;
const rx::CLContextImpl::Ptr mImpl;
const DeviceRefs mDevices;
CommandQueue::PtrList mCommandQueues;
Memory::PtrList mMemories;
Sampler::PtrList mSamplers;
Program::PtrList mPrograms;
Event::PtrList mEvents;
friend class Buffer;
friend class CommandQueue;
friend class Event;
friend class Memory;
friend class Platform;
friend class Program;
friend class Sampler;
const DevicePtrs mDevices;
friend class Object;
};
inline const Platform &Context::getPlatform() const noexcept
......@@ -184,89 +145,51 @@ inline const Platform &Context::getPlatform() const noexcept
return mPlatform;
}
inline const DevicePtrs &Context::getDevices() const
{
return mDevices;
}
inline bool Context::hasDevice(const _cl_device_id *device) const
{
return std::find_if(mDevices.cbegin(), mDevices.cend(), [=](const DeviceRefPtr &ptr) {
return std::find_if(mDevices.cbegin(), mDevices.cend(), [=](const DevicePtr &ptr) {
return ptr.get() == device;
}) != mDevices.cend();
}
inline const DeviceRefs &Context::getDevices() const
template <typename T>
inline T &Context::getImpl() const
{
return mDevices;
return static_cast<T &>(*mImpl);
}
inline bool Context::supportsImages() const
{
return (std::find_if(mDevices.cbegin(), mDevices.cend(), [](const DeviceRefPtr &ptr) {
return (std::find_if(mDevices.cbegin(), mDevices.cend(), [](const DevicePtr &ptr) {
return ptr->getInfo().mImageSupport == CL_TRUE;
}) != mDevices.cend());
}
inline bool Context::supportsIL() const
{
return (std::find_if(mDevices.cbegin(), mDevices.cend(), [](const DeviceRefPtr &ptr) {
return (std::find_if(mDevices.cbegin(), mDevices.cend(), [](const DevicePtr &ptr) {
return !ptr->getInfo().mIL_Version.empty();
}) != mDevices.cend());
}
inline bool Context::supportsBuiltInKernel(const std::string &name) const
{
return (std::find_if(mDevices.cbegin(), mDevices.cend(), [&](const DeviceRefPtr &ptr) {
return (std::find_if(mDevices.cbegin(), mDevices.cend(), [&](const DevicePtr &ptr) {
return ptr->supportsBuiltInKernel(name);
}) != mDevices.cend());
}
inline bool Context::hasCommandQueue(const _cl_command_queue *commandQueue) const
{
return std::find_if(mCommandQueues.cbegin(), mCommandQueues.cend(),
[=](const CommandQueuePtr &ptr) { return ptr.get() == commandQueue; }) !=
mCommandQueues.cend();
}
inline bool Context::hasMemory(const _cl_mem *memory) const
{
return std::find_if(mMemories.cbegin(), mMemories.cend(), [=](const MemoryPtr &ptr) {
return ptr.get() == memory;
}) != mMemories.cend();
}
inline bool Context::hasSampler(const _cl_sampler *sampler) const
{
return std::find_if(mSamplers.cbegin(), mSamplers.cend(), [=](const SamplerPtr &ptr) {
return ptr.get() == sampler;
}) != mSamplers.cend();
}
inline bool Context::hasProgram(const _cl_program *program) const
{
return std::find_if(mPrograms.cbegin(), mPrograms.cend(), [=](const ProgramPtr &ptr) {
return ptr.get() == program;
}) != mPrograms.cend();
}
inline bool Context::hasKernel(const _cl_kernel *kernel) const
{
return std::find_if(mPrograms.cbegin(), mPrograms.cend(), [=](const ProgramPtr &ptr) {
return ptr->hasKernel(kernel);
}) != mPrograms.cend();
}
inline bool Context::hasEvent(const _cl_event *event) const
{
return std::find_if(mEvents.cbegin(), mEvents.cend(),
[=](const EventPtr &ptr) { return ptr.get() == event; }) != mEvents.cend();
}
inline EventRefPtr Context::findEvent(const EventPredicate &eventPredicate) const
{
const auto eventIt = std::find_if(mEvents.cbegin(), mEvents.cend(), eventPredicate);
return EventRefPtr(eventIt != mEvents.cend() ? eventIt->get() : nullptr);
}
inline void Context::retain() noexcept
inline bool Context::IsValidAndVersionOrNewer(const _cl_context *context,
cl_uint major,
cl_uint minor)
{
addRef();
return IsValid(context) &&
context->cast<Context>().getPlatform().isVersionOrNewer(major, minor);
}
} // namespace cl
......
......@@ -14,13 +14,7 @@
namespace cl
{
Device::~Device()
{
if (isRoot())
{
removeRef();
}
}
Device::~Device() = default;
bool Device::supportsBuiltInKernel(const std::string &name) const
{
......@@ -47,20 +41,6 @@ bool Device::supportsBuiltInKernel(const std::string &name) const
return false;
}
bool Device::release()
{
if (isRoot())
{
return false;
}
const bool released = removeRef();
if (released)
{
mParent->destroySubDevice(this);
}
return released;
}
cl_int Device::getInfo(DeviceInfo name, size_t valueSize, void *value, size_t *valueSizeRet) const
{
static_assert(std::is_same<cl_uint, cl_bool>::value &&
......@@ -321,18 +301,19 @@ cl_int Device::getInfo(DeviceInfo name, size_t valueSize, void *value, size_t *v
// Handle all mapped values
case DeviceInfo::Platform:
valPointer = static_cast<cl_platform_id>(&mPlatform);
valPointer = mPlatform.getNative();
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
case DeviceInfo::ParentDevice:
valPointer = static_cast<cl_device_id>(mParent.get());
valPointer = mParent->getNative();
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
case DeviceInfo::ReferenceCount:
copyValue = getRefCountPtr();
copySize = sizeof(*getRefCountPtr());
valUInt = isRoot() ? 1u : getRefCount();
copyValue = &valUInt;
copySize = sizeof(valUInt);
break;
default:
......@@ -373,63 +354,37 @@ cl_int Device::createSubDevices(const cl_device_partition_property *properties,
{
numDevices = 0u;
}
DevicePtrList subDeviceList;
const cl_int result =
mImpl->createSubDevices(*this, properties, numDevices, subDeviceList, numDevicesRet);
if (result == CL_SUCCESS)
rx::CLDeviceImpl::CreateFuncs subDeviceCreateFuncs;
const cl_int errorCode =
mImpl->createSubDevices(properties, numDevices, subDeviceCreateFuncs, numDevicesRet);
if (errorCode == CL_SUCCESS)
{
cl::DeviceType type = mInfo.mType;
type.clear(CL_DEVICE_TYPE_DEFAULT);
DevicePtrs devices;
devices.reserve(subDeviceCreateFuncs.size());
while (!subDeviceCreateFuncs.empty())
{
for (const DevicePtr &subDevice : subDeviceList)
devices.emplace_back(new Device(mPlatform, this, type, subDeviceCreateFuncs.front()));
if (!devices.back()->mInfo.isValid())
{
*subDevices++ = subDevice.get();
return CL_INVALID_VALUE;
}
mSubDevices.splice(mSubDevices.cend(), std::move(subDeviceList));
subDeviceCreateFuncs.pop_front();
}
return result;
}
DevicePtr Device::CreateDevice(Platform &platform,
Device *parent,
DeviceType type,
const CreateImplFunc &createImplFunc)
{
DevicePtr device(new Device(platform, parent, type, createImplFunc));
return device->mInfo.isValid() ? std::move(device) : DevicePtr{};
}
bool Device::IsValid(const _cl_device_id *device)
{
const Platform::PtrList &platforms = Platform::GetPlatforms();
return std::find_if(platforms.cbegin(), platforms.cend(), [=](const PlatformPtr &platform) {
return platform->hasDevice(device);
}) != platforms.cend();
for (DevicePtr &subDevice : devices)
{
*subDevices++ = subDevice.release();
}
}
return errorCode;
}
Device::Device(Platform &platform,
Device *parent,
DeviceType type,
const CreateImplFunc &createImplFunc)
: _cl_device_id(platform.getDispatch()),
mPlatform(platform),
mParent(parent),
mImpl(createImplFunc(*this)),
mInfo(mImpl->createInfo(type))
const rx::CLDeviceImpl::CreateFunc &createFunc)
: mPlatform(platform), mParent(parent), mImpl(createFunc(*this)), mInfo(mImpl->createInfo(type))
{}
void Device::destroySubDevice(Device *device)
{
auto deviceIt = mSubDevices.cbegin();
while (deviceIt != mSubDevices.cend() && deviceIt->get() != device)
{
++deviceIt;
}
if (deviceIt != mSubDevices.cend())
{
mSubDevices.erase(deviceIt);
}
else
{
ERR() << "Sub-device not found";
}
}
} // namespace cl
......@@ -20,26 +20,21 @@ namespace cl
class Device final : public _cl_device_id, public Object
{
public:
using CreateImplFunc = std::function<rx::CLDeviceImpl::Ptr(const cl::Device &)>;
~Device() override;
Platform &getPlatform() noexcept;
const Platform &getPlatform() const noexcept;
bool isRoot() const noexcept;
template <typename T>
template <typename T = rx::CLDeviceImpl>
T &getImpl() const;
const rx::CLDeviceImpl::Info &getInfo() const;
cl_version getVersion() const;
bool isVersionOrNewer(cl_uint major, cl_uint minor) const;
bool hasSubDevice(const _cl_device_id *device) const;
bool supportsBuiltInKernel(const std::string &name) const;
void retain() noexcept;
bool release();
cl_int getInfoUInt(DeviceInfo name, cl_uint *value) const;
cl_int getInfoULong(DeviceInfo name, cl_ulong *value) const;
......@@ -50,29 +45,19 @@ class Device final : public _cl_device_id, public Object
cl_device_id *subDevices,
cl_uint *numDevicesRet);
static DevicePtr CreateDevice(Platform &platform,
Device *parent,
DeviceType type,
const CreateImplFunc &createImplFunc);
static bool IsValid(const _cl_device_id *device);
static bool IsValidAndVersionOrNewer(const _cl_device_id *device, cl_uint major, cl_uint minor);
static bool IsValidType(DeviceType type);
private:
Device(Platform &platform,
Device *parent,
DeviceType type,
const CreateImplFunc &createImplFunc);
void destroySubDevice(Device *device);
const rx::CLDeviceImpl::CreateFunc &createFunc);
Platform &mPlatform;
const DeviceRefPtr mParent;
const DevicePtr mParent;
const rx::CLDeviceImpl::Ptr mImpl;
const rx::CLDeviceImpl::Info mInfo;
DevicePtrList mSubDevices;
CommandQueue *mDefaultCommandQueue = nullptr;
friend class CommandQueue;
......@@ -91,7 +76,7 @@ inline const Platform &Device::getPlatform() const noexcept
inline bool Device::isRoot() const noexcept
{
return !mParent;
return mParent == nullptr;
}
template <typename T>
......@@ -105,24 +90,14 @@ inline const rx::CLDeviceImpl::Info &Device::getInfo() const
return mInfo;
}
inline bool Device::isVersionOrNewer(cl_uint major, cl_uint minor) const
{
return mInfo.mVersion >= CL_MAKE_VERSION(major, minor, 0u);
}
inline bool Device::hasSubDevice(const _cl_device_id *device) const
inline cl_version Device::getVersion() const
{
return std::find_if(mSubDevices.cbegin(), mSubDevices.cend(), [=](const DevicePtr &ptr) {
return ptr.get() == device || ptr->hasSubDevice(device);
}) != mSubDevices.cend();
return mInfo.mVersion;
}
inline void Device::retain() noexcept
inline bool Device::isVersionOrNewer(cl_uint major, cl_uint minor) const
{
if (!isRoot())
{
addRef();
}
return mInfo.mVersion >= CL_MAKE_VERSION(major, minor, 0u);
}
inline cl_int Device::getInfoUInt(DeviceInfo name, cl_uint *value) const
......@@ -135,13 +110,6 @@ inline cl_int Device::getInfoULong(DeviceInfo name, cl_ulong *value) const
return mImpl->getInfoULong(name, value);
}
inline bool Device::IsValidAndVersionOrNewer(const _cl_device_id *device,
cl_uint major,
cl_uint minor)
{
return IsValid(device) && static_cast<const Device *>(device)->isVersionOrNewer(major, minor);
}
inline bool Device::IsValidType(DeviceType type)
{
return type.get() <= CL_DEVICE_TYPE_CUSTOM || type == CL_DEVICE_TYPE_ALL;
......
......@@ -7,8 +7,8 @@
#include "libANGLE/CLEvent.h"
#include "libANGLE/CLCommandQueue.h"
#include "libANGLE/CLContext.h"
#include "libANGLE/CLPlatform.h"
#include <cstring>
......@@ -17,16 +17,6 @@ namespace cl
Event::~Event() = default;
bool Event::release()
{
const bool released = removeRef();
if (released)
{
mContext->destroyEvent(this);
}
return released;
}
void Event::callback(cl_int commandStatus)
{
ASSERT(commandStatus >= 0 && commandStatus < 3);
......@@ -34,6 +24,11 @@ void Event::callback(cl_int commandStatus)
{
data.first(this, commandStatus, data.second);
}
// This event can be released after the callback was called.
if (release())
{
delete this;
}
}
cl_int Event::setUserEventStatus(cl_int executionStatus)
......@@ -49,6 +44,7 @@ cl_int Event::setUserEventStatus(cl_int executionStatus)
cl_int Event::getInfo(EventInfo name, size_t valueSize, void *value, size_t *valueSizeRet) const
{
cl_int execStatus = 0;
cl_uint valUInt = 0u;
void *valPointer = nullptr;
const void *copyValue = nullptr;
size_t copySize = 0u;
......@@ -56,7 +52,7 @@ cl_int Event::getInfo(EventInfo name, size_t valueSize, void *value, size_t *val
switch (name)
{
case EventInfo::CommandQueue:
valPointer = static_cast<cl_command_queue>(mCommandQueue.get());
valPointer = mCommandQueue->getNative();
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
......@@ -65,8 +61,9 @@ cl_int Event::getInfo(EventInfo name, size_t valueSize, void *value, size_t *val
copySize = sizeof(mCommandType);
break;
case EventInfo::ReferenceCount:
copyValue = getRefCountPtr();
copySize = sizeof(*getRefCountPtr());
valUInt = getRefCount();
copyValue = &valUInt;
copySize = sizeof(valUInt);
break;
case EventInfo::CommandExecutionStatus:
{
......@@ -80,7 +77,7 @@ cl_int Event::getInfo(EventInfo name, size_t valueSize, void *value, size_t *val
break;
}
case EventInfo::Context:
valPointer = static_cast<cl_context>(mContext.get());
valPointer = mContext->getNative();
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
......@@ -113,36 +110,21 @@ cl_int Event::setCallback(cl_int commandExecCallbackType, EventCB pfnNotify, voi
// Only when required register a single callback with the back end for each callback type.
if (mCallbacks[commandExecCallbackType].empty())
{
const cl_int errorCode = mImpl->setCallback(commandExecCallbackType);
const cl_int errorCode = mImpl->setCallback(*this, commandExecCallbackType);
if (errorCode != CL_SUCCESS)
{
return errorCode;
}
// This event has to be retained until the callback is called.
retain();
}
mCallbacks[commandExecCallbackType].emplace_back(pfnNotify, userData);
return CL_SUCCESS;
}
bool Event::IsValid(const _cl_event *event)
{
const Platform::PtrList &platforms = Platform::GetPlatforms();
return std::find_if(platforms.cbegin(), platforms.cend(), [=](const PlatformPtr &platform) {
return platform->hasEvent(event);
}) != platforms.cend();
}
bool Event::IsValidAndVersionOrNewer(const _cl_event *event, cl_uint major, cl_uint minor)
{
const Platform::PtrList &platforms = Platform::GetPlatforms();
return std::find_if(platforms.cbegin(), platforms.cend(), [=](const PlatformPtr &platform) {
return platform->isVersionOrNewer(major, minor) && platform->hasEvent(event);
}) != platforms.cend();
}
Event::Event(Context &context, cl_int &errorCode)
: _cl_event(context.getDispatch()),
mContext(&context),
mImpl(context.mImpl->createUserEvent(*this, errorCode)),
: mContext(&context),
mImpl(context.getImpl().createUserEvent(*this, errorCode)),
mCommandType(CL_COMMAND_USER)
{}
......
......@@ -20,22 +20,17 @@ namespace cl
class Event final : public _cl_event, public Object
{
public:
using PtrList = std::list<EventPtr>;
~Event() override;
Context &getContext();
const Context &getContext() const;
const CommandQueueRefPtr &getCommandQueue() const;
const CommandQueuePtr &getCommandQueue() const;
cl_command_type getCommandType() const;
bool wasStatusChanged() const;
template <typename T>
template <typename T = rx::CLEventImpl>
T &getImpl() const;
void retain() noexcept;
bool release();
void callback(cl_int commandStatus);
cl_int setUserEventStatus(cl_int executionStatus);
......@@ -44,16 +39,13 @@ class Event final : public _cl_event, public Object
cl_int setCallback(cl_int commandExecCallbackType, EventCB pfnNotify, void *userData);
static bool IsValid(const _cl_event *event);
static bool IsValidAndVersionOrNewer(const _cl_event *event, cl_uint major, cl_uint minor);
private:
using CallbackData = std::pair<EventCB, void *>;
Event(Context &context, cl_int &errorCode);
const ContextRefPtr mContext;
const CommandQueueRefPtr mCommandQueue;
const ContextPtr mContext;
const CommandQueuePtr mCommandQueue;
const rx::CLEventImpl::Ptr mImpl;
const cl_command_type mCommandType;
......@@ -64,7 +56,7 @@ class Event final : public _cl_event, public Object
"OpenCL command execution status values are not as assumed");
std::array<std::vector<CallbackData>, 3u> mCallbacks;
friend class Context;
friend class Object;
};
inline Context &Event::getContext()
......@@ -77,7 +69,7 @@ inline const Context &Event::getContext() const
return *mContext;
}
inline const CommandQueueRefPtr &Event::getCommandQueue() const
inline const CommandQueuePtr &Event::getCommandQueue() const
{
return mCommandQueue;
}
......@@ -98,11 +90,6 @@ inline T &Event::getImpl() const
return static_cast<T &>(*mImpl);
}
inline void Event::retain() noexcept
{
addRef();
}
} // namespace cl
#endif // LIBANGLE_CLEVENT_H_
......@@ -59,7 +59,7 @@ cl_int Image::getInfo(ImageInfo name, size_t valueSize, void *value, size_t *val
copySize = sizeof(mDesc.arraySize);
break;
case ImageInfo::Buffer:
valPointer = static_cast<cl_mem>(mParent.get());
valPointer = Memory::CastNative(mParent.get());
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
......@@ -101,7 +101,7 @@ bool Image::IsValid(const _cl_mem *image)
{
return false;
}
switch (static_cast<const Memory *>(image)->getType())
switch (image->cast<Memory>().getType())
{
case CL_MEM_OBJECT_IMAGE1D:
case CL_MEM_OBJECT_IMAGE2D:
......
......@@ -40,7 +40,7 @@ class Image final : public Memory
const cl_image_format mFormat;
const ImageDescriptor mDesc;
friend class Context;
friend class Object;
};
inline cl_mem_object_type Image::getType() const
......
......@@ -7,7 +7,7 @@
#include "libANGLE/CLKernel.h"
#include "libANGLE/CLPlatform.h"
#include "libANGLE/CLContext.h"
#include "libANGLE/CLProgram.h"
#include <cstring>
......@@ -17,18 +17,9 @@ namespace cl
Kernel::~Kernel() = default;
bool Kernel::release()
{
const bool released = removeRef();
if (released)
{
mProgram->destroyKernel(this);
}
return released;
}
cl_int Kernel::getInfo(KernelInfo name, size_t valueSize, void *value, size_t *valueSizeRet) const
{
cl_uint valUInt = 0u;
void *valPointer = nullptr;
const void *copyValue = nullptr;
size_t copySize = 0u;
......@@ -44,16 +35,17 @@ cl_int Kernel::getInfo(KernelInfo name, size_t valueSize, void *value, size_t *v
copySize = sizeof(mInfo.mNumArgs);
break;
case KernelInfo::ReferenceCount:
copyValue = getRefCountPtr();
copySize = sizeof(*getRefCountPtr());
valUInt = getRefCount();
copyValue = &valUInt;
copySize = sizeof(valUInt);
break;
case KernelInfo::Context:
valPointer = static_cast<cl_context>(&mProgram->getContext());
valPointer = mProgram->getContext().getNative();
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
case KernelInfo::Program:
valPointer = static_cast<cl_program>(mProgram.get());
valPointer = mProgram->getNative();
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
......@@ -94,7 +86,7 @@ cl_int Kernel::getWorkGroupInfo(cl_device_id device,
size_t index = 0u;
if (device != nullptr)
{
const DeviceRefs &devices = mProgram->getContext().getDevices();
const DevicePtrs &devices = mProgram->getContext().getDevices();
while (index < devices.size() && devices[index].get() != device)
{
++index;
......@@ -215,34 +207,14 @@ cl_int Kernel::getArgInfo(cl_uint argIndex,
return CL_SUCCESS;
}
bool Kernel::IsValid(const _cl_kernel *kernel)
{
const Platform::PtrList &platforms = Platform::GetPlatforms();
return std::find_if(platforms.cbegin(), platforms.cend(), [=](const PlatformPtr &platform) {
return platform->hasKernel(kernel);
}) != platforms.cend();
}
bool Kernel::IsValidAndVersionOrNewer(const _cl_kernel *kernel, cl_uint major, cl_uint minor)
{
const Platform::PtrList &platforms = Platform::GetPlatforms();
return std::find_if(platforms.cbegin(), platforms.cend(), [=](const PlatformPtr &platform) {
return platform->isVersionOrNewer(major, minor) && platform->hasKernel(kernel);
}) != platforms.cend();
}
Kernel::Kernel(Program &program, const char *name, cl_int &errorCode)
: _cl_kernel(program.getDispatch()),
mProgram(&program),
mImpl(program.mImpl->createKernel(*this, name, errorCode)),
: mProgram(&program),
mImpl(program.getImpl().createKernel(*this, name, errorCode)),
mInfo(mImpl ? mImpl->createInfo(errorCode) : rx::CLKernelImpl::Info{})
{}
Kernel::Kernel(Program &program, const CreateImplFunc &createImplFunc, cl_int &errorCode)
: _cl_kernel(program.getDispatch()),
mProgram(&program),
mImpl(createImplFunc(*this)),
mInfo(mImpl->createInfo(errorCode))
Kernel::Kernel(Program &program, const rx::CLKernelImpl::CreateFunc &createFunc, cl_int &errorCode)
: mProgram(&program), mImpl(createFunc(*this)), mInfo(mImpl->createInfo(errorCode))
{}
} // namespace cl
......@@ -17,9 +17,6 @@ namespace cl
class Kernel final : public _cl_kernel, public Object
{
public:
using PtrList = std::list<KernelPtr>;
using CreateImplFunc = std::function<rx::CLKernelImpl::Ptr(const cl::Kernel &)>;
~Kernel() override;
const Program &getProgram() const;
......@@ -28,9 +25,6 @@ class Kernel final : public _cl_kernel, public Object
template <typename T>
T &getImpl() const;
void retain() noexcept;
bool release();
cl_int getInfo(KernelInfo name, size_t valueSize, void *value, size_t *valueSizeRet) const;
cl_int getWorkGroupInfo(cl_device_id device,
......@@ -45,17 +39,15 @@ class Kernel final : public _cl_kernel, public Object
void *value,
size_t *valueSizeRet) const;
static bool IsValid(const _cl_kernel *kernel);
static bool IsValidAndVersionOrNewer(const _cl_kernel *kernel, cl_uint major, cl_uint minor);
private:
Kernel(Program &program, const char *name, cl_int &errorCode);
Kernel(Program &program, const CreateImplFunc &createImplFunc, cl_int &errorCode);
Kernel(Program &program, const rx::CLKernelImpl::CreateFunc &createFunc, cl_int &errorCode);
const ProgramRefPtr mProgram;
const ProgramPtr mProgram;
const rx::CLKernelImpl::Ptr mImpl;
const rx::CLKernelImpl::Info mInfo;
friend class Object;
friend class Program;
};
......@@ -75,11 +67,6 @@ inline T &Kernel::getImpl() const
return static_cast<T &>(*mImpl);
}
inline void Kernel::retain() noexcept
{
addRef();
}
} // namespace cl
#endif // LIBANGLE_CLKERNEL_H_
......@@ -9,7 +9,6 @@
#include "libANGLE/CLBuffer.h"
#include "libANGLE/CLContext.h"
#include "libANGLE/CLPlatform.h"
#include <cstring>
......@@ -18,16 +17,6 @@ namespace cl
Memory::~Memory() = default;
bool Memory::release()
{
const bool released = removeRef();
if (released)
{
mContext->destroyMemory(this);
}
return released;
}
cl_int Memory::getInfo(MemInfo name, size_t valueSize, void *value, size_t *valueSizeRet) const
{
static_assert(
......@@ -63,16 +52,17 @@ cl_int Memory::getInfo(MemInfo name, size_t valueSize, void *value, size_t *valu
copySize = sizeof(mMapCount);
break;
case MemInfo::ReferenceCount:
copyValue = getRefCountPtr();
copySize = sizeof(*getRefCountPtr());
valUInt = getRefCount();
copyValue = &valUInt;
copySize = sizeof(valUInt);
break;
case MemInfo::Context:
valPointer = static_cast<cl_context>(mContext.get());
valPointer = mContext->getNative();
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
case MemInfo::AssociatedMemObject:
valPointer = static_cast<cl_mem>(mParent.get());
valPointer = Memory::CastNative(mParent.get());
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
......@@ -113,14 +103,6 @@ cl_int Memory::getInfo(MemInfo name, size_t valueSize, void *value, size_t *valu
return CL_SUCCESS;
}
bool Memory::IsValid(const _cl_mem *memory)
{
const Platform::PtrList &platforms = Platform::GetPlatforms();
return std::find_if(platforms.cbegin(), platforms.cend(), [=](const PlatformPtr &platform) {
return platform->hasMemory(memory);
}) != platforms.cend();
}
Memory::Memory(const Buffer &buffer,
Context &context,
PropArray &&properties,
......@@ -128,12 +110,11 @@ Memory::Memory(const Buffer &buffer,
size_t size,
void *hostPtr,
cl_int &errorCode)
: _cl_mem(context.getDispatch()),
mContext(&context),
: mContext(&context),
mProperties(std::move(properties)),
mFlags(flags),
mHostPtr(flags.isSet(CL_MEM_USE_HOST_PTR) ? hostPtr : nullptr),
mImpl(context.mImpl->createBuffer(buffer, size, hostPtr, errorCode)),
mImpl(context.getImpl().createBuffer(buffer, size, hostPtr, errorCode)),
mSize(size)
{}
......@@ -143,8 +124,7 @@ Memory::Memory(const Buffer &buffer,
size_t offset,
size_t size,
cl_int &errorCode)
: _cl_mem(parent.getDispatch()),
mContext(parent.mContext),
: mContext(parent.mContext),
mFlags(flags),
mHostPtr(parent.mHostPtr != nullptr ? static_cast<char *>(parent.mHostPtr) + offset
: nullptr),
......@@ -163,13 +143,12 @@ Memory::Memory(const Image &image,
Memory *parent,
void *hostPtr,
cl_int &errorCode)
: _cl_mem(context.getDispatch()),
mContext(&context),
: mContext(&context),
mProperties(std::move(properties)),
mFlags(flags),
mHostPtr(flags.isSet(CL_MEM_USE_HOST_PTR) ? hostPtr : nullptr),
mParent(parent),
mImpl(context.mImpl->createImage(image, format, desc, hostPtr, errorCode)),
mImpl(context.getImpl().createImage(image, format, desc, hostPtr, errorCode)),
mSize(mImpl ? mImpl->getSize(errorCode) : 0u)
{}
......
......@@ -18,7 +18,6 @@ namespace cl
class Memory : public _cl_mem, public Object
{
public:
using PtrList = std::list<MemoryPtr>;
using PropArray = std::vector<cl_mem_properties>;
~Memory() override;
......@@ -29,16 +28,11 @@ class Memory : public _cl_mem, public Object
const PropArray &getProperties() const;
MemFlags getFlags() const;
void *getHostPtr() const;
const MemoryRefPtr &getParent() const;
const MemoryPtr &getParent() const;
size_t getOffset() const;
void retain() noexcept;
bool release();
cl_int getInfo(MemInfo name, size_t valueSize, void *value, size_t *valueSizeRet) const;
static bool IsValid(const _cl_mem *memory);
protected:
Memory(const Buffer &buffer,
Context &context,
......@@ -65,11 +59,11 @@ class Memory : public _cl_mem, public Object
void *hostPtr,
cl_int &errorCode);
const ContextRefPtr mContext;
const ContextPtr mContext;
const PropArray mProperties;
const MemFlags mFlags;
void *const mHostPtr = nullptr;
const MemoryRefPtr mParent;
const MemoryPtr mParent;
const size_t mOffset = 0u;
const rx::CLMemoryImpl::Ptr mImpl;
const size_t mSize;
......@@ -100,7 +94,7 @@ inline void *Memory::getHostPtr() const
return mHostPtr;
}
inline const MemoryRefPtr &Memory::getParent() const
inline const MemoryPtr &Memory::getParent() const
{
return mParent;
}
......@@ -110,11 +104,6 @@ inline size_t Memory::getOffset() const
return mOffset;
}
inline void Memory::retain() noexcept
{
addRef();
}
} // namespace cl
#endif // LIBANGLE_CLMEMORY_H_
......@@ -10,6 +10,8 @@
namespace cl
{
Object::Object() : mRefCount(0u) {}
Object::~Object()
{
if (mRefCount != 0u)
......
......@@ -8,9 +8,10 @@
#ifndef LIBANGLE_CLOBJECT_H_
#define LIBANGLE_CLOBJECT_H_
#include "libANGLE/CLtypes.h"
#include "libANGLE/renderer/CLtypes.h"
#include "libANGLE/Debug.h"
#include <atomic>
namespace cl
{
......@@ -18,15 +19,14 @@ namespace cl
class Object
{
public:
Object() = default;
Object();
virtual ~Object();
cl_uint getRefCount() { return mRefCount; }
const cl_uint *getRefCountPtr() const { return &mRefCount; }
cl_uint getRefCount() const noexcept { return mRefCount; }
protected:
void addRef() noexcept { ++mRefCount; }
bool removeRef()
void retain() noexcept { ++mRefCount; }
bool release()
{
if (mRefCount == 0u)
{
......@@ -36,8 +36,15 @@ class Object
return --mRefCount == 0u;
}
template <typename T, typename... Args>
static T *Create(cl_int &errorCode, Args &&... args)
{
RefPointer<T> object(new T(std::forward<Args>(args)..., errorCode));
return errorCode == CL_SUCCESS ? object.release() : nullptr;
}
private:
cl_uint mRefCount = 1u;
std::atomic<cl_uint> mRefCount;
};
} // namespace cl
......
......@@ -7,7 +7,9 @@
#include "libANGLE/CLPlatform.h"
#include <cstdint>
#include "libANGLE/CLContext.h"
#include "libANGLE/CLDevice.h"
#include <cstring>
namespace cl
......@@ -37,7 +39,7 @@ Context::PropArray ParseContextProperties(const cl_context_properties *propertie
switch (*propIt++)
{
case CL_CONTEXT_PLATFORM:
platform = reinterpret_cast<Platform *>(*propIt++);
platform = &reinterpret_cast<cl_platform_id>(*propIt++)->cast<Platform>();
break;
case CL_CONTEXT_INTEROP_USER_SYNC:
userSync = *propIt++ != CL_FALSE;
......@@ -58,10 +60,7 @@ Context::PropArray ParseContextProperties(const cl_context_properties *propertie
} // namespace
Platform::~Platform()
{
removeRef();
}
Platform::~Platform() = default;
cl_int Platform::getInfo(PlatformInfo name,
size_t valueSize,
......@@ -166,12 +165,27 @@ cl_int Platform::getDeviceIDs(DeviceType deviceType,
return CL_SUCCESS;
}
void Platform::CreatePlatform(const cl_icd_dispatch &dispatch, const CreateImplFunc &createImplFunc)
void Platform::Initialize(const cl_icd_dispatch &dispatch,
rx::CLPlatformImpl::CreateFuncs &&createFuncs)
{
PlatformPtr platform(new Platform(dispatch, createImplFunc));
if (platform->mInfo.isValid() && !platform->mDevices.empty())
PlatformPtrs &platforms = GetPointers();
ASSERT(Dispatch::sDispatch == nullptr && platforms.empty());
if (Dispatch::sDispatch != nullptr || !platforms.empty())
{
ERR() << "Already initialized";
return;
}
Dispatch::sDispatch = &dispatch;
platforms.reserve(createFuncs.size());
while (!createFuncs.empty())
{
GetList().emplace_back(std::move(platform));
platforms.emplace_back(new Platform(createFuncs.front()));
if (!platforms.back()->mInfo.isValid() || platforms.back()->mDevices.empty())
{
platforms.pop_back();
}
createFuncs.pop_front();
}
}
......@@ -179,16 +193,16 @@ cl_int Platform::GetPlatformIDs(cl_uint numEntries,
cl_platform_id *platforms,
cl_uint *numPlatforms)
{
const PtrList &platformList = GetPlatforms();
const PlatformPtrs &availPlatforms = GetPlatforms();
if (numPlatforms != nullptr)
{
*numPlatforms = static_cast<cl_uint>(platformList.size());
*numPlatforms = static_cast<cl_uint>(availPlatforms.size());
}
if (platforms != nullptr)
{
cl_uint entry = 0u;
auto platformIt = platformList.cbegin();
while (entry < numEntries && platformIt != platformList.cend())
auto platformIt = availPlatforms.cbegin();
while (entry < numEntries && platformIt != availPlatforms.cend())
{
platforms[entry++] = (*platformIt++).get();
}
......@@ -207,15 +221,14 @@ cl_context Platform::CreateContext(const cl_context_properties *properties,
bool userSync = false;
Context::PropArray propArray = ParseContextProperties(properties, platform, userSync);
ASSERT(platform != nullptr);
DeviceRefs refDevices;
DevicePtrs devs;
devs.reserve(numDevices);
while (numDevices-- != 0u)
{
refDevices.emplace_back(static_cast<Device *>(*devices++));
devs.emplace_back(&(*devices++)->cast<Device>());
}
return platform->createContext(
new Context(*platform, std::move(propArray), std::move(refDevices), notify, userData,
userSync, errorCode),
errorCode);
return Object::Create<Context>(errorCode, *platform, std::move(propArray), std::move(devs),
notify, userData, userSync);
}
cl_context Platform::CreateContextFromType(const cl_context_properties *properties,
......@@ -228,44 +241,31 @@ cl_context Platform::CreateContextFromType(const cl_context_properties *properti
bool userSync = false;
Context::PropArray propArray = ParseContextProperties(properties, platform, userSync);
ASSERT(platform != nullptr);
return platform->createContext(new Context(*platform, std::move(propArray), deviceType, notify,
userData, userSync, errorCode),
errorCode);
return Object::Create<Context>(errorCode, *platform, std::move(propArray), deviceType, notify,
userData, userSync);
}
Platform::Platform(const cl_icd_dispatch &dispatch, const CreateImplFunc &createImplFunc)
: _cl_platform_id(dispatch),
mImpl(createImplFunc(*this)),
Platform::Platform(const rx::CLPlatformImpl::CreateFunc &createFunc)
: mImpl(createFunc(*this)),
mInfo(mImpl->createInfo()),
mDevices(mImpl->createDevices(*this))
mDevices(createDevices(mImpl->createDevices()))
{}
cl_context Platform::createContext(Context *context, cl_int errorCode)
DevicePtrs Platform::createDevices(rx::CLDeviceImpl::CreateDatas &&createDatas)
{
mContexts.emplace_back(context);
if (errorCode != CL_SUCCESS)
DevicePtrs devices;
devices.reserve(createDatas.size());
while (!createDatas.empty())
{
mContexts.back()->release();
return nullptr;
}
return mContexts.back().get();
}
void Platform::destroyContext(Context *context)
{
auto contextIt = mContexts.cbegin();
while (contextIt != mContexts.cend() && contextIt->get() != context)
devices.emplace_back(
new Device(*this, nullptr, createDatas.front().first, createDatas.front().second));
if (!devices.back()->mInfo.isValid())
{
++contextIt;
devices.pop_back();
}
if (contextIt != mContexts.cend())
{
mContexts.erase(contextIt);
}
else
{
ERR() << "Context not found";
createDatas.pop_front();
}
return devices;
}
constexpr char Platform::kVendor[];
......
......@@ -9,8 +9,7 @@
#ifndef LIBANGLE_CLPLATFORM_H_
#define LIBANGLE_CLPLATFORM_H_
#include "libANGLE/CLContext.h"
#include "libANGLE/CLDevice.h"
#include "libANGLE/CLObject.h"
#include "libANGLE/renderer/CLPlatformImpl.h"
#include "anglebase/no_destructor.h"
......@@ -21,25 +20,15 @@ namespace cl
class Platform final : public _cl_platform_id, public Object
{
public:
using PtrList = std::list<PlatformPtr>;
using CreateImplFunc = std::function<rx::CLPlatformImpl::Ptr(const cl::Platform &)>;
~Platform() override;
const rx::CLPlatformImpl::Info &getInfo() const;
cl_version getVersion() const;
bool isVersionOrNewer(cl_uint major, cl_uint minor) const;
bool hasDevice(const _cl_device_id *device) const;
const DevicePtrList &getDevices() const;
bool hasContext(const _cl_context *context) const;
bool hasCommandQueue(const _cl_command_queue *commandQueue) const;
bool hasMemory(const _cl_mem *memory) const;
bool hasSampler(const _cl_sampler *sampler) const;
bool hasProgram(const _cl_program *program) const;
bool hasKernel(const _cl_kernel *kernel) const;
bool hasEvent(const _cl_event *event) const;
const DevicePtrs &getDevices() const;
EventRefPtr findEvent(const EventPredicate &eventPredicate) const;
template <typename T = rx::CLPlatformImpl>
T &getImpl() const;
cl_int getInfo(PlatformInfo name, size_t valueSize, void *value, size_t *valueSizeRet) const;
......@@ -48,8 +37,8 @@ class Platform final : public _cl_platform_id, public Object
cl_device_id *devices,
cl_uint *numDevices) const;
static void CreatePlatform(const cl_icd_dispatch &dispatch,
const CreateImplFunc &createImplFunc);
static void Initialize(const cl_icd_dispatch &dispatch,
rx::CLPlatformImpl::CreateFuncs &&createFuncs);
static cl_int GetPlatformIDs(cl_uint numEntries,
cl_platform_id *platforms,
......@@ -68,34 +57,26 @@ class Platform final : public _cl_platform_id, public Object
void *userData,
cl_int &errorCode);
static const PtrList &GetPlatforms();
static const PlatformPtrs &GetPlatforms();
static Platform *GetDefault();
static Platform *CastOrDefault(cl_platform_id platform);
static bool IsValid(const _cl_platform_id *platform);
static bool IsValidOrDefault(const _cl_platform_id *platform);
static EventRefPtr FindEvent(const EventPredicate &eventPredicate);
static constexpr const char *GetVendor();
private:
Platform(const cl_icd_dispatch &dispatch, const CreateImplFunc &createImplFunc);
explicit Platform(const rx::CLPlatformImpl::CreateFunc &createFunc);
cl_context createContext(Context *context, cl_int errorCode);
void destroyContext(Context *context);
DevicePtrs createDevices(rx::CLDeviceImpl::CreateDatas &&createDatas);
static PtrList &GetList();
static PlatformPtrs &GetPointers();
const rx::CLPlatformImpl::Ptr mImpl;
const rx::CLPlatformImpl::Info mInfo;
const DevicePtrList mDevices;
Context::PtrList mContexts;
const DevicePtrs mDevices;
static constexpr char kVendor[] = "ANGLE";
static constexpr char kIcdSuffix[] = "ANGLE";
friend class Context;
};
inline const rx::CLPlatformImpl::Info &Platform::getInfo() const
......@@ -103,128 +84,53 @@ inline const rx::CLPlatformImpl::Info &Platform::getInfo() const
return mInfo;
}
inline bool Platform::isVersionOrNewer(cl_uint major, cl_uint minor) const
inline cl_version Platform::getVersion() const
{
return mInfo.mVersion >= CL_MAKE_VERSION(major, minor, 0u);
return mInfo.mVersion;
}
inline bool Platform::hasDevice(const _cl_device_id *device) const
inline bool Platform::isVersionOrNewer(cl_uint major, cl_uint minor) const
{
return std::find_if(mDevices.cbegin(), mDevices.cend(), [=](const DevicePtr &ptr) {
return ptr.get() == device || ptr->hasSubDevice(device);
}) != mDevices.cend();
return mInfo.mVersion >= CL_MAKE_VERSION(major, minor, 0u);
}
inline const DevicePtrList &Platform::getDevices() const
inline const DevicePtrs &Platform::getDevices() const
{
return mDevices;
}
inline bool Platform::hasContext(const _cl_context *context) const
{
return std::find_if(mContexts.cbegin(), mContexts.cend(), [=](const ContextPtr &ptr) {
return ptr.get() == context;
}) != mContexts.cend();
}
inline bool Platform::hasCommandQueue(const _cl_command_queue *commandQueue) const
{
return std::find_if(mContexts.cbegin(), mContexts.cend(), [=](const ContextPtr &ptr) {
return ptr->hasCommandQueue(commandQueue);
}) != mContexts.cend();
}
inline bool Platform::hasMemory(const _cl_mem *memory) const
{
return std::find_if(mContexts.cbegin(), mContexts.cend(), [=](const ContextPtr &ptr) {
return ptr->hasMemory(memory);
}) != mContexts.cend();
}
inline bool Platform::hasSampler(const _cl_sampler *sampler) const
template <typename T>
inline T &Platform::getImpl() const
{
return std::find_if(mContexts.cbegin(), mContexts.cend(), [=](const ContextPtr &ptr) {
return ptr->hasSampler(sampler);
}) != mContexts.cend();
return static_cast<T &>(*mImpl);
}
inline bool Platform::hasProgram(const _cl_program *program) const
inline PlatformPtrs &Platform::GetPointers()
{
return std::find_if(mContexts.cbegin(), mContexts.cend(), [=](const ContextPtr &ptr) {
return ptr->hasProgram(program);
}) != mContexts.cend();
static angle::base::NoDestructor<PlatformPtrs> sPointers;
return *sPointers;
}
inline bool Platform::hasKernel(const _cl_kernel *kernel) const
inline const PlatformPtrs &Platform::GetPlatforms()
{
return std::find_if(mContexts.cbegin(), mContexts.cend(), [=](const ContextPtr &ptr) {
return ptr->hasKernel(kernel);
}) != mContexts.cend();
}
inline bool Platform::hasEvent(const _cl_event *event) const
{
return std::find_if(mContexts.cbegin(), mContexts.cend(), [=](const ContextPtr &ptr) {
return ptr->hasEvent(event);
}) != mContexts.cend();
}
inline EventRefPtr Platform::findEvent(const EventPredicate &eventPredicate) const
{
EventRefPtr event;
auto contextIt = mContexts.cbegin();
while (contextIt != mContexts.cend() && event == nullptr)
{
event = (*contextIt++)->findEvent(eventPredicate);
}
return event;
}
inline Platform::PtrList &Platform::GetList()
{
static angle::base::NoDestructor<PtrList> sList;
return *sList;
}
inline const Platform::PtrList &Platform::GetPlatforms()
{
return GetList();
return GetPointers();
}
inline Platform *Platform::GetDefault()
{
return GetList().empty() ? nullptr : GetList().front().get();
return GetPlatforms().empty() ? nullptr : GetPlatforms().front().get();
}
inline Platform *Platform::CastOrDefault(cl_platform_id platform)
{
return platform != nullptr ? static_cast<Platform *>(platform) : GetDefault();
}
inline bool Platform::IsValid(const _cl_platform_id *platform)
{
const PtrList &platforms = GetPlatforms();
return std::find_if(platforms.cbegin(), platforms.cend(), [=](const PlatformPtr &ptr) {
return ptr.get() == platform;
}) != platforms.cend();
return platform != nullptr ? &platform->cast<Platform>() : GetDefault();
}
// Our CL implementation defines that a nullptr value chooses the platform that we provide as
// default, so this function returns true for a nullptr value if a default platform exists.
inline bool Platform::IsValidOrDefault(const _cl_platform_id *platform)
{
return platform != nullptr ? IsValid(platform) : GetDefault() != nullptr;
}
inline EventRefPtr Platform::FindEvent(const EventPredicate &eventPredicate)
{
EventRefPtr event;
auto platformIt = GetPlatforms().cbegin();
while (platformIt != GetPlatforms().cend() && event == nullptr)
{
event = (*platformIt++)->findEvent(eventPredicate);
}
return event;
return platform != nullptr ? platform->isValid() : GetDefault() != nullptr;
}
constexpr const char *Platform::GetVendor()
......
......@@ -17,16 +17,6 @@ namespace cl
Program::~Program() = default;
bool Program::release()
{
const bool released = removeRef();
if (released)
{
mContext->destroyProgram(this);
}
return released;
}
cl_int Program::getInfo(ProgramInfo name, size_t valueSize, void *value, size_t *valueSizeRet) const
{
static_assert(std::is_same<cl_uint, cl_bool>::value &&
......@@ -34,6 +24,7 @@ cl_int Program::getInfo(ProgramInfo name, size_t valueSize, void *value, size_t
std::is_same<cl_uint, cl_filter_mode>::value,
"OpenCL type mismatch");
std::vector<cl_device_id> devices;
std::vector<size_t> binarySizes;
std::vector<const unsigned char *> binaries;
cl_uint valUInt = 0u;
......@@ -44,11 +35,12 @@ cl_int Program::getInfo(ProgramInfo name, size_t valueSize, void *value, size_t
switch (name)
{
case ProgramInfo::ReferenceCount:
copyValue = getRefCountPtr();
copySize = sizeof(*getRefCountPtr());
valUInt = getRefCount();
copyValue = &valUInt;
copySize = sizeof(valUInt);
break;
case ProgramInfo::Context:
valPointer = static_cast<cl_context>(mContext.get());
valPointer = mContext->getNative();
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
......@@ -58,10 +50,13 @@ cl_int Program::getInfo(ProgramInfo name, size_t valueSize, void *value, size_t
copySize = sizeof(valUInt);
break;
case ProgramInfo::Devices:
static_assert(sizeof(decltype(mDevices)::value_type) == sizeof(Device *),
"DeviceRefs has wrong element size");
copyValue = mDevices.data();
copySize = mDevices.size() * sizeof(decltype(mDevices)::value_type);
devices.reserve(mDevices.size());
for (const DevicePtr &device : mDevices)
{
devices.emplace_back(device->getNative());
}
copyValue = devices.data();
copySize = devices.size() * sizeof(decltype(devices)::value_type);
break;
case ProgramInfo::Source:
copyValue = mSource.c_str();
......@@ -135,117 +130,70 @@ cl_int Program::getInfo(ProgramInfo name, size_t valueSize, void *value, size_t
cl_kernel Program::createKernel(const char *kernel_name, cl_int &errorCode)
{
return createKernel(new Kernel(*this, kernel_name, errorCode), errorCode);
}
cl_int Program::createKernel(const Kernel::CreateImplFunc &createImplFunc)
{
cl_int errorCode = CL_SUCCESS;
createKernel(new Kernel(*this, createImplFunc, errorCode), errorCode);
return errorCode;
return Object::Create<Kernel>(errorCode, *this, kernel_name);
}
cl_int Program::createKernels(cl_uint numKernels, cl_kernel *kernels, cl_uint *numKernelsRet)
{
cl_int errorCode = mImpl->createKernels(*this);
if (errorCode == CL_SUCCESS)
{
// CL_INVALID_VALUE if kernels is not NULL and
// num_kernels is less than the number of kernels in program.
if (kernels != nullptr && numKernels < mKernels.size())
if (kernels == nullptr)
{
errorCode = CL_INVALID_VALUE;
numKernels = 0u;
}
else
rx::CLKernelImpl::CreateFuncs createFuncs;
cl_int errorCode = mImpl->createKernels(numKernels, createFuncs, numKernelsRet);
if (errorCode == CL_SUCCESS)
{
if (kernels != nullptr)
KernelPtrs krnls;
krnls.reserve(createFuncs.size());
while (!createFuncs.empty())
{
for (const KernelPtr &kernel : mKernels)
krnls.emplace_back(new Kernel(*this, createFuncs.front(), errorCode));
if (errorCode != CL_SUCCESS)
{
*kernels++ = kernel.get();
return CL_INVALID_VALUE;
}
createFuncs.pop_front();
}
if (numKernelsRet != nullptr)
for (KernelPtr &kernel : krnls)
{
*numKernelsRet = static_cast<cl_uint>(mKernels.size());
}
*kernels++ = kernel.release();
}
}
return errorCode;
}
bool Program::IsValid(const _cl_program *program)
{
const Platform::PtrList &platforms = Platform::GetPlatforms();
return std::find_if(platforms.cbegin(), platforms.cend(), [=](const PlatformPtr &platform) {
return platform->hasProgram(program);
}) != platforms.cend();
}
Program::Program(Context &context, std::string &&source, cl_int &errorCode)
: _cl_program(context.getDispatch()),
mContext(&context),
: mContext(&context),
mDevices(context.getDevices()),
mImpl(context.mImpl->createProgramWithSource(*this, source, errorCode)),
mImpl(context.getImpl().createProgramWithSource(*this, source, errorCode)),
mSource(std::move(source))
{}
Program::Program(Context &context, const void *il, size_t length, cl_int &errorCode)
: _cl_program(context.getDispatch()),
mContext(&context),
: mContext(&context),
mDevices(context.getDevices()),
mIL(static_cast<const char *>(il), length),
mImpl(context.mImpl->createProgramWithIL(*this, il, length, errorCode)),
mImpl(context.getImpl().createProgramWithIL(*this, il, length, errorCode)),
mSource(mImpl ? mImpl->getSource(errorCode) : std::string{})
{}
Program::Program(Context &context,
DeviceRefs &&devices,
DevicePtrs &&devices,
Binaries &&binaries,
cl_int *binaryStatus,
cl_int &errorCode)
: _cl_program(context.getDispatch()),
mContext(&context),
: mContext(&context),
mDevices(std::move(devices)),
mImpl(context.mImpl->createProgramWithBinary(*this, binaries, binaryStatus, errorCode)),
mImpl(context.getImpl().createProgramWithBinary(*this, binaries, binaryStatus, errorCode)),
mSource(mImpl ? mImpl->getSource(errorCode) : std::string{}),
mBinaries(std::move(binaries))
{}
Program::Program(Context &context, DeviceRefs &&devices, const char *kernelNames, cl_int &errorCode)
: _cl_program(context.getDispatch()),
mContext(&context),
Program::Program(Context &context, DevicePtrs &&devices, const char *kernelNames, cl_int &errorCode)
: mContext(&context),
mDevices(std::move(devices)),
mImpl(context.mImpl->createProgramWithBuiltInKernels(*this, kernelNames, errorCode)),
mImpl(context.getImpl().createProgramWithBuiltInKernels(*this, kernelNames, errorCode)),
mSource(mImpl ? mImpl->getSource(errorCode) : std::string{})
{}
cl_kernel Program::createKernel(Kernel *kernel, cl_int errorCode)
{
mKernels.emplace_back(kernel);
if (errorCode != CL_SUCCESS)
{
mKernels.back()->release();
return nullptr;
}
return mKernels.back().get();
}
void Program::destroyKernel(Kernel *kernel)
{
auto kernelIt = mKernels.cbegin();
while (kernelIt != mKernels.cend() && kernelIt->get() != kernel)
{
++kernelIt;
}
if (kernelIt != mKernels.cend())
{
mKernels.erase(kernelIt);
}
else
{
ERR() << "Kernel not found";
}
}
} // namespace cl
......@@ -17,46 +17,35 @@ namespace cl
class Program final : public _cl_program, public Object
{
public:
using PtrList = std::list<ProgramPtr>;
~Program() override;
Context &getContext();
const Context &getContext() const;
const DeviceRefs &getDevices() const;
const Kernel::PtrList &getKernels() const;
bool hasKernel(const _cl_kernel *kernel) const;
const DevicePtrs &getDevices() const;
void retain() noexcept;
bool release();
template <typename T = rx::CLProgramImpl>
T &getImpl() const;
cl_int getInfo(ProgramInfo name, size_t valueSize, void *value, size_t *valueSizeRet) const;
cl_kernel createKernel(const char *kernel_name, cl_int &errorCode);
cl_int createKernel(const Kernel::CreateImplFunc &createImplFunc);
cl_int createKernels(cl_uint numKernels, cl_kernel *kernels, cl_uint *numKernelsRet);
static bool IsValid(const _cl_program *program);
cl_int createKernels(cl_uint numKernels, cl_kernel *kernels, cl_uint *numKernelsRet);
private:
Program(Context &context, std::string &&source, cl_int &errorCode);
Program(Context &context, const void *il, size_t length, cl_int &errorCode);
Program(Context &context,
DeviceRefs &&devices,
DevicePtrs &&devices,
Binaries &&binaries,
cl_int *binaryStatus,
cl_int &errorCode);
Program(Context &context, DeviceRefs &&devices, const char *kernelNames, cl_int &errorCode);
cl_kernel createKernel(Kernel *kernel, cl_int errorCode);
void destroyKernel(Kernel *kernel);
Program(Context &context, DevicePtrs &&devices, const char *kernelNames, cl_int &errorCode);
const ContextRefPtr mContext;
const DeviceRefs mDevices;
const ContextPtr mContext;
const DevicePtrs mDevices;
const std::string mIL;
const rx::CLProgramImpl::Ptr mImpl;
const std::string mSource;
......@@ -64,10 +53,8 @@ class Program final : public _cl_program, public Object
Binaries mBinaries;
size_t mNumKernels;
std::string mKernelNames;
Kernel::PtrList mKernels;
friend class Context;
friend class Kernel;
friend class Object;
};
inline Context &Program::getContext()
......@@ -80,26 +67,15 @@ inline const Context &Program::getContext() const
return *mContext;
}
inline const DeviceRefs &Program::getDevices() const
inline const DevicePtrs &Program::getDevices() const
{
return mDevices;
}
inline const Kernel::PtrList &Program::getKernels() const
{
return mKernels;
}
inline bool Program::hasKernel(const _cl_kernel *kernel) const
{
return std::find_if(mKernels.cbegin(), mKernels.cend(), [=](const KernelPtr &ptr) {
return ptr.get() == kernel;
}) != mKernels.cend();
}
inline void Program::retain() noexcept
template <typename T>
inline T &Program::getImpl() const
{
addRef();
return static_cast<T &>(*mImpl);
}
} // namespace cl
......
......@@ -28,9 +28,9 @@ class RefPointer
}
~RefPointer()
{
if (mCLObject != nullptr)
if (mCLObject != nullptr && mCLObject->release())
{
mCLObject->release();
delete mCLObject;
}
}
......
......@@ -8,7 +8,6 @@
#include "libANGLE/CLSampler.h"
#include "libANGLE/CLContext.h"
#include "libANGLE/CLPlatform.h"
#include <cstring>
......@@ -17,16 +16,6 @@ namespace cl
Sampler::~Sampler() = default;
bool Sampler::release()
{
const bool released = removeRef();
if (released)
{
mContext->destroySampler(this);
}
return released;
}
cl_int Sampler::getInfo(SamplerInfo name, size_t valueSize, void *value, size_t *valueSizeRet) const
{
static_assert(std::is_same<cl_uint, cl_addressing_mode>::value &&
......@@ -41,11 +30,12 @@ cl_int Sampler::getInfo(SamplerInfo name, size_t valueSize, void *value, size_t
switch (name)
{
case SamplerInfo::ReferenceCount:
copyValue = getRefCountPtr();
copySize = sizeof(*getRefCountPtr());
valUInt = getRefCount();
copyValue = &valUInt;
copySize = sizeof(valUInt);
break;
case SamplerInfo::Context:
valPointer = static_cast<cl_context>(mContext.get());
valPointer = mContext->getNative();
copyValue = &valPointer;
copySize = sizeof(valPointer);
break;
......@@ -91,27 +81,18 @@ cl_int Sampler::getInfo(SamplerInfo name, size_t valueSize, void *value, size_t
return CL_SUCCESS;
}
bool Sampler::IsValid(const _cl_sampler *sampler)
{
const Platform::PtrList &platforms = Platform::GetPlatforms();
return std::find_if(platforms.cbegin(), platforms.cend(), [=](const PlatformPtr &platform) {
return platform->hasSampler(sampler);
}) != platforms.cend();
}
Sampler::Sampler(Context &context,
PropArray &&properties,
cl_bool normalizedCoords,
AddressingMode addressingMode,
FilterMode filterMode,
cl_int &errorCode)
: _cl_sampler(context.getDispatch()),
mContext(&context),
: mContext(&context),
mProperties(std::move(properties)),
mNormalizedCoords(normalizedCoords),
mAddressingMode(addressingMode),
mFilterMode(filterMode),
mImpl(context.mImpl->createSampler(*this, errorCode))
mImpl(context.getImpl().createSampler(*this, errorCode))
{}
} // namespace cl
......@@ -17,7 +17,6 @@ namespace cl
class Sampler final : public _cl_sampler, public Object
{
public:
using PtrList = std::list<SamplerPtr>;
using PropArray = std::vector<cl_sampler_properties>;
~Sampler() override;
......@@ -28,13 +27,8 @@ class Sampler final : public _cl_sampler, public Object
AddressingMode getAddressingMode() const;
FilterMode getFilterMode() const;
void retain() noexcept;
bool release();
cl_int getInfo(SamplerInfo name, size_t valueSize, void *value, size_t *valueSizeRet) const;
static bool IsValid(const _cl_sampler *sampler);
private:
Sampler(Context &context,
PropArray &&properties,
......@@ -43,14 +37,14 @@ class Sampler final : public _cl_sampler, public Object
FilterMode filterMode,
cl_int &errorCode);
const ContextRefPtr mContext;
const ContextPtr mContext;
const PropArray mProperties;
const cl_bool mNormalizedCoords;
const AddressingMode mAddressingMode;
const FilterMode mFilterMode;
const rx::CLSamplerImpl::Ptr mImpl;
friend class Context;
friend class Object;
};
inline const Context &Sampler::getContext() const
......@@ -78,11 +72,6 @@ inline FilterMode Sampler::getFilterMode() const
return mFilterMode;
}
inline void Sampler::retain() noexcept
{
addRef();
}
} // namespace cl
#endif // LIBANGLE_CLSAMPLER_H_
......@@ -10,8 +10,10 @@
#include "libANGLE/CLBitField.h"
#include "libANGLE/CLRefPointer.h"
#include "libANGLE/Debug.h"
#include "common/PackedCLEnums_autogen.h"
#include "common/angleutils.h"
// Include frequently used standard headers
#include <algorithm>
......@@ -38,34 +40,24 @@ class Platform;
class Program;
class Sampler;
using CommandQueuePtr = std::unique_ptr<CommandQueue>;
using ContextPtr = std::unique_ptr<Context>;
using DevicePtr = std::unique_ptr<Device>;
using EventPtr = std::unique_ptr<Event>;
using KernelPtr = std::unique_ptr<Kernel>;
using MemoryPtr = std::unique_ptr<Memory>;
using ObjectPtr = std::unique_ptr<Object>;
using PlatformPtr = std::unique_ptr<Platform>;
using ProgramPtr = std::unique_ptr<Program>;
using SamplerPtr = std::unique_ptr<Sampler>;
using CommandQueuePtr = RefPointer<CommandQueue>;
using ContextPtr = RefPointer<Context>;
using DevicePtr = RefPointer<Device>;
using EventPtr = RefPointer<Event>;
using KernelPtr = RefPointer<Kernel>;
using MemoryPtr = RefPointer<Memory>;
using PlatformPtr = RefPointer<Platform>;
using ProgramPtr = RefPointer<Program>;
using SamplerPtr = RefPointer<Sampler>;
using CommandQueueRefPtr = RefPointer<CommandQueue>;
using ContextRefPtr = RefPointer<Context>;
using DeviceRefPtr = RefPointer<Device>;
using EventRefPtr = RefPointer<Event>;
using MemoryRefPtr = RefPointer<Memory>;
using ProgramRefPtr = RefPointer<Program>;
using DevicePtrList = std::list<DevicePtr>;
using DeviceRefs = std::vector<DeviceRefPtr>;
using EventRefs = std::vector<EventRefPtr>;
using DevicePtrs = std::vector<DevicePtr>;
using EventPtrs = std::vector<EventPtr>;
using KernelPtrs = std::vector<KernelPtr>;
using PlatformPtrs = std::vector<PlatformPtr>;
using Binary = std::vector<unsigned char>;
using Binaries = std::vector<Binary>;
using EventPredicate = std::function<bool(const EventPtr &)>;
struct ImageDescriptor
{
cl_mem_object_type type;
......
......@@ -9,7 +9,6 @@
#define LIBANGLE_RENDERER_CLCONTEXTIMPL_H_
#include "libANGLE/renderer/CLCommandQueueImpl.h"
#include "libANGLE/renderer/CLDeviceImpl.h"
#include "libANGLE/renderer/CLEventImpl.h"
#include "libANGLE/renderer/CLMemoryImpl.h"
#include "libANGLE/renderer/CLProgramImpl.h"
......@@ -26,7 +25,7 @@ class CLContextImpl : angle::NonCopyable
CLContextImpl(const cl::Context &context);
virtual ~CLContextImpl();
virtual cl::DeviceRefs getDevices(cl_int &errorCode) const = 0;
virtual cl::DevicePtrs getDevices(cl_int &errorCode) const = 0;
virtual CLCommandQueueImpl::Ptr createCommandQueue(const cl::CommandQueue &commandQueue,
cl_int &errorCode) = 0;
......@@ -64,7 +63,7 @@ class CLContextImpl : angle::NonCopyable
virtual CLEventImpl::Ptr createUserEvent(const cl::Event &event, cl_int &errorCode) = 0;
virtual cl_int waitForEvents(const cl::EventRefs &events) = 0;
virtual cl_int waitForEvents(const cl::EventPtrs &events) = 0;
protected:
const cl::Context &mContext;
......
......@@ -17,6 +17,10 @@ class CLDeviceImpl : angle::NonCopyable
{
public:
using Ptr = std::unique_ptr<CLDeviceImpl>;
using CreateFunc = std::function<Ptr(const cl::Device &)>;
using CreateFuncs = std::list<CreateFunc>;
using CreateData = std::pair<cl::DeviceType, CreateFunc>;
using CreateDatas = std::list<CreateData>;
struct Info
{
......@@ -72,10 +76,9 @@ class CLDeviceImpl : angle::NonCopyable
virtual cl_int getInfoStringLength(cl::DeviceInfo name, size_t *value) const = 0;
virtual cl_int getInfoString(cl::DeviceInfo name, size_t size, char *value) const = 0;
virtual cl_int createSubDevices(cl::Device &device,
const cl_device_partition_property *properties,
virtual cl_int createSubDevices(const cl_device_partition_property *properties,
cl_uint numDevices,
cl::DevicePtrList &subDeviceList,
CreateFuncs &createFuncs,
cl_uint *numDevicesRet) = 0;
protected:
......
......@@ -25,7 +25,7 @@ class CLEventImpl : angle::NonCopyable
virtual cl_int setUserEventStatus(cl_int executionStatus) = 0;
virtual cl_int setCallback(cl_int commandExecCallbackType) = 0;
virtual cl_int setCallback(cl::Event &event, cl_int commandExecCallbackType) = 0;
protected:
const cl::Event &mEvent;
......
......@@ -17,6 +17,8 @@ class CLKernelImpl : angle::NonCopyable
{
public:
using Ptr = std::unique_ptr<CLKernelImpl>;
using CreateFunc = std::function<Ptr(const cl::Kernel &)>;
using CreateFuncs = std::list<CreateFunc>;
struct WorkGroupInfo
{
......
......@@ -11,8 +11,6 @@
#include "libANGLE/renderer/CLContextImpl.h"
#include "libANGLE/renderer/CLDeviceImpl.h"
#include <tuple>
namespace rx
{
......@@ -20,6 +18,8 @@ class CLPlatformImpl : angle::NonCopyable
{
public:
using Ptr = std::unique_ptr<CLPlatformImpl>;
using CreateFunc = std::function<Ptr(const cl::Platform &)>;
using CreateFuncs = std::list<CreateFunc>;
struct Info
{
......@@ -48,10 +48,10 @@ class CLPlatformImpl : angle::NonCopyable
// For initialization only
virtual Info createInfo() const = 0;
virtual cl::DevicePtrList createDevices(cl::Platform &platform) const = 0;
virtual CLDeviceImpl::CreateDatas createDevices() const = 0;
virtual CLContextImpl::Ptr createContext(cl::Context &context,
const cl::DeviceRefs &devices,
const cl::DevicePtrs &devices,
bool userSync,
cl_int &errorCode) = 0;
......
......@@ -27,7 +27,9 @@ class CLProgramImpl : angle::NonCopyable
const char *name,
cl_int &errorCode) = 0;
virtual cl_int createKernels(cl::Program &program) = 0;
virtual cl_int createKernels(cl_uint numKernels,
CLKernelImpl::CreateFuncs &createFuncs,
cl_uint *numKernelsRet) = 0;
protected:
const cl::Program &mProgram;
......
......@@ -10,8 +10,6 @@
#include "libANGLE/CLtypes.h"
#include "common/angleutils.h"
namespace rx
{
......
......@@ -38,7 +38,6 @@ config("angle_cl_backend_config") {
angle_source_set("angle_cl_backend") {
sources = _cl_backend_sources
configs += [ "$angle_root/src/libOpenCL:opencl_library_name" ]
public_deps = [
"$angle_root:libANGLE_headers",
......
......@@ -7,8 +7,6 @@
#include "libANGLE/renderer/cl/CLCommandQueueCL.h"
#include "libANGLE/Debug.h"
namespace rx
{
......
......@@ -9,8 +9,6 @@
#ifndef LIBANGLE_RENDERER_CL_CLCOMMANDQUEUECL_H_
#define LIBANGLE_RENDERER_CL_CLCOMMANDQUEUECL_H_
#include "libANGLE/renderer/cl/cl_types.h"
#include "libANGLE/renderer/CLCommandQueueImpl.h"
namespace rx
......
......@@ -24,7 +24,6 @@
#include "libANGLE/CLPlatform.h"
#include "libANGLE/CLProgram.h"
#include "libANGLE/CLSampler.h"
#include "libANGLE/Debug.h"
namespace rx
{
......@@ -41,7 +40,7 @@ CLContextCL::~CLContextCL()
}
}
cl::DeviceRefs CLContextCL::getDevices(cl_int &errorCode) const
cl::DevicePtrs CLContextCL::getDevices(cl_int &errorCode) const
{
size_t valueSize = 0u;
errorCode = mNative->getDispatch().clGetContextInfo(mNative, CL_CONTEXT_DEVICES, 0u, nullptr,
......@@ -53,8 +52,9 @@ cl::DeviceRefs CLContextCL::getDevices(cl_int &errorCode) const
nativeDevices.data(), nullptr);
if (errorCode == CL_SUCCESS)
{
const cl::DevicePtrList &platformDevices = mContext.getPlatform().getDevices();
cl::DeviceRefs devices;
const cl::DevicePtrs &platformDevices = mContext.getPlatform().getDevices();
cl::DevicePtrs devices;
devices.reserve(nativeDevices.size());
for (cl_device_id nativeDevice : nativeDevices)
{
auto it = platformDevices.cbegin();
......@@ -72,13 +72,13 @@ cl::DeviceRefs CLContextCL::getDevices(cl_int &errorCode) const
ASSERT(false);
errorCode = CL_INVALID_DEVICE;
ERR() << "Device not found in platform list";
return cl::DeviceRefs{};
return cl::DevicePtrs{};
}
}
return devices;
}
}
return cl::DeviceRefs{};
return cl::DevicePtrs{};
}
CLCommandQueueImpl::Ptr CLContextCL::createCommandQueue(const cl::CommandQueue &commandQueue,
......@@ -140,7 +140,7 @@ CLMemoryImpl::Ptr CLContextCL::createImage(const cl::Image &image,
desc.height, desc.depth,
desc.arraySize, desc.rowPitch,
desc.slicePitch, desc.numMipLevels,
desc.numSamples, {static_cast<cl_mem>(image.getParent().get())}};
desc.numSamples, {cl::Memory::CastNative(image.getParent().get())}};
if (image.getProperties().empty())
{
......@@ -238,7 +238,7 @@ CLProgramImpl::Ptr CLContextCL::createProgramWithBinary(const cl::Program &progr
{
ASSERT(program.getDevices().size() == binaries.size());
std::vector<cl_device_id> nativeDevices;
for (const cl::DeviceRefPtr &device : program.getDevices())
for (const cl::DevicePtr &device : program.getDevices())
{
nativeDevices.emplace_back(device->getImpl<CLDeviceCL>().getNative());
}
......@@ -261,7 +261,7 @@ CLProgramImpl::Ptr CLContextCL::createProgramWithBuiltInKernels(const cl::Progra
cl_int &errorCode)
{
std::vector<cl_device_id> nativeDevices;
for (const cl::DeviceRefPtr &device : program.getDevices())
for (const cl::DevicePtr &device : program.getDevices())
{
nativeDevices.emplace_back(device->getImpl<CLDeviceCL>().getNative());
}
......@@ -278,11 +278,11 @@ CLEventImpl::Ptr CLContextCL::createUserEvent(const cl::Event &event, cl_int &er
return CLEventImpl::Ptr(nativeEvent != nullptr ? new CLEventCL(event, nativeEvent) : nullptr);
}
cl_int CLContextCL::waitForEvents(const cl::EventRefs &events)
cl_int CLContextCL::waitForEvents(const cl::EventPtrs &events)
{
std::vector<cl_event> nativeEvents;
nativeEvents.reserve(events.size());
for (const cl::EventRefPtr &event : events)
for (const cl::EventPtr &event : events)
{
nativeEvents.emplace_back(event->getImpl<CLEventCL>().getNative());
}
......
......@@ -8,8 +8,6 @@
#ifndef LIBANGLE_RENDERER_CL_CLCONTEXTCL_H_
#define LIBANGLE_RENDERER_CL_CLCONTEXTCL_H_
#include "libANGLE/renderer/cl/cl_types.h"
#include "libANGLE/renderer/CLContextImpl.h"
namespace rx
......@@ -21,7 +19,7 @@ class CLContextCL : public CLContextImpl
CLContextCL(const cl::Context &context, cl_context native);
~CLContextCL() override;
cl::DeviceRefs getDevices(cl_int &errorCode) const override;
cl::DevicePtrs getDevices(cl_int &errorCode) const override;
CLCommandQueueImpl::Ptr createCommandQueue(const cl::CommandQueue &commandQueue,
cl_int &errorCode) override;
......@@ -59,7 +57,7 @@ class CLContextCL : public CLContextImpl
CLEventImpl::Ptr createUserEvent(const cl::Event &event, cl_int &errorCode) override;
cl_int waitForEvents(const cl::EventRefs &events) override;
cl_int waitForEvents(const cl::EventPtrs &events) override;
private:
const cl_context mNative;
......
......@@ -7,11 +7,9 @@
#include "libANGLE/renderer/cl/CLDeviceCL.h"
#include "libANGLE/renderer/cl/CLPlatformCL.h"
#include "libANGLE/renderer/cl/cl_util.h"
#include "libANGLE/CLDevice.h"
#include "libANGLE/Debug.h"
namespace rx
{
......@@ -195,10 +193,9 @@ cl_int CLDeviceCL::getInfoString(cl::DeviceInfo name, size_t size, char *value)
nullptr);
}
cl_int CLDeviceCL::createSubDevices(cl::Device &device,
const cl_device_partition_property *properties,
cl_int CLDeviceCL::createSubDevices(const cl_device_partition_property *properties,
cl_uint numDevices,
cl::DevicePtrList &subDeviceList,
CreateFuncs &createFuncs,
cl_uint *numDevicesRet)
{
if (numDevices == 0u)
......@@ -208,27 +205,18 @@ cl_int CLDeviceCL::createSubDevices(cl::Device &device,
}
std::vector<cl_device_id> nativeSubDevices(numDevices, nullptr);
const cl_int result = mNative->getDispatch().clCreateSubDevices(
const cl_int errorCode = mNative->getDispatch().clCreateSubDevices(
mNative, properties, numDevices, nativeSubDevices.data(), nullptr);
if (result == CL_SUCCESS)
if (errorCode == CL_SUCCESS)
{
cl::DeviceType type = device.getInfo().mType;
type.clear(CL_DEVICE_TYPE_DEFAULT);
for (cl_device_id nativeSubDevice : nativeSubDevices)
{
const cl::Device::CreateImplFunc createImplFunc = [&](const cl::Device &device) {
createFuncs.emplace_back([=](const cl::Device &device) {
return Ptr(new CLDeviceCL(device, nativeSubDevice));
};
subDeviceList.emplace_back(
cl::Device::CreateDevice(device.getPlatform(), &device, type, createImplFunc));
if (!subDeviceList.back())
{
subDeviceList.clear();
return CL_INVALID_VALUE;
}
});
}
}
return result;
return errorCode;
}
CLDeviceCL::CLDeviceCL(const cl::Device &device, cl_device_id native)
......
......@@ -8,8 +8,6 @@
#ifndef LIBANGLE_RENDERER_CL_CLDEVICECL_H_
#define LIBANGLE_RENDERER_CL_CLDEVICECL_H_
#include "libANGLE/renderer/cl/cl_types.h"
#include "libANGLE/renderer/CLDeviceImpl.h"
namespace rx
......@@ -30,10 +28,9 @@ class CLDeviceCL : public CLDeviceImpl
cl_int getInfoStringLength(cl::DeviceInfo name, size_t *value) const override;
cl_int getInfoString(cl::DeviceInfo name, size_t size, char *value) const override;
cl_int createSubDevices(cl::Device &device,
const cl_device_partition_property *properties,
cl_int createSubDevices(const cl_device_partition_property *properties,
cl_uint numDevices,
cl::DevicePtrList &subDeviceList,
CreateFuncs &createFuncs,
cl_uint *numDevicesRet) override;
private:
......
......@@ -7,8 +7,7 @@
#include "libANGLE/renderer/cl/CLEventCL.h"
#include "libANGLE/CLPlatform.h"
#include "libANGLE/Debug.h"
#include "libANGLE/CLEvent.h"
namespace rx
{
......@@ -36,24 +35,15 @@ cl_int CLEventCL::setUserEventStatus(cl_int executionStatus)
return mNative->getDispatch().clSetUserEventStatus(mNative, executionStatus);
}
cl_int CLEventCL::setCallback(cl_int commandExecCallbackType)
cl_int CLEventCL::setCallback(cl::Event &event, cl_int commandExecCallbackType)
{
return mNative->getDispatch().clSetEventCallback(mNative, commandExecCallbackType, Callback,
nullptr);
&event);
}
void CLEventCL::Callback(cl_event event, cl_int commandStatus, void *userData)
{
const cl::EventRefPtr evt = cl::Platform::FindEvent(
[=](const cl::EventPtr &ptr) { return ptr->getImpl<CLEventCL>().getNative() == event; });
if (evt)
{
evt->callback(commandStatus);
}
else
{
WARN() << "Callback event not found";
}
static_cast<cl::Event *>(userData)->callback(commandStatus);
}
} // namespace rx
......@@ -27,7 +27,7 @@ class CLEventCL : public CLEventImpl
cl_int setUserEventStatus(cl_int executionStatus) override;
cl_int setCallback(cl_int commandExecCallbackType) override;
cl_int setCallback(cl::Event &event, cl_int commandExecCallbackType) override;
private:
static void CL_CALLBACK Callback(cl_event event, cl_int commandStatus, void *userData);
......
......@@ -8,7 +8,6 @@
#include "libANGLE/renderer/cl/CLMemoryCL.h"
#include "libANGLE/CLBuffer.h"
#include "libANGLE/Debug.h"
namespace rx
{
......
......@@ -8,8 +8,6 @@
#ifndef LIBANGLE_RENDERER_CL_CLMEMORYCL_H_
#define LIBANGLE_RENDERER_CL_CLMEMORYCL_H_
#include "libANGLE/renderer/cl/cl_types.h"
#include "libANGLE/renderer/CLMemoryImpl.h"
namespace rx
......
......@@ -11,12 +11,11 @@
#include "libANGLE/renderer/cl/CLDeviceCL.h"
#include "libANGLE/renderer/cl/cl_util.h"
#include "libANGLE/CLContext.h"
#include "libANGLE/CLDevice.h"
#include "libANGLE/CLPlatform.h"
#include "libANGLE/Debug.h"
#include "anglebase/no_destructor.h"
#include "common/angle_version.h"
#include "common/system_utils.h"
extern "C" {
#include "icd.h"
......@@ -298,9 +297,9 @@ CLPlatformImpl::Info CLPlatformCL::createInfo() const
return info;
}
cl::DevicePtrList CLPlatformCL::createDevices(cl::Platform &platform) const
CLDeviceImpl::CreateDatas CLPlatformCL::createDevices() const
{
cl::DevicePtrList devices;
CLDeviceImpl::CreateDatas createDatas;
// Fetch all regular devices. This does not include CL_DEVICE_TYPE_CUSTOM, which are not
// supported by the CL pass-through back end because they have no standard feature set.
......@@ -348,29 +347,23 @@ cl::DevicePtrList CLPlatformCL::createDevices(cl::Platform &platform) const
types[index].clear(CL_DEVICE_TYPE_DEFAULT);
}
const cl::Device::CreateImplFunc createImplFunc = [&](const cl::Device &device) {
return CLDeviceCL::Ptr(new CLDeviceCL(device, nativeDevices[index]));
};
devices.emplace_back(
cl::Device::CreateDevice(platform, nullptr, types[index], createImplFunc));
if (!devices.back())
{
devices.clear();
break;
}
cl_device_id nativeDevice = nativeDevices[index];
createDatas.emplace_back(types[index], [=](const cl::Device &device) {
return CLDeviceCL::Ptr(new CLDeviceCL(device, nativeDevice));
});
}
}
}
if (devices.empty())
if (createDatas.empty())
{
ERR() << "Failed to query CL devices";
}
return devices;
return createDatas;
}
CLContextImpl::Ptr CLPlatformCL::createContext(cl::Context &context,
const cl::DeviceRefs &devices,
const cl::DevicePtrs &devices,
bool userSync,
cl_int &errorCode)
{
......@@ -380,7 +373,7 @@ CLContextImpl::Ptr CLPlatformCL::createContext(cl::Context &context,
0};
std::vector<cl_device_id> nativeDevices;
for (const cl::DeviceRefPtr &device : devices)
for (const cl::DevicePtr &device : devices)
{
nativeDevices.emplace_back(device->getImpl<CLDeviceCL>().getNative());
}
......@@ -408,57 +401,26 @@ CLContextImpl::Ptr CLPlatformCL::createContextFromType(cl::Context &context,
: nullptr);
}
void CLPlatformCL::Initialize(const cl_icd_dispatch &dispatch, bool isIcd)
void CLPlatformCL::Initialize(CreateFuncs &createFuncs, bool isIcd)
{
// Using khrIcdInitialize() of the third party Khronos OpenCL ICD Loader to enumerate the
// available OpenCL implementations on the system. They will be stored in the singly linked
// list khrIcdVendors of the C struct KHRicdVendor.
if (khrIcdVendors != nullptr)
{
return;
}
// The absolute path to ANGLE's OpenCL library is needed and it is assumed here that
// it is in the same directory as the shared library which contains this CL back end.
std::string libPath = angle::GetModuleDirectory();
if (!libPath.empty() && libPath.back() != angle::GetPathSeparator())
{
libPath += angle::GetPathSeparator();
}
libPath += ANGLE_OPENCL_LIB_NAME;
libPath += '.';
libPath += angle::GetSharedLibraryExtension();
// Our OpenCL entry points are not reentrant, so we have to prevent khrIcdInitialize()
// from querying ANGLE's OpenCL library. We store a dummy entry with the library in the
// khrIcdVendors list, because the ICD Loader skips the libraries which are already in
// the list as it assumes they were already enumerated.
static angle::base::NoDestructor<KHRicdVendor> sVendorAngle({});
sVendorAngle->library = khrIcdOsLibraryLoad(libPath.c_str());
khrIcdVendors = sVendorAngle.get();
if (khrIcdVendors->library == nullptr)
{
WARN() << "Unable to load library \"" << libPath << "\"";
return;
}
// Using khrIcdInitialize() of the third party Khronos OpenCL ICD Loader to
// enumerate the available OpenCL implementations on the system. They will be
// stored in the singly linked list khrIcdVendors of the C struct KHRicdVendor.
khrIcdInitialize();
// After the enumeration we don't need ANGLE's OpenCL library any more,
// but we keep the dummy entry int the list to prevent another enumeration.
khrIcdOsLibraryUnload(khrIcdVendors->library);
khrIcdVendors->library = nullptr;
// Iterating through the singly linked list khrIcdVendors to create an ANGLE CL pass-through
// platform for each found ICD platform. Skipping our dummy entry that has an invalid platform.
// The ICD loader will also enumerate ANGLE's OpenCL library if it is registered. Our
// OpenCL entry points for the ICD enumeration are reentrant, but at this point of the
// initialization there are no platforms available, so our platforms will not be found.
// This is intended as this back end should only enumerate non-ANGLE implementations.
// Iterating through the singly linked list khrIcdVendors to create
// an ANGLE CL pass-through platform for each found ICD platform.
for (KHRicdVendor *vendorIt = khrIcdVendors; vendorIt != nullptr; vendorIt = vendorIt->next)
{
if (vendorIt->platform != nullptr)
{
const cl::Platform::CreateImplFunc createImplFunc = [&](const cl::Platform &platform) {
return Ptr(new CLPlatformCL(platform, vendorIt->platform));
};
cl::Platform::CreatePlatform(dispatch, createImplFunc);
}
cl_platform_id nativePlatform = vendorIt->platform;
createFuncs.emplace_back([=](const cl::Platform &platform) {
return Ptr(new CLPlatformCL(platform, nativePlatform));
});
}
}
......
......@@ -21,10 +21,10 @@ class CLPlatformCL : public CLPlatformImpl
cl_platform_id getNative();
Info createInfo() const override;
cl::DevicePtrList createDevices(cl::Platform &platform) const override;
CLDeviceImpl::CreateDatas createDevices() const override;
CLContextImpl::Ptr createContext(cl::Context &context,
const cl::DeviceRefs &devices,
const cl::DevicePtrs &devices,
bool userSync,
cl_int &errorCode) override;
......@@ -33,7 +33,7 @@ class CLPlatformCL : public CLPlatformImpl
bool userSync,
cl_int &errorCode) override;
static void Initialize(const cl_icd_dispatch &dispatch, bool isIcd);
static void Initialize(CreateFuncs &createFuncs, bool isIcd);
private:
CLPlatformCL(const cl::Platform &platform, cl_platform_id native);
......
......@@ -9,10 +9,6 @@
#include "libANGLE/renderer/cl/CLKernelCL.h"
#include "libANGLE/CLKernel.h"
#include "libANGLE/CLProgram.h"
#include "libANGLE/Debug.h"
namespace rx
{
......@@ -58,35 +54,25 @@ CLKernelImpl::Ptr CLProgramCL::createKernel(const cl::Kernel &kernel,
: nullptr);
}
cl_int CLProgramCL::createKernels(cl::Program &program)
cl_int CLProgramCL::createKernels(cl_uint numKernels,
CLKernelImpl::CreateFuncs &createFuncs,
cl_uint *numKernelsRet)
{
cl_uint numKernels = 0u;
cl_int errorCode =
mNative->getDispatch().clCreateKernelsInProgram(mNative, 0u, nullptr, &numKernels);
if (errorCode == CL_SUCCESS)
if (numKernels == 0u)
{
return mNative->getDispatch().clCreateKernelsInProgram(mNative, 0u, nullptr, numKernelsRet);
}
std::vector<cl_kernel> nativeKernels(numKernels, nullptr);
errorCode = mNative->getDispatch().clCreateKernelsInProgram(mNative, numKernels,
nativeKernels.data(), nullptr);
const cl_int errorCode = mNative->getDispatch().clCreateKernelsInProgram(
mNative, numKernels, nativeKernels.data(), nullptr);
if (errorCode == CL_SUCCESS)
{
for (cl_kernel nativeKernel : nativeKernels)
{
// Check that kernel has not already been created.
if (std::find_if(mProgram.getKernels().cbegin(), mProgram.getKernels().cend(),
[=](const cl::KernelPtr &ptr) {
return ptr->getImpl<CLKernelCL>().getNative() == nativeKernel;
}) == mProgram.getKernels().cend())
{
errorCode = program.createKernel([&](const cl::Kernel &kernel) {
createFuncs.emplace_back([=](const cl::Kernel &kernel) {
return CLKernelImpl::Ptr(new CLKernelCL(kernel, nativeKernel));
});
if (errorCode != CL_SUCCESS)
{
break;
}
}
}
}
}
return errorCode;
......
......@@ -27,7 +27,9 @@ class CLProgramCL : public CLProgramImpl
const char *name,
cl_int &errorCode) override;
cl_int createKernels(cl::Program &program) override;
cl_int createKernels(cl_uint numKernels,
CLKernelImpl::CreateFuncs &createFuncs,
cl_uint *numKernelsRet) override;
private:
const cl_program mNative;
......
......@@ -12,7 +12,6 @@
#include "anglebase/no_destructor.h"
#include <string>
#include <unordered_set>
#define ANGLE_SUPPORTED_OPENCL_EXTENSIONS "cl_khr_extended_versioning", "cl_khr_icd"
......
......@@ -47,10 +47,9 @@ cl_int CLDeviceVk::getInfoString(cl::DeviceInfo name, size_t size, char *value)
return CL_INVALID_VALUE;
}
cl_int CLDeviceVk::createSubDevices(cl::Device &device,
const cl_device_partition_property *properties,
cl_int CLDeviceVk::createSubDevices(const cl_device_partition_property *properties,
cl_uint numDevices,
cl::DevicePtrList &subDeviceList,
CreateFuncs &subDevices,
cl_uint *numDevicesRet)
{
return CL_INVALID_VALUE;
......
......@@ -29,10 +29,9 @@ class CLDeviceVk : public CLDeviceImpl
cl_int getInfoStringLength(cl::DeviceInfo name, size_t *value) const override;
cl_int getInfoString(cl::DeviceInfo name, size_t size, char *value) const override;
cl_int createSubDevices(cl::Device &device,
const cl_device_partition_property *properties,
cl_int createSubDevices(const cl_device_partition_property *properties,
cl_uint numDevices,
cl::DevicePtrList &subDeviceList,
CreateFuncs &subDevices,
cl_uint *numDevicesRet) override;
};
......
......@@ -56,23 +56,17 @@ CLPlatformImpl::Info CLPlatformVk::createInfo() const
return info;
}
cl::DevicePtrList CLPlatformVk::createDevices(cl::Platform &platform) const
CLDeviceImpl::CreateDatas CLPlatformVk::createDevices() const
{
cl::DeviceType type; // TODO(jplate) Fetch device type from Vulkan
cl::DevicePtrList devices;
const cl::Device::CreateImplFunc createImplFunc = [](const cl::Device &device) {
return CLDeviceVk::Ptr(new CLDeviceVk(device));
};
devices.emplace_back(cl::Device::CreateDevice(platform, nullptr, type, createImplFunc));
if (!devices.back())
{
devices.clear();
}
return devices;
CLDeviceImpl::CreateDatas createDatas;
createDatas.emplace_back(
type, [](const cl::Device &device) { return CLDeviceVk::Ptr(new CLDeviceVk(device)); });
return createDatas;
}
CLContextImpl::Ptr CLPlatformVk::createContext(cl::Context &context,
const cl::DeviceRefs &devices,
const cl::DevicePtrs &devices,
bool userSync,
cl_int &errorCode)
{
......@@ -89,12 +83,10 @@ CLContextImpl::Ptr CLPlatformVk::createContextFromType(cl::Context &context,
return contextImpl;
}
void CLPlatformVk::Initialize(const cl_icd_dispatch &dispatch)
void CLPlatformVk::Initialize(CreateFuncs &createFuncs)
{
const cl::Platform::CreateImplFunc createImplFunc = [](const cl::Platform &platform) {
return Ptr(new CLPlatformVk(platform));
};
cl::Platform::CreatePlatform(dispatch, createImplFunc);
createFuncs.emplace_back(
[](const cl::Platform &platform) { return Ptr(new CLPlatformVk(platform)); });
}
const std::string &CLPlatformVk::GetVersionString()
......
......@@ -19,10 +19,10 @@ class CLPlatformVk : public CLPlatformImpl
~CLPlatformVk() override;
Info createInfo() const override;
cl::DevicePtrList createDevices(cl::Platform &platform) const override;
CLDeviceImpl::CreateDatas createDevices() const override;
CLContextImpl::Ptr createContext(cl::Context &context,
const cl::DeviceRefs &devices,
const cl::DevicePtrs &devices,
bool userSync,
cl_int &errorCode) override;
......@@ -31,7 +31,7 @@ class CLPlatformVk : public CLPlatformImpl
bool userSync,
cl_int &errorCode) override;
static void Initialize(const cl_icd_dispatch &dispatch);
static void Initialize(CreateFuncs &createFuncs);
static constexpr cl_version GetVersion();
static const std::string &GetVersionString();
......
......@@ -9,6 +9,7 @@
#include "libGLESv2/cl_dispatch_table.h"
#include "libANGLE/CLPlatform.h"
#ifdef ANGLE_ENABLE_CL_PASSTHROUGH
# include "libANGLE/renderer/cl/CLPlatformCL.h"
#endif
......@@ -16,25 +17,51 @@
# include "libANGLE/renderer/vulkan/CLPlatformVk.h"
#endif
#include "anglebase/no_destructor.h"
#include <mutex>
namespace cl
{
void InitBackEnds(bool isIcd)
{
static bool initialized = false;
if (initialized)
enum struct State
{
Uninitialized,
Initializing,
Initialized
};
static State sState = State::Uninitialized;
// Fast thread-unsafe check first
if (sState == State::Initialized)
{
return;
}
static angle::base::NoDestructor<std::recursive_mutex> sMutex;
std::lock_guard<std::recursive_mutex> lock(*sMutex);
// Thread-safe check, return if initialized
// or if already initializing (re-entry from CL pass-through back end)
if (sState != State::Uninitialized)
{
return;
}
initialized = true;
sState = State::Initializing;
rx::CLPlatformImpl::CreateFuncs createFuncs;
#ifdef ANGLE_ENABLE_CL_PASSTHROUGH
rx::CLPlatformCL::Initialize(gCLIcdDispatchTable, isIcd);
rx::CLPlatformCL::Initialize(createFuncs, isIcd);
#endif
#ifdef ANGLE_ENABLE_VULKAN
rx::CLPlatformVk::Initialize(gCLIcdDispatchTable);
rx::CLPlatformVk::Initialize(createFuncs);
#endif
Platform::Initialize(gCLIcdDispatchTable, std::move(createFuncs));
sState = State::Initialized;
}
} // namespace cl
......@@ -8,21 +8,12 @@ import("../../gni/angle.gni")
assert(angle_enable_cl)
cl_library_name = "OpenCL_ANGLE"
if (is_win || is_linux) {
glesv2_path =
rebase_path(get_label_info("$angle_root:libGLESv2", "root_out_dir"))
}
config("opencl_library_name") {
if (is_win) {
defines = [ "ANGLE_OPENCL_LIB_NAME=\"" + cl_library_name + "\"" ]
} else {
defines = [ "ANGLE_OPENCL_LIB_NAME=\"lib" + cl_library_name + "\"" ]
}
}
angle_shared_library(cl_library_name) {
angle_shared_library("OpenCL_ANGLE") {
defines = [ "LIBCL_IMPLEMENTATION" ]
if (is_win) {
defines += [ "ANGLE_GLESV2_LIBRARY_PATH=\"" +
......@@ -56,5 +47,5 @@ angle_shared_library(cl_library_name) {
}
group("angle_cl") {
data_deps = [ ":$cl_library_name" ]
data_deps = [ ":OpenCL_ANGLE" ]
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment