Skip to content

Commit c03f46f

Browse files
jhuber6llvmbot
authored andcommitted
[Offload] Properly guard modifications to the RPC device array (#126790)
Summary: If the user deallocates an RPC device this can sometimes fail if the RPC server is still running. This will happen if the modification happens while the server is still checking it. This patch adds a mutex to guard modifications to it. (cherry picked from commit baf7a3c)
1 parent 553185b commit c03f46f

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

offload/plugins-nextgen/common/include/RPC.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ struct RPCServerTy {
7272
/// Array of associated devices. These must be alive as long as the server is.
7373
std::unique_ptr Devices;
7474

75+
/// Mutex that guards accesses to the buffers and device array.
76+
std::mutex BufferMutex{};
77+
7578
/// A helper class for running the user thread that handles the RPC interface.
7679
/// Because we only need to check the RPC server while any kernels are
7780
/// working, we track submission / completion events to allow the thread to
@@ -90,6 +93,9 @@ struct RPCServerTy {
9093
std::condition_variable CV;
9194
std::mutex Mutex;
9295

96+
/// A reference to the main server's mutex.
97+
std::mutex &BufferMutex;
98+
9399
/// A reference to all the RPC interfaces that the server is handling.
94100
llvm::ArrayRef<void *> Buffers;
95101

@@ -98,9 +104,9 @@ struct RPCServerTy {
98104

99105
/// Initialize the worker thread to run in the background.
100106
ServerThread(void *Buffers[], plugin::GenericDeviceTy *Devices[],
101-
size_t Length)
102-
: Running(false), NumUsers(0), CV(), Mutex(), Buffers(Buffers, Length),
103-
Devices(Devices, Length) {}
107+
size_t Length, std::mutex &BufferMutex)
108+
: Running(false), NumUsers(0), CV(), Mutex(), BufferMutex(BufferMutex),
109+
Buffers(Buffers, Length), Devices(Devices, Length) {}
104110

105111
~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); }
106112

offload/plugins-nextgen/common/src/RPC.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ void RPCServerTy::ServerThread::run() {
131131
Lock.unlock();
132132
while (NumUsers.load(std::memory_order_relaxed) > 0 &&
133133
Running.load(std::memory_order_relaxed)) {
134+
std::lock_guard<decltype(Mutex)> Lock(BufferMutex);
134135
for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) {
135136
if (!Buffer || !Device)
136137
continue;
@@ -149,7 +150,7 @@ RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin)
149150
Devices(std::make_unique(
150151
Plugin.getNumDevices())),
151152
Thread(new ServerThread(Buffers.get(), Devices.get(),
152-
Plugin.getNumDevices())) {}
153+
Plugin.getNumDevices(), BufferMutex)) {}
153154

154155
llvm::Error RPCServerTy::startThread() {
155156
Thread->startThread();
@@ -190,13 +191,15 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
190191
if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client,
191192
sizeof(rpc::Client), nullptr))
192193
return Err;
194+
std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex);
193195
Buffers[Device.getDeviceId()] = RPCBuffer;
194196
Devices[Device.getDeviceId()] = &Device;
195197

196198
return Error::success();
197199
}
198200

199201
Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
202+
std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex);
200203
Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST);
201204
Buffers[Device.getDeviceId()] = nullptr;
202205
Devices[Device.getDeviceId()] = nullptr;

0 commit comments

Comments
 (0)