Skip to content

Fix DataParallel issues for GP-VAE, FILM, and FITS on multi-GPU setups#819

Draft
Claude wants to merge 2 commits into
devfrom
claude/fix-gp-vae-fail-multi-cuda
Draft

Fix DataParallel issues for GP-VAE, FILM, and FITS on multi-GPU setups#819
Claude wants to merge 2 commits into
devfrom
claude/fix-gp-vae-fail-multi-cuda

Conversation

@Claude

@Claude Claude AI commented Mar 16, 2026

Copy link
Copy Markdown

Multi-GPU training with torch.nn.DataParallel fails for GP-VAE, FILM, and FITS models with runtime errors related to distribution initialization, complex tensor operations, and einsum dimension mismatches.

Changes

GP-VAE (pypots/nn/modules/gpvae/backbone.py)

  • Pre-compute kernel matrices in __init__ and register as buffer via register_buffer("prior_covariance", ...)
  • Create prior distribution fresh per forward pass in new _get_prior() method instead of caching
  • Eliminates "lazy wrapper should be called at most once" error from shared distribution state across GPU replicas

FITS (pypots/nn/modules/fits/backbone.py)

  • Remove .to(torch.cfloat) from Linear layer initialization
  • Split complex FFT outputs into real/imaginary components, apply Linear transformations separately, then recombine with torch.complex()
  • Resolves "t() expects tensor with <= 2 dimensions" error from complex dtype parameters

FILM (pypots/nn/modules/film/layers.py)

  • Decompose complex einsum operations: (a+bi)(c+di) = (ac-bd) + (ad+bc)i
  • Apply einsum to real/imaginary parts separately, then recombine
  • Fixes "einsum() subscript mismatch" error in DataParallel context

Technical rationale

DataParallel replicates models to each GPU but struggles with:

  1. Shared mutable state (distribution objects)
  2. Non-standard parameter dtypes (complex Linear layers)
  3. Complex tensor operations in certain contexts

Solution uses buffers for device-aware tensors and explicit real/imaginary arithmetic, consistent with patterns from PR #633 (Koopa, USGAN, CRLI fixes).

Usage

# Single GPU (existing behavior unchanged)
model = GPVAE(n_steps, n_features, latent_size, device="cuda:0")

# Multiple GPUs (now works)
model = GPVAE(n_steps, n_features, latent_size, device=["cuda:0", "cuda:1"])
Original prompt

This section details on the original issue you should resolve

<issue_title>FITS/FILM/GP-VAE fail when running on multiple CUDA devices</issue_title>
<issue_description>### 1. System Info

v0.11

2. Information

  • The official example scripts
  • My own created scripts

3. Reproduction

  • pypots.clustering.crli
  • pypots.imputation.usgan
  • pypots.imputation.koopa
  • pypots.imputation.film
  • pypots.imputation.gpvae
  • pypots.imputation.fits
  • pypots.forecasting.fits

4. Expected behavior

For pypots.forecasting.fits and pypots.imputation.fits we have

E       RuntimeError: Caught RuntimeError in replica 0 on device 1.
E       Original Traceback (most recent call last):
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
E           output = module(*input, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
E           return self._call_impl(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
E           return forward_call(*args, **kwargs)
E         File "/home/wdudu/PyPOTS_dev/pypots/forecasting/fits/core.py", line 68, in forward
E           enc_out = self.backbone(enc_out)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
E           return self._call_impl(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
E           return forward_call(*args, **kwargs)
E         File "/home/wdudu/PyPOTS_dev/pypots/nn/modules/fits/backbone.py", line 63, in forward
E           low_specxy_ = self.freq_upsampler(low_specx.permute(0, 2, 1))
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
E           return self._call_impl(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
E           return forward_call(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
E           return F.linear(input, self.weight, self.bias)
E       RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D

For pypots.imputation.film we have

E       RuntimeError: Caught RuntimeError in replica 0 on device 1.
E       Original Traceback (most recent call last):
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
E           output = module(*input, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
E           return self._call_impl(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
E           return forward_call(*args, **kwargs)
E         File "/home/wdudu/PyPOTS_dev/pypots/imputation/film/core.py", line 65, in forward
E           backbone_output = self.backbone(X_embedding)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
E           return self._call_impl(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
E           return forward_call(*args, **kwargs)
E         File "/home/wdudu/PyPOTS_dev/pypots/nn/modules/film/backbone.py", line 65, in forward
E           out1 = self.spec_conv_1[i](x_in_c)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
E           return self._call_impl(*args, **kwargs)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
E           return forward_call(*args, **kwargs)
E         File "/home/wdudu/PyPOTS_dev/pypots/nn/modules/film/layers.py", line 128, in forward
E           out_ft[:, :, :, : self.modes2] = torch.einsum("bjix,iox->bjox", a, self.weights1)
E         File "/home/wdudu/.conda/envs/ml/lib/python3.10/site-packages/torch/functional.py", line 380, in einsum
E           return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
E       RuntimeError: einsum(): the number of subscripts in the equation (3) does not match the number of dimensions (4) for operand 1 and no ellipsis was given

For pypots.imputation.gpvae we have

E       RuntimeError: Caught RuntimeError in replica 1 on device 2.
E       Original Traceback (most recent call last):
E         File...

</details>



<!-- START COPILOT CODING AGENT SUFFIX -->

- Fixes WenjieDu/PyPOTS#632

- GP-VAE: Pre-compute kernel matrices as buffers and create prior distribution on-the-fly to avoid DataParallel replication issues
- FITS: Handle complex numbers by splitting into real/imaginary parts instead of using .to(torch.cfloat) on Linear layers
- FILM: Handle complex einsum operations by splitting into real/imaginary parts to avoid DataParallel issues with complex tensors

Co-authored-by: WenjieDu <17807970+WenjieDu@users.noreply.github.com>
@Claude Claude AI changed the title [WIP] [FITS-123] Fix GP-VAE fail when running on multiple CUDA devices Fix DataParallel issues for GP-VAE, FILM, and FITS on multi-GPU setups Mar 16, 2026
@Claude Claude AI requested a review from WenjieDu March 16, 2026 16:51
@github-actions

Copy link
Copy Markdown

This pull request had no activity for 14 days. It will be closed in 1 week unless there is some new activity.

@github-actions github-actions Bot added the stale label Mar 31, 2026
@github-actions github-actions Bot closed this Apr 7, 2026
@WenjieDu WenjieDu reopened this Apr 7, 2026
@WenjieDu WenjieDu added keep Keep this issue away from being stale. and removed stale labels Apr 7, 2026
@WenjieDu WenjieDu changed the base branch from main to dev April 26, 2026 07:36
@WenjieDu WenjieDu closed this Apr 26, 2026
@WenjieDu WenjieDu reopened this Apr 26, 2026
@sonarqubecloud

Copy link
Copy Markdown

@coveralls

coveralls commented Apr 26, 2026

Copy link
Copy Markdown
Collaborator

Coverage Report for CI Build 24951327727

Coverage decreased (-0.03%) to 79.961%

Details

  • Coverage decreased (-0.03%) from the base build.
  • Patch coverage: 13 uncovered changes across 3 files (28 of 41 lines covered, 68.29%).
  • No coverage regressions found.

Uncovered Changes

File Changed Covered %
pypots/nn/modules/film/layers.py 15 8 53.33%
pypots/nn/modules/fits/backbone.py 9 5 55.56%
pypots/nn/modules/gpvae/backbone.py 17 15 88.24%

Coverage Regressions

No coverage regressions found.


Coverage Stats

Coverage Status
Relevant Lines: 19028
Covered Lines: 15215
Line Coverage: 79.96%
Coverage Strength: 1.6 hits per line

💛 - Coveralls

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request addresses multi-GPU (torch.nn.DataParallel) runtime failures in the GP-VAE, FITS, and FiLM model backbones by removing non-replicable state, avoiding complex-valued module parameters, and rewriting complex operations into real/imag arithmetic.

Changes:

  • GP-VAE: precomputes and buffers the GP prior covariance and constructs a fresh prior distribution each forward pass.
  • FITS: removes complex-typed nn.Linear parameters and applies the upsampling transform to real/imag parts separately.
  • FiLM: rewrites complex einsum operations as real/imag einsum + recombination.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

File Description
pypots/nn/modules/gpvae/backbone.py Precomputes GP prior covariance as a buffer and creates the prior distribution per-forward for DataParallel safety.
pypots/nn/modules/fits/backbone.py Reworks FFT upsampling to avoid complex Linear parameters by operating on real/imag parts.
pypots/nn/modules/film/layers.py Splits complex einsum into real/imag computations to avoid DataParallel einsum issues.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +56 to +61
low_specx_i = low_specx[:, :, i] # Shape: (batch, dominance_freq)
# Split into real and imaginary parts
real_part = self.freq_upsampler[i](low_specx_i.real)
imag_part = self.freq_upsampler[i](low_specx_i.imag)
# Recombine into complex tensor
low_specxy_[:, :, i] = torch.complex(real_part, imag_part)
Comment on lines +66 to +71
low_specx_permuted = low_specx.permute(0, 2, 1)
# Split into real and imaginary parts
real_part = self.freq_upsampler(low_specx_permuted.real)
imag_part = self.freq_upsampler(low_specx_permuted.imag)
# Recombine and permute back
low_specxy_ = torch.complex(real_part, imag_part).permute(0, 2, 1)
Comment on lines +121 to +122
# Register as buffer so it will be moved to the correct device by DataParallel
self.register_buffer("prior_covariance", kernel_matrix_tiled)
Comment on lines +145 to 148
return torch.distributions.MultivariateNormal(
loc=torch.zeros(self.latent_dim, self.time_length, device=device),
covariance_matrix=kernel_matrix_tiled.to(device),
covariance_matrix=self.prior_covariance.to(device),
)
Comment on lines +110 to 112
# Register index as a buffer to ensure it's properly handled by DataParallel
self.register_buffer("index_buffer", torch.tensor(self.index, dtype=torch.long))

Comment on lines 139 to +142
a = x_ft[:, :, :, : self.modes2]
out_ft[:, :, :, : self.modes2] = torch.einsum("bjix,iox->bjox", a, self.weights1)
# Handle complex einsum by splitting into real and imaginary parts
# to avoid issues with DataParallel
a_real = a.real
@WenjieDu

Copy link
Copy Markdown
Owner

@claude[agent] Now it's your turn. Resolve the above suggestions from Copilot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

keep Keep this issue away from being stale.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants