Commit f2ab65b0 by Ben Clayton

Vulkan: Fix use-after-destruction of VkFence event

Remove sw::TaskEvents and sw::WaitGroup, replace this with sw::CountedEvent. See b/173784261 for details. Fixes: b/173784261 Change-Id: I21fb69c810558a1929bba5cc46f106d9d4e51c4b Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/50628 Kokoro-Result: kokoro <noreply+kokoro@google.com> Reviewed-by: 's avatarChris Forbes <chrisforbes@google.com> Reviewed-by: 's avatarNicolas Capens <nicolascapens@google.com> Tested-by: 's avatarBen Clayton <bclayton@google.com> Presubmit-Ready: Ben Clayton <bclayton@google.com>
parent de9e6494
...@@ -181,7 +181,7 @@ void Renderer::operator delete(void *mem) ...@@ -181,7 +181,7 @@ void Renderer::operator delete(void *mem)
} }
void Renderer::draw(const sw::Context *context, VkIndexType indexType, unsigned int count, int baseVertex, void Renderer::draw(const sw::Context *context, VkIndexType indexType, unsigned int count, int baseVertex,
TaskEvents *events, int instanceID, int viewID, void *indexBuffer, const VkExtent3D &framebufferExtent, CountedEvent *events, int instanceID, int viewID, void *indexBuffer, const VkExtent3D &framebufferExtent,
PushConstantStorage const &pushConstants, bool update) PushConstantStorage const &pushConstants, bool update)
{ {
if(count == 0) { return; } if(count == 0) { return; }
...@@ -421,7 +421,7 @@ void DrawCall::setup() ...@@ -421,7 +421,7 @@ void DrawCall::setup()
if(events) if(events)
{ {
events->start(); events->add();
} }
} }
...@@ -429,7 +429,7 @@ void DrawCall::teardown() ...@@ -429,7 +429,7 @@ void DrawCall::teardown()
{ {
if(events) if(events)
{ {
events->finish(); events->done();
events = nullptr; events = nullptr;
} }
......
...@@ -43,11 +43,11 @@ class PipelineLayout; ...@@ -43,11 +43,11 @@ class PipelineLayout;
namespace sw { namespace sw {
class CountedEvent;
struct DrawCall; struct DrawCall;
class PixelShader; class PixelShader;
class VertexShader; class VertexShader;
struct Task; struct Task;
class TaskEvents;
class Resource; class Resource;
struct Constants; struct Constants;
...@@ -172,7 +172,7 @@ struct DrawCall ...@@ -172,7 +172,7 @@ struct DrawCall
vk::ImageView *stencilBuffer; vk::ImageView *stencilBuffer;
vk::DescriptorSet::Array descriptorSetObjects; vk::DescriptorSet::Array descriptorSetObjects;
const vk::PipelineLayout *pipelineLayout; const vk::PipelineLayout *pipelineLayout;
TaskEvents *events; sw::CountedEvent *events;
vk::Query *occlusionQuery; vk::Query *occlusionQuery;
...@@ -210,7 +210,7 @@ public: ...@@ -210,7 +210,7 @@ public:
bool hasOcclusionQuery() const { return occlusionQuery != nullptr; } bool hasOcclusionQuery() const { return occlusionQuery != nullptr; }
void draw(const sw::Context *context, VkIndexType indexType, unsigned int count, int baseVertex, void draw(const sw::Context *context, VkIndexType indexType, unsigned int count, int baseVertex,
TaskEvents *events, int instanceID, int viewID, void *indexBuffer, const VkExtent3D &framebufferExtent, CountedEvent *events, int instanceID, int viewID, void *indexBuffer, const VkExtent3D &framebufferExtent,
PushConstantStorage const &pushConstants, bool update = true); PushConstantStorage const &pushConstants, bool update = true);
// Viewport & Clipper // Viewport & Clipper
......
...@@ -68,6 +68,11 @@ target_compile_options(vk_system ...@@ -68,6 +68,11 @@ target_compile_options(vk_system
${ROOT_PROJECT_COMPILE_OPTIONS} ${ROOT_PROJECT_COMPILE_OPTIONS}
) )
target_link_libraries(vk_system
PUBLIC
marl
)
target_link_options(vk_system target_link_options(vk_system
PUBLIC PUBLIC
${SWIFTSHADER_LINK_FLAGS} ${SWIFTSHADER_LINK_FLAGS}
......
...@@ -22,103 +22,85 @@ ...@@ -22,103 +22,85 @@
#ifndef sw_Synchronization_hpp #ifndef sw_Synchronization_hpp
#define sw_Synchronization_hpp #define sw_Synchronization_hpp
#include "Debug.hpp"
#include <assert.h> #include <assert.h>
#include <chrono> #include <chrono>
#include <condition_variable> #include <condition_variable>
#include <queue> #include <queue>
#include "marl/event.h"
#include "marl/mutex.h" #include "marl/mutex.h"
#include "marl/waitgroup.h"
namespace sw { namespace sw {
// TaskEvents is an interface for notifying when tasks begin and end. // CountedEvent is an event that is signalled when the internal counter is
// Tasks can be nested and/or overlapping. // decremented and reaches zero.
// TaskEvents is used for task queue synchronization. // The counter is incremented with calls to add() and decremented with calls to
class TaskEvents // done().
class CountedEvent
{ {
public: public:
// start() is called before a task begins. // Constructs the CountedEvent with the initial signalled state set to the
virtual void start() = 0; // provided value.
// finish() is called after a task ends. finish() must only be called after CountedEvent(bool signalled = false)
// a corresponding call to start(). : ev(marl::Event::Mode::Manual, signalled)
virtual void finish() = 0; {}
// complete() is a helper for calling start() followed by finish().
inline void complete() // add() increments the internal counter.
// add() must not be called when the event is already signalled.
void add() const
{ {
start(); ASSERT(!ev.isSignalled());
finish(); wg.add();
} }
protected: // done() decrements the internal counter, signalling the event if the new
virtual ~TaskEvents() = default; // counter value is zero.
}; // done() must not be called when the event is already signalled.
void done() const
// WaitGroup is a synchronization primitive that allows you to wait for
// collection of asynchronous tasks to finish executing.
// Call add() before each task begins, and then call done() when after each task
// is finished.
// At the same time, wait() can be used to block until all tasks have finished.
// WaitGroup takes its name after Golang's sync.WaitGroup.
class WaitGroup : public TaskEvents
{
public:
// add() begins a new task.
void add()
{ {
marl::lock lock(mutex); ASSERT(!ev.isSignalled());
++count_; if(wg.done())
{
ev.signal();
}
} }
// done() is called when a task of the WaitGroup has been completed. // reset() clears the signal state.
// Returns true if there are no more tasks currently running in the // done() must not be called when the internal counter is non-zero.
// WaitGroup. void reset() const
bool done()
{ {
marl::lock lock(mutex); ev.clear();
assert(count_ > 0);
--count_;
if(count_ == 0)
{
condition.notify_all();
}
return count_ == 0;
} }
// wait() blocks until all the tasks have been finished. // signalled() returns the current signal state.
void wait() bool signalled() const
{ {
marl::lock lock(mutex); return ev.isSignalled();
lock.wait(condition, [this]() REQUIRES(mutex) { return count_ == 0; });
} }
// wait() blocks until all the tasks have been finished or the timeout // wait() waits until the event is signalled.
// has been reached, returning true if all tasks have been completed, or void wait() const
// false if the timeout has been reached.
template<class CLOCK, class DURATION>
bool wait(const std::chrono::time_point<CLOCK, DURATION> &timeout)
{ {
marl::lock lock(mutex); ev.wait();
return condition.wait_until(lock, timeout, [this]() REQUIRES(mutex) { return count_ == 0; });
} }
// count() returns the number of times add() has been called without a call // wait() waits until the event is signalled or the timeout is reached.
// to done(). // If the timeout was reached, then wait() return false.
// Note: No lock is held after count() returns, so the count may immediately template<class CLOCK, class DURATION>
// change after returning. bool wait(const std::chrono::time_point<CLOCK, DURATION> &timeout) const
int32_t count()
{ {
marl::lock lock(mutex); return ev.wait_until(timeout);
return count_;
} }
// TaskEvents compliance // event() returns the internal marl event.
void start() override { add(); } const marl::Event &event() { return ev; }
void finish() override { done(); }
private: private:
marl::mutex mutex; const marl::WaitGroup wg;
int32_t count_ GUARDED_BY(mutex) = 0; const marl::Event ev;
std::condition_variable condition;
}; };
// Chan is a thread-safe FIFO queue of type T. // Chan is a thread-safe FIFO queue of type T.
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "VkDescriptorSet.hpp" #include "VkDescriptorSet.hpp"
#include "VkObject.hpp" #include "VkObject.hpp"
#include "Device/Context.hpp" #include "Device/Context.hpp"
#include "System/Synchronization.hpp"
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -27,7 +28,6 @@ namespace sw { ...@@ -27,7 +28,6 @@ namespace sw {
class Context; class Context;
class Renderer; class Renderer;
class TaskEvents;
} // namespace sw } // namespace sw
...@@ -150,7 +150,7 @@ public: ...@@ -150,7 +150,7 @@ public:
}; };
sw::Renderer *renderer = nullptr; sw::Renderer *renderer = nullptr;
sw::TaskEvents *events = nullptr; sw::CountedEvent *events = nullptr;
RenderPass *renderPass = nullptr; RenderPass *renderPass = nullptr;
Framebuffer *renderPassFramebuffer = nullptr; Framebuffer *renderPassFramebuffer = nullptr;
std::array<PipelineState, vk::VK_PIPELINE_BIND_POINT_RANGE_SIZE> pipelineState; std::array<PipelineState, vk::VK_PIPELINE_BIND_POINT_RANGE_SIZE> pipelineState;
......
...@@ -246,7 +246,7 @@ VkResult Device::waitForFences(uint32_t fenceCount, const VkFence *pFences, VkBo ...@@ -246,7 +246,7 @@ VkResult Device::waitForFences(uint32_t fenceCount, const VkFence *pFences, VkBo
marl::containers::vector<marl::Event, 8> events; marl::containers::vector<marl::Event, 8> events;
for(uint32_t i = 0; i < fenceCount; i++) for(uint32_t i = 0; i < fenceCount; i++)
{ {
events.push_back(Cast(pFences[i])->getEvent()); events.push_back(Cast(pFences[i])->getCountedEvent()->event());
} }
auto any = marl::Event::any(events.begin(), events.end()); auto any = marl::Event::any(events.begin(), events.end());
......
...@@ -18,17 +18,13 @@ ...@@ -18,17 +18,13 @@
#include "VkObject.hpp" #include "VkObject.hpp"
#include "System/Synchronization.hpp" #include "System/Synchronization.hpp"
#include "marl/containers.h"
#include "marl/event.h"
#include "marl/waitgroup.h"
namespace vk { namespace vk {
class Fence : public Object<Fence, VkFence>, public sw::TaskEvents class Fence : public Object<Fence, VkFence>
{ {
public: public:
Fence(const VkFenceCreateInfo *pCreateInfo, void *mem) Fence(const VkFenceCreateInfo *pCreateInfo, void *mem)
: event(marl::Event::Mode::Manual, (pCreateInfo->flags & VK_FENCE_CREATE_SIGNALED_BIT) != 0) : counted_event(std::make_shared<sw::CountedEvent>((pCreateInfo->flags & VK_FENCE_CREATE_SIGNALED_BIT) != 0))
{} {}
static size_t ComputeRequiredAllocationSize(const VkFenceCreateInfo *pCreateInfo) static size_t ComputeRequiredAllocationSize(const VkFenceCreateInfo *pCreateInfo)
...@@ -38,49 +34,38 @@ public: ...@@ -38,49 +34,38 @@ public:
void reset() void reset()
{ {
event.clear(); counted_event->reset();
}
void complete()
{
counted_event->add();
counted_event->done();
} }
VkResult getStatus() VkResult getStatus()
{ {
return event.isSignalled() ? VK_SUCCESS : VK_NOT_READY; return counted_event->signalled() ? VK_SUCCESS : VK_NOT_READY;
} }
VkResult wait() VkResult wait()
{ {
event.wait(); counted_event->wait();
return VK_SUCCESS; return VK_SUCCESS;
} }
template<class CLOCK, class DURATION> template<class CLOCK, class DURATION>
VkResult wait(const std::chrono::time_point<CLOCK, DURATION> &timeout) VkResult wait(const std::chrono::time_point<CLOCK, DURATION> &timeout)
{ {
return event.wait_until(timeout) ? VK_SUCCESS : VK_TIMEOUT; return counted_event->wait(timeout) ? VK_SUCCESS : VK_TIMEOUT;
}
const marl::Event &getEvent() const { return event; }
// TaskEvents compliance
void start() override
{
ASSERT(!event.isSignalled());
wg.add();
} }
void finish() override const std::shared_ptr<sw::CountedEvent> &getCountedEvent() const { return counted_event; };
{
ASSERT(!event.isSignalled());
if(wg.done())
{
event.signal();
}
}
private: private:
Fence(const Fence &) = delete; Fence(const Fence &) = delete;
marl::WaitGroup wg; const std::shared_ptr<sw::CountedEvent> counted_event;
const marl::Event event;
}; };
static inline Fence *Cast(VkFence object) static inline Fence *Cast(VkFence object)
......
...@@ -102,11 +102,10 @@ VkResult Queue::submit(uint32_t submitCount, const VkSubmitInfo *pSubmits, Fence ...@@ -102,11 +102,10 @@ VkResult Queue::submit(uint32_t submitCount, const VkSubmitInfo *pSubmits, Fence
Task task; Task task;
task.submitCount = submitCount; task.submitCount = submitCount;
task.pSubmits = DeepCopySubmitInfo(submitCount, pSubmits); task.pSubmits = DeepCopySubmitInfo(submitCount, pSubmits);
task.events = fence; if(fence)
if(task.events)
{ {
task.events->start(); task.events = fence->getCountedEvent();
task.events->add();
} }
pending.put(task); pending.put(task);
...@@ -132,7 +131,7 @@ void Queue::submitQueue(const Task &task) ...@@ -132,7 +131,7 @@ void Queue::submitQueue(const Task &task)
{ {
CommandBuffer::ExecutionState executionState; CommandBuffer::ExecutionState executionState;
executionState.renderer = renderer.get(); executionState.renderer = renderer.get();
executionState.events = task.events; executionState.events = task.events.get();
for(uint32_t j = 0; j < submitInfo.commandBufferCount; j++) for(uint32_t j = 0; j < submitInfo.commandBufferCount; j++)
{ {
vk::Cast(submitInfo.pCommandBuffers[j])->submit(executionState); vk::Cast(submitInfo.pCommandBuffers[j])->submit(executionState);
...@@ -155,7 +154,7 @@ void Queue::submitQueue(const Task &task) ...@@ -155,7 +154,7 @@ void Queue::submitQueue(const Task &task)
// TODO: fix renderer signaling so that work submitted separately from (but before) a fence // TODO: fix renderer signaling so that work submitted separately from (but before) a fence
// is guaranteed complete by the time the fence signals. // is guaranteed complete by the time the fence signals.
renderer->synchronize(); renderer->synchronize();
task.events->finish(); task.events->done();
} }
} }
...@@ -187,14 +186,14 @@ void Queue::taskLoop(marl::Scheduler *scheduler) ...@@ -187,14 +186,14 @@ void Queue::taskLoop(marl::Scheduler *scheduler)
VkResult Queue::waitIdle() VkResult Queue::waitIdle()
{ {
// Wait for task queue to flush. // Wait for task queue to flush.
sw::WaitGroup wg; auto event = std::make_shared<sw::CountedEvent>();
wg.add(); event->add(); // done() is called at the end of submitQueue()
Task task; Task task;
task.events = &wg; task.events = event;
pending.put(task); pending.put(task);
wg.wait(); event->wait();
garbageCollect(); garbageCollect();
......
...@@ -65,7 +65,7 @@ private: ...@@ -65,7 +65,7 @@ private:
{ {
uint32_t submitCount = 0; uint32_t submitCount = 0;
VkSubmitInfo *pSubmits = nullptr; VkSubmitInfo *pSubmits = nullptr;
sw::TaskEvents *events = nullptr; std::shared_ptr<sw::CountedEvent> events;
enum Type enum Type
{ {
......
...@@ -21,12 +21,14 @@ test("swiftshader_system_unittests") { ...@@ -21,12 +21,14 @@ test("swiftshader_system_unittests") {
"//testing/gmock", "//testing/gmock",
"//testing/gtest", "//testing/gtest",
"../../src/System", "../../src/System",
"../../third_party/marl:Marl",
] ]
sources = [ sources = [
"//gpu/swiftshader_tests_main.cc", "//gpu/swiftshader_tests_main.cc",
"LRUCacheTests.cpp", "LRUCacheTests.cpp",
"unittests.cpp", "unittests.cpp",
"SynchronizationTests.cpp",
] ]
include_dirs = [ include_dirs = [
......
...@@ -26,6 +26,7 @@ set(SYSTEM_UNIT_TESTS_SRC_FILES ...@@ -26,6 +26,7 @@ set(SYSTEM_UNIT_TESTS_SRC_FILES
LRUCacheTests.cpp LRUCacheTests.cpp
main.cpp main.cpp
unittests.cpp unittests.cpp
SynchronizationTests.cpp
) )
add_executable(system-unittests add_executable(system-unittests
......
// Copyright 2020 The SwiftShader Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "System/Synchronization.hpp"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <thread>
TEST(EventCounter, ConstructUnsignalled)
{
sw::CountedEvent ev;
ASSERT_FALSE(ev.signalled());
}
TEST(EventCounter, ConstructSignalled)
{
sw::CountedEvent ev(true);
ASSERT_TRUE(ev.signalled());
}
TEST(EventCounter, Reset)
{
sw::CountedEvent ev(true);
ev.reset();
ASSERT_FALSE(ev.signalled());
}
TEST(EventCounter, AddUnsignalled)
{
sw::CountedEvent ev;
ev.add();
ASSERT_FALSE(ev.signalled());
}
TEST(EventCounter, AddDoneUnsignalled)
{
sw::CountedEvent ev;
ev.add();
ev.done();
ASSERT_TRUE(ev.signalled());
}
TEST(EventCounter, Wait)
{
sw::CountedEvent ev;
bool b = false;
ev.add();
auto t = std::thread([=, &b] {
b = true;
ev.done();
});
ev.wait();
ASSERT_TRUE(b);
t.join();
}
TEST(EventCounter, WaitNoTimeout)
{
sw::CountedEvent ev;
bool b = false;
ev.add();
auto t = std::thread([=, &b] {
b = true;
ev.done();
});
ASSERT_TRUE(ev.wait(std::chrono::system_clock::now() + std::chrono::seconds(10)));
ASSERT_TRUE(b);
t.join();
}
TEST(EventCounter, WaitTimeout)
{
sw::CountedEvent ev;
ev.add();
ASSERT_FALSE(ev.wait(std::chrono::system_clock::now() + std::chrono::milliseconds(1)));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment