-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
RFC: Add a hook for detecting task switches. #39994
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
In TimerOutputs.jl there's been some discussion about thread safety and the ability to time different tasks (KristofferC/TimerOutputs.jl#80). Perhaps something like this could be used for that, and in that case, allowing multiple callbacks seem desirable. |
64df01d
to
2a3b42b
Compare
OK, I made it a list of hooks. |
2a3b42b
to
c82e339
Compare
Certain libraries are configured using global or thread-local state instead of passing handles to every function. CUDA, for example, has a `cudaSetDevice` function that binds a device to the current thread for all future API calls. This is at odds with Julia's task-based concurrency, which presents an execution environment that's local to the current task (e.g., in the case of CUDA, using a different device). This PR adds a hook mechanism that can be used to detect task switches, and synchronize Julia's task-local environment with the library's global or thread-local state.
c82e339
to
0c2fb63
Compare
This seems quite intrusive to me. Would you rather have this, or a very-fast-access task-local pointer? And/or, we could make this hook task-specific rather than global. As it is, this could be called for millions of unrelated task switches. |
Yeah it's pretty intrusive... The alternative (and current approach) seems even worse though: before every CUDA-related API call or operation, check whether the task-local state matches the global one. Furthermore, querying CUDA's global state takes 10s of ns, which is too slow to to perform before every API call, so we cache that in a thread-local buffer. All that is what leads to the hot mess that is https://github.com/JuliaGPU/CUDA.jl/blob/master/src/state.jl. With the task switch hook, at least we only pay that cost when switching tasks. And in that hook we can check if the switched-to task has any CUDA state in its task-local storage, and only conditionally set-up CUDA's global state.
I'd very much like that, but it would still require comparing the task-local state to the global CUDA state (or its per-thread cached counterpart) on every API call & operation, which is pretty expensive and fragile.
I considered that, but we generally don't know which tasks are going to be performing GPU computations. A user can just fire up a task, import CUDA.jl, and perform API calls. Maybe there's another solution, I've been staring at this approach for a while now, so feel free to suggest other ideas. |
I don't think we'll promise to maintain a single-thread cooperative non-migrating scheduler, so this seems like an unfeasible approach. I agree we should improve the performance of task-local-storage access. Let me see if there's some easy things we can do to improve that anyways. |
@@ -507,6 +518,17 @@ JL_DLLEXPORT void jl_switch(void) | |||
jl_error("cannot switch to task running on another thread"); | |||
} | |||
|
|||
if (jl_task_switch_hooks) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a lock (or some atomics) here? If so, wouldn't it increase the overhead of the task switch even if there are no hooks?
If CUDA.jl wants to manage global states coupled to OS threads, we need an ("unsafe") API asking to not task switch/migrate within a given scope, don't we? But then, if we have such an API, I guess I'm missing what's necessary beyond |
Looking at some semi-realistic applications, the number of API calls is orders of magnitude larger than the number of task switches, so from a performance PoV it seems better to pay that cost when switching tasks rather than checking the state on every API call. EDIT: concrete example, doing an AlphaZero.jl run:
The hook could be extended to inform about thread migration, so why is this unfeasible?
But I do want to switch tasks, since that's useful for overlapping computation on a GPU (by asynchronously submitting work from different tasks), or for working with multiple devices. |
IIUC, it sounds like you need two things. One is thread-local storage for interacting with external libraries and another is what I call context variables (#35833) for tracking asynchronous events across tasks. I don't know if that's enough for GPU, but, in general, it'd be great if we can orthogonalize the API. |
Adding another use case that this PR could help improve: Dagger.jl estimates the cost of tasks being scheduled by timing them with a form of |
Certain libraries are configured using global or thread-local state
instead of passing handles to every function. CUDA, for example, has a
cudaSetDevice
function that binds a device to the current thread forall future API calls. This is at odds with Julia's task-based
concurrency, which presents an execution environment that's local to the
current task (e.g., in the case of CUDA, using a different device per task).
This PR adds a hook mechanism that can be used to detect task switches,
and synchronize Julia's task-local environment with the library's global
or thread-local state.
TODO/questions:
Intended use:
FWIW, the overhead of calling a no-op hook is around 5ns (I compared against saving the function and doing a
jl_call1
, which took around 30ns).