@@ -296,8 +296,11 @@ def predict_with_new_items(self, user_ids: List, item_ids: List, additional_data
296296
297297 def get_user_feature (self , user_ids : List , item_ids : List , additional_dataset : GcmcDataset , with_user_embedding : bool = True ) -> np .ndarray :
298298 dataset = self .graph_dataset .add_dataset (additional_dataset , add_item = True )
299- return self ._get_user_feature (
300- user_ids = user_ids , item_ids = item_ids , with_user_embedding = with_user_embedding , graph = self .graph , dataset = dataset , session = self .session )
299+ return self ._get_feature (user_ids = user_ids , item_ids = item_ids , with_user_embedding = with_user_embedding , graph = self .graph , dataset = dataset , session = self .session , feature = 'user' )
300+
301+ def get_item_feature (self , user_ids : List , item_ids : List , additional_dataset : GcmcDataset , with_user_embedding : bool = True ) -> np .ndarray :
302+ dataset = self .graph_dataset .add_dataset (additional_dataset , add_item = True )
303+ return self ._get_feature (user_ids = user_ids , item_ids = item_ids , with_user_embedding = with_user_embedding , graph = self .graph , dataset = dataset , session = self .session , feature = 'item' )
301304
302305 @classmethod
303306 def _predict (cls , user_ids : List , item_ids : List , with_user_embedding , graph : GraphConvolutionalMatrixCompletionGraph , dataset : GcmcGraphDataset ,
@@ -320,9 +323,9 @@ def _predict(cls, user_ids: List, item_ids: List, with_user_embedding, graph: Gr
320323 return predictions
321324
322325 @classmethod
323- def _get_user_feature (cls , user_ids : List , item_ids : List , with_user_embedding ,
324- graph : GraphConvolutionalMatrixCompletionGraph , dataset : GcmcGraphDataset ,
325- session : tf .Session ) -> np .ndarray :
326+ def _get_feature (cls , user_ids : List , item_ids : List , with_user_embedding ,
327+ graph : GraphConvolutionalMatrixCompletionGraph , dataset : GcmcGraphDataset ,
328+ session : tf .Session , feature : str ) -> np .ndarray :
326329 if graph is None :
327330 RuntimeError ('Please call fit first.' )
328331
@@ -335,9 +338,10 @@ def _get_user_feature(cls, user_ids: List, item_ids: List, with_user_embedding,
335338 input_data = dict (user = user_indices , item = item_indices , user_feature_indices = user_feature_indices ,
336339 item_feature_indices = item_feature_indices )
337340 feed_dict = cls ._feed_dict (input_data , graph , dataset , rating_adjacency_matrix , is_train = False )
341+ encoder_map = dict (user = graph .user_encoder , item = graph .item_encoder )
338342 with session .as_default ():
339- user_feature = session .run (graph . user_encoder , feed_dict = feed_dict )
340- return user_feature
343+ feature = session .run (encoder_map [ feature ] , feed_dict = feed_dict )
344+ return feature
341345
342346 @staticmethod
343347 def _feed_dict (input_data , graph , graph_dataset , rating_adjacency_matrix , dropout_rate : float = 0.0 , learning_rate : float = 0.0 , is_train : bool = True ):
@@ -382,6 +386,12 @@ def get_user_feature_with_new_items(self, item_ids: List, additional_dataset: Gc
382386 indices = [i for i in range (len (users )) if i % len (item_ids ) == 0 ]
383387 return user_ids , user_feature [indices ]
384388
389+ def get_item_feature_with_new_items (self , item_ids : List , additional_dataset : GcmcDataset , with_user_embedding : bool = True ) -> pd .DataFrame :
390+ user_id = self .graph_dataset .user_ids [0 ]
391+ users , items = zip (* [(user_id , item_id ) for item_id in item_ids ])
392+ item_feature = self .get_item_feature (user_ids = users , item_ids = items , additional_dataset = additional_dataset , with_user_embedding = with_user_embedding )
393+ return items , item_feature
394+
385395 def _make_graph (self ) -> GraphConvolutionalMatrixCompletionGraph :
386396 return GraphConvolutionalMatrixCompletionGraph (
387397 n_rating = self .graph_dataset .n_rating ,
0 commit comments