-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtime_aware_pe.py
More file actions
46 lines (41 loc) · 1.67 KB
/
Copy pathtime_aware_pe.py
File metadata and controls
46 lines (41 loc) · 1.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from utils import fix_length
class TAPE(nn.Module):
def __init__(self, dropout,device):
super(TAPE, self).__init__()
self.device =device
self.dropout = nn.Dropout(dropout)
def forward(self, x, time, data_size):
b, n, d = x.shape
# (b, n)
time_ = torch.clone(time)
time_[:, 1:] = time[:, :-1]
mask = [torch.ones(e, dtype=torch.float32) for e in data_size]
# (b, n)
mask = fix_length(mask, 1, n, "exclude padding term").to(self.device)
# (b, n)
interval = time - time_
interval = interval.masked_fill(mask == 0, 0.0)
sum_interval = (interval.sum(dim=-1)).reshape(b, -1)
sum_interval = sum_interval.masked_fill(sum_interval == 0, 1)
num_interval = (mask.sum(dim=-1) - 1).reshape(b, -1)
num_interval = num_interval.masked_fill(num_interval == 0, 1)
avg_interval = sum_interval / num_interval
interval /= avg_interval
# (b, n)
pos = torch.zeros_like(time)
pos[:, 0] = 1.
for k in range(1, n):
pos[:, k] = pos[:, k - 1] + interval[:, k] + 1
pos = pos.masked_fill(mask == 0, 0.0)
div_term = torch.exp(torch.arange(0, d, 2) * -(math.log(10000.0) / d)).to(self.device)
# (b, n, d)
tape = torch.zeros_like(x,requires_grad=False)
tape[:, :, 0::2] = torch.sin(pos[:].unsqueeze(-1) * div_term.unsqueeze(0))
tape[:, :, 1::2] = torch.cos(pos[:].unsqueeze(-1) * div_term.unsqueeze(0))
# x += Variable(tape, requires_grad=False)
x+=tape
return self.dropout(x)