Fix DataParallel issues for GP-VAE, FILM, and FITS on multi-GPU setups#819
Fix DataParallel issues for GP-VAE, FILM, and FITS on multi-GPU setups#819Claude wants to merge 2 commits into
Conversation
- GP-VAE: Pre-compute kernel matrices as buffers and create prior distribution on-the-fly to avoid DataParallel replication issues - FITS: Handle complex numbers by splitting into real/imaginary parts instead of using .to(torch.cfloat) on Linear layers - FILM: Handle complex einsum operations by splitting into real/imaginary parts to avoid DataParallel issues with complex tensors Co-authored-by: WenjieDu <17807970+WenjieDu@users.noreply.github.com>
|
This pull request had no activity for 14 days. It will be closed in 1 week unless there is some new activity. |
|
Coverage Report for CI Build 24951327727Coverage decreased (-0.03%) to 79.961%Details
Uncovered Changes
Coverage RegressionsNo coverage regressions found. Coverage Stats
💛 - Coveralls |
There was a problem hiding this comment.
Pull request overview
This pull request addresses multi-GPU (torch.nn.DataParallel) runtime failures in the GP-VAE, FITS, and FiLM model backbones by removing non-replicable state, avoiding complex-valued module parameters, and rewriting complex operations into real/imag arithmetic.
Changes:
- GP-VAE: precomputes and buffers the GP prior covariance and constructs a fresh prior distribution each forward pass.
- FITS: removes complex-typed
nn.Linearparameters and applies the upsampling transform to real/imag parts separately. - FiLM: rewrites complex
einsumoperations as real/imageinsum+ recombination.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
pypots/nn/modules/gpvae/backbone.py |
Precomputes GP prior covariance as a buffer and creates the prior distribution per-forward for DataParallel safety. |
pypots/nn/modules/fits/backbone.py |
Reworks FFT upsampling to avoid complex Linear parameters by operating on real/imag parts. |
pypots/nn/modules/film/layers.py |
Splits complex einsum into real/imag computations to avoid DataParallel einsum issues. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| low_specx_i = low_specx[:, :, i] # Shape: (batch, dominance_freq) | ||
| # Split into real and imaginary parts | ||
| real_part = self.freq_upsampler[i](low_specx_i.real) | ||
| imag_part = self.freq_upsampler[i](low_specx_i.imag) | ||
| # Recombine into complex tensor | ||
| low_specxy_[:, :, i] = torch.complex(real_part, imag_part) |
| low_specx_permuted = low_specx.permute(0, 2, 1) | ||
| # Split into real and imaginary parts | ||
| real_part = self.freq_upsampler(low_specx_permuted.real) | ||
| imag_part = self.freq_upsampler(low_specx_permuted.imag) | ||
| # Recombine and permute back | ||
| low_specxy_ = torch.complex(real_part, imag_part).permute(0, 2, 1) |
| # Register as buffer so it will be moved to the correct device by DataParallel | ||
| self.register_buffer("prior_covariance", kernel_matrix_tiled) |
| return torch.distributions.MultivariateNormal( | ||
| loc=torch.zeros(self.latent_dim, self.time_length, device=device), | ||
| covariance_matrix=kernel_matrix_tiled.to(device), | ||
| covariance_matrix=self.prior_covariance.to(device), | ||
| ) |
| # Register index as a buffer to ensure it's properly handled by DataParallel | ||
| self.register_buffer("index_buffer", torch.tensor(self.index, dtype=torch.long)) | ||
|
|
| a = x_ft[:, :, :, : self.modes2] | ||
| out_ft[:, :, :, : self.modes2] = torch.einsum("bjix,iox->bjox", a, self.weights1) | ||
| # Handle complex einsum by splitting into real and imaginary parts | ||
| # to avoid issues with DataParallel | ||
| a_real = a.real |
|
@claude[agent] Now it's your turn. Resolve the above suggestions from Copilot. |



Multi-GPU training with
torch.nn.DataParallelfails for GP-VAE, FILM, and FITS models with runtime errors related to distribution initialization, complex tensor operations, and einsum dimension mismatches.Changes
GP-VAE (
pypots/nn/modules/gpvae/backbone.py)__init__and register as buffer viaregister_buffer("prior_covariance", ...)_get_prior()method instead of cachingFITS (
pypots/nn/modules/fits/backbone.py).to(torch.cfloat)from Linear layer initializationtorch.complex()FILM (
pypots/nn/modules/film/layers.py)(a+bi)(c+di) = (ac-bd) + (ad+bc)iTechnical rationale
DataParallel replicates models to each GPU but struggles with:
Solution uses buffers for device-aware tensors and explicit real/imaginary arithmetic, consistent with patterns from PR #633 (Koopa, USGAN, CRLI fixes).
Usage
Original prompt
This section details on the original issue you should resolve
<issue_title>FITS/FILM/GP-VAE fail when running on multiple CUDA devices</issue_title>
<issue_description>### 1. System Info
v0.11
2. Information
3. Reproduction
4. Expected behavior
For pypots.forecasting.fits and pypots.imputation.fits we have
For pypots.imputation.film we have
For pypots.imputation.gpvae we have