-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
113 lines (101 loc) · 4.52 KB
/
Copy pathutils.py
File metadata and controls
113 lines (101 loc) · 4.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# -*- coding: utf-8 -*-
import dgl
import torch
import networkx as nx
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
def build_graph(args, type, node):
g = dgl.DGLGraph()
# add 34 nodes into the graph; nodes are labeled from 0~33
g.add_nodes(node)
edge_list = []
if type == 'direct':
with open(f'{args.data_dir}graph/K_Directed.txt', 'r') as f:
for line in f.readlines():
line = line.replace('\n', '').split('\t')
edge_list.append((int(line[0]), int(line[1])))
# add edges two lists of nodes: src and dst
src, dst = tuple(zip(*edge_list))
g.add_edges(src, dst)
# edges are directional in DGL; make them bi-directional
# g.add_edges(dst, src)
return g
elif type == 'undirect':
with open(f'{args.data_dir}graph/K_Undirected.txt', 'r') as f:
for line in f.readlines():
line = line.replace('\n', '').split('\t')
edge_list.append((int(line[0]), int(line[1])))
# add edges two lists of nodes: src and dst
src, dst = tuple(zip(*edge_list))
g.add_edges(src, dst)
# edges are directional in DGL; make them bi-directional
g.add_edges(dst, src)
return g
elif type == 'k_from_e':
with open(f'{args.data_dir}graph/k_from_e.txt', 'r') as f:
for line in f.readlines():
line = line.replace('\n', '').split('\t')
edge_list.append((int(line[0]), int(line[1])))
# add edges two lists of nodes: src and dst
src, dst = tuple(zip(*edge_list))
g.add_edges(src, dst)
return g
elif type == 'e_from_k':
with open(f'{args.data_dir}graph/e_from_k.txt', 'r') as f:
for line in f.readlines():
line = line.replace('\n', '').split('\t')
edge_list.append((int(line[0]), int(line[1])))
# add edges two lists of nodes: src and dst
src, dst = tuple(zip(*edge_list))
g.add_edges(src, dst)
return g
elif type == 'u_from_e':
with open(f'{args.data_dir}graph/u_from_e.txt', 'r') as f:
for line in f.readlines():
line = line.replace('\n', '').split('\t')
edge_list.append((int(line[0]), int(line[1])))
# add edges two lists of nodes: src and dst
src, dst = tuple(zip(*edge_list))
g.add_edges(src, dst)
return g
elif type == 'e_from_u':
with open(f'{args.data_dir}graph/e_from_u.txt', 'r') as f:
for line in f.readlines():
line = line.replace('\n', '').split('\t')
edge_list.append((int(line[0]), int(line[1])))
# add edges two lists of nodes: src and dst
src, dst = tuple(zip(*edge_list))
g.add_edges(src, dst)
return g
def construct_local_map(args):
local_map = {
'directed_g': build_graph(args, 'direct', args.n_knowledge),
'undirected_g': build_graph(args, 'undirect', args.n_knowledge),
'k_from_e': build_graph(args, 'k_from_e', args.n_knowledge + args.n_exer),
'e_from_k': build_graph(args, 'e_from_k', args.n_knowledge + args.n_exer),
'u_from_e': build_graph(args, 'u_from_e', args.n_stu + args.n_exer),
'e_from_u': build_graph(args, 'e_from_u', args.n_stu + args.n_exer),
}
return local_map
def construct_relation_graph(args):
local_map = {
'directed_g': build_graph(args, 'direct', args.n_knowledge),
'undirected_g': build_graph(args, 'undirect', args.n_knowledge),
'k_from_e': build_graph(args, 'k_from_e', args.n_knowledge + args.n_exer),
'e_from_k': build_graph(args, 'e_from_k', args.n_knowledge + args.n_exer),
'u_from_e': build_graph(args, 'u_from_e', args.n_stu + args.n_exer),
'e_from_u': build_graph(args, 'e_from_u', args.n_stu + args.n_exer),
}
return local_map
def compute_loss(embedding_1, embedding_2, input_ids, temperature):
node_embedding_1 = F.normalize(embedding_1, dim=1)
node_embedding_2 = F.normalize(embedding_2, dim=1)
batch_node_embedding_1 = node_embedding_1[input_ids]
batch_node_embedding_2 = node_embedding_2[input_ids]
pos_sim_nodes = torch.sum(batch_node_embedding_1*batch_node_embedding_2, dim = -1)
tot_sim_nodes = torch.matmul(batch_node_embedding_1, torch.transpose(node_embedding_2, 0, 1))
ssl_logit = tot_sim_nodes-pos_sim_nodes[:, None]
ssl_loss = torch.logsumexp(ssl_logit/temperature, dim=1)
return ssl_loss.mean()