Skip to content

Commit 1e956fe

Browse files
authored
Merge pull request #40 from mski-iksm/feature/add_export_item_embedding
Added item embedding exporting function
2 parents 928a5ce + 243178e commit 1e956fe

2 files changed

Lines changed: 59 additions & 7 deletions

File tree

redshells/model/graph_convolutional_matrix_completion.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,11 @@ def predict_with_new_items(self, user_ids: List, item_ids: List, additional_data
296296

297297
def get_user_feature(self, user_ids: List, item_ids: List, additional_dataset: GcmcDataset, with_user_embedding: bool = True) -> np.ndarray:
298298
dataset = self.graph_dataset.add_dataset(additional_dataset, add_item=True)
299-
return self._get_user_feature(
300-
user_ids=user_ids, item_ids=item_ids, with_user_embedding=with_user_embedding, graph=self.graph, dataset=dataset, session=self.session)
299+
return self._get_feature(user_ids=user_ids, item_ids=item_ids, with_user_embedding=with_user_embedding, graph=self.graph, dataset=dataset, session=self.session, feature='user')
300+
301+
def get_item_feature(self, user_ids: List, item_ids: List, additional_dataset: GcmcDataset, with_user_embedding: bool = True) -> np.ndarray:
302+
dataset = self.graph_dataset.add_dataset(additional_dataset, add_item=True)
303+
return self._get_feature(user_ids=user_ids, item_ids=item_ids, with_user_embedding=with_user_embedding, graph=self.graph, dataset=dataset, session=self.session, feature='item')
301304

302305
@classmethod
303306
def _predict(cls, user_ids: List, item_ids: List, with_user_embedding, graph: GraphConvolutionalMatrixCompletionGraph, dataset: GcmcGraphDataset,
@@ -320,9 +323,9 @@ def _predict(cls, user_ids: List, item_ids: List, with_user_embedding, graph: Gr
320323
return predictions
321324

322325
@classmethod
323-
def _get_user_feature(cls, user_ids: List, item_ids: List, with_user_embedding,
324-
graph: GraphConvolutionalMatrixCompletionGraph, dataset: GcmcGraphDataset,
325-
session: tf.Session) -> np.ndarray:
326+
def _get_feature(cls, user_ids: List, item_ids: List, with_user_embedding,
327+
graph: GraphConvolutionalMatrixCompletionGraph, dataset: GcmcGraphDataset,
328+
session: tf.Session, feature: str) -> np.ndarray:
326329
if graph is None:
327330
RuntimeError('Please call fit first.')
328331

@@ -335,9 +338,10 @@ def _get_user_feature(cls, user_ids: List, item_ids: List, with_user_embedding,
335338
input_data = dict(user=user_indices, item=item_indices, user_feature_indices=user_feature_indices,
336339
item_feature_indices=item_feature_indices)
337340
feed_dict = cls._feed_dict(input_data, graph, dataset, rating_adjacency_matrix, is_train=False)
341+
encoder_map = dict(user=graph.user_encoder, item=graph.item_encoder)
338342
with session.as_default():
339-
user_feature = session.run(graph.user_encoder, feed_dict=feed_dict)
340-
return user_feature
343+
feature = session.run(encoder_map[feature], feed_dict=feed_dict)
344+
return feature
341345

342346
@staticmethod
343347
def _feed_dict(input_data, graph, graph_dataset, rating_adjacency_matrix, dropout_rate: float = 0.0, learning_rate: float = 0.0, is_train: bool = True):
@@ -382,6 +386,12 @@ def get_user_feature_with_new_items(self, item_ids: List, additional_dataset: Gc
382386
indices = [i for i in range(len(users)) if i % len(item_ids) == 0]
383387
return user_ids, user_feature[indices]
384388

389+
def get_item_feature_with_new_items(self, item_ids: List, additional_dataset: GcmcDataset, with_user_embedding: bool = True) -> pd.DataFrame:
390+
user_id = self.graph_dataset.user_ids[0]
391+
users, items = zip(*[(user_id, item_id) for item_id in item_ids])
392+
item_feature = self.get_item_feature(user_ids=users, item_ids=items, additional_dataset=additional_dataset, with_user_embedding=with_user_embedding)
393+
return items, item_feature
394+
385395
def _make_graph(self) -> GraphConvolutionalMatrixCompletionGraph:
386396
return GraphConvolutionalMatrixCompletionGraph(
387397
n_rating=self.graph_dataset.n_rating,

test/model/test_graph_convolutional_matrix_completion.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,48 @@ def test_get_user_feature_with_new_items(self, dummy_get_user_feature):
116116
self.assertEqual(len(user_features[0]), n_users)
117117
self.assertEqual(user_features[1].shape, (n_users, n_user_embed_dimension))
118118

119+
def test_get_item_feature_with_new_items(self):
120+
n_users = 101
121+
n_items = 233
122+
n_data = 3007
123+
am1 = _make_sparse_matrix(n_users, n_items, n_data)
124+
am2 = 2 * _make_sparse_matrix(n_users, n_items, n_data)
125+
adjacency_matrix = am1 + am2
126+
user_ids = adjacency_matrix.tocoo().row
127+
item_ids = adjacency_matrix.tocoo().col
128+
ratings = adjacency_matrix.tocoo().data
129+
item_features = [{i: np.array([i]) for i in range(n_items)}]
130+
dataset = GcmcDataset(user_ids, item_ids, ratings, item_features=item_features)
131+
graph_dataset = GcmcGraphDataset(dataset, test_size=0.1)
132+
encoder_hidden_size = 100
133+
encoder_size = 100
134+
scope_name = 'GraphConvolutionalMatrixCompletionGraph'
135+
model = GraphConvolutionalMatrixCompletion(
136+
graph_dataset=graph_dataset,
137+
encoder_hidden_size=encoder_hidden_size,
138+
encoder_size=encoder_size,
139+
scope_name=scope_name,
140+
batch_size=1024,
141+
epoch_size=10,
142+
learning_rate=0.01,
143+
dropout_rate=0.7,
144+
normalization_type='symmetric')
145+
model.fit()
146+
147+
user_ids = [90, 62, 3, 3]
148+
item_ids = [11, 236, 240, 243]
149+
additional_item_features = {item_id: np.array([999]) for item_id in item_ids}
150+
additional_dataset = GcmcDataset(np.array(user_ids), np.array(item_ids), np.array([1, 2, 1, 1]), item_features=[additional_item_features])
151+
152+
target_item_ids = item_ids + [12, 13, 17, 55] # item_ids to get embeddings
153+
154+
item_feature = model.get_item_feature_with_new_items(item_ids=target_item_ids, additional_dataset=additional_dataset)
155+
self.assertEqual(len(item_feature), 2)
156+
self.assertEqual(list(item_feature[0]), target_item_ids)
157+
self.assertEqual(item_feature[1].shape, (len(target_item_ids), encoder_size))
158+
output_embedding = {k: v for k, v in zip(*item_feature)}
159+
np.testing.assert_almost_equal(output_embedding[240], output_embedding[243])
160+
119161

120162
if __name__ == '__main__':
121163
unittest.main()

0 commit comments

Comments
 (0)