Configure GPU index for 'astra_cuda', select GPU currently used by PyTorch in OperatorModule #1546
Configure GPU index for 'astra_cuda', select GPU currently used by PyTorch in OperatorModule #1546jleuschn wants to merge 4 commits into
Conversation
|
Hey @jleuschn, thanks a lot for your effort! From a coding point of view, it looks very good. However, I'm afraid that this solution is too hacky and not future-proof. That's mostly due to circumstances and not your fault. Let me give my reasoning here.
Okay, this has become way longer than I expected, and I realize it will not be a trivial change. I will look into it tonight make a suggestion. But I have no way to test with multiple GPUs, so you would have to help me out @jleuschn. Does that sound okay? |
|
Thanks, @kohr-h , for checking the request and pointing out the issues above! While parameters and buffers are broadcasted to the GPUs in PyTorch's Yes, i can help out testing on multiple GPUs! |
|
Good point about the replication thing! Hm, so the first call to Regarding the larger question of whether it's worth the effort. Currently I have my doubts. The whole thing is quite inefficient anyway since each ray transform does the whole roundtrip CPU->GPU->CPU no matter what, and if it's wrapped into an |
|
Yes, i ran some speed test. It seems that it only makes a difference if the GPUs are heavily used already by the rest of the network. TBH, i don't fully understand why, considering the mentioned chain, maybe the chains are not in sync between the different GPUs, so some ray trafo runs in parallel to another layer? |
|
Very good, thanks for doing the speed test! Indeed, the gain is not nothing, but certainly not what you would hope for when throwing N times the compute power at the problem. So I agree, for now it's not necessary to invest time, but it's good to know about this limitation and that we need to think about solutions at some point. |
This pull request implements two new features:
'gpu_index'inRayTransformBaseOperatorModuleby setting'gpu_index'The second feature feels a little hacky, since it assumes the special role of the
'gpu_index'property if existing for anyOperatorinstance that is wrapped by theOperatorModule. However this seems to me to be the most non-invasive way to implement this behaviour, since otherwise probablyRayTransformBasewould have to know about torch, e.g. offeringgpu_index = 'torch_current'.