Calling the .get_c() method from the embedder instance created in the example
embedder = CPCProtEmbedding(model)
ends up throwing a CUDA out of memory error, because it is tracking the gradients
def get_c(self, data, return_mask = False):
z, mask = self.get_z(data, return_mask=True)
if self.parallel:
# workaround for accessing model attributes when DataParallel
c = self.cpc(data, return_early='c')
else:
c = self.cpc.get_c(z)
wrapping the forward pass in a "with torch.no_grad():" worked for me
def get_c(self, data, return_mask = False):
z, mask = self.get_z(data, return_mask=True)
with torch.no_grad():
if self.parallel:
# workaround for accessing model attributes when DataParallel
c = self.cpc(data, return_early='c')
else:
c = self.cpc.get_c(z)
Calling the .get_c() method from the embedder instance created in the example
ends up throwing a CUDA out of memory error, because it is tracking the gradients
wrapping the forward pass in a "with torch.no_grad():" worked for me