165165#-- - Extracted from init to load the code related to the loading of files.
166166#-- - Splitted download into download_model and download_tokens.
167167#-- - Splitted save into save_model and save_tokens.
168+ #--
169+ #-- - 19/09/2023 Lyaaaaa
170+ #-- - Updated _set_model_parameters to set all the parameters only for the
171+ #-- generators (except low_memory_mode which is used by the translator too).
172+ #-- - Removed some log from _load as they are repeating themself.
173+ #-- - Updated _download_tokens to directly use self._model_name.
174+ #-- - Fixed an error in _download_model. It used model_name which doesn't
175+ #-- exist. Now it uses self._model_name.
176+ #-- - Updated the logs in _load_model and _load_tokens.
168177#------------------------------------------------------------------------------
169178
170179from transformers import AutoModelForCausalLM , AutoModelForSeq2SeqLM , AutoTokenizer
@@ -230,29 +239,30 @@ def _set_model_parameters(self, p_parameters : dict):
230239 if self ._low_memory_mode == None :
231240 self ._low_memory_mode = p_parameters ["low_memory_mode" ]
232241
233- if self ._limit_memory == False :
234- self ._max_memory = None
235- elif self ._limit_memory == None and p_parameters ["limit_memory" ] == True :
236- self ._max_memory = {0 : p_parameters ["max_memory" ]["0" ],
237- "cpu" : p_parameters ["max_memory" ]["cpu" ]}
242+ if self ._model_type == Model_Type .GENERATION .value :
243+ if self ._limit_memory == False :
244+ self ._max_memory = None
245+ elif self ._limit_memory == None and p_parameters ["limit_memory" ] == True :
246+ self ._max_memory = {0 : p_parameters ["max_memory" ]["0" ],
247+ "cpu" : p_parameters ["max_memory" ]["cpu" ]}
238248
239- if self ._allow_offload == True :
240- self .create_offload_folder ()
241- elif self ._allow_offload == None and p_parameters ["allow_offload" ] == True :
242- self .create_offload_folder ()
249+ if self ._allow_offload == True :
250+ self .create_offload_folder ()
251+ elif self ._allow_offload == None and p_parameters ["allow_offload" ] == True :
252+ self .create_offload_folder ()
243253
244254
245- if self ._allow_download == None :
246- self ._allow_download = p_parameters ["allow_download" ]
255+ if self ._allow_download == None :
256+ self ._allow_download = p_parameters ["allow_download" ]
247257
248- if self ._device_map == None :
249- self ._device_map = p_parameters ["device_map" ]
258+ if self ._device_map == None :
259+ self ._device_map = p_parameters ["device_map" ]
250260
251- if self ._torch_dtype == None :
252- self ._torch_dtype = Torch_Dtypes .dtypes .value [p_parameters ["torch_dtype" ]]
261+ if self ._torch_dtype == None :
262+ self ._torch_dtype = Torch_Dtypes .dtypes .value [p_parameters ["torch_dtype" ]]
253263
254- if self ._offload_dict == None :
255- self ._offload_dict = p_parameters ["offload_dict" ]
264+ if self ._offload_dict == None :
265+ self ._offload_dict = p_parameters ["offload_dict" ]
256266
257267#------------------------------------------------------------------------------
258268#--
@@ -264,7 +274,6 @@ def _load(self):
264274 if self ._allow_download == True :
265275 self ._download_tokens ()
266276 else :
267- logger .log .error ("Couldn't load the tokens files." )
268277 logger .log .info ("Downloading with the server is disabled" )
269278 else :
270279 logger .log .info ("Tokens successfully loaded from local files" )
@@ -280,7 +289,6 @@ def _load(self):
280289 if self ._allow_download == True :
281290 self ._download_model ()
282291 else :
283- logger .log .error ("Couldn't load the model " + self ._model_name )
284292 logger .log .info ("Downloading with the server is disabled." )
285293 else :
286294 logger .log .info ("Model successfully loaded from local files" )
@@ -295,7 +303,7 @@ def _load_tokens(self):
295303 try :
296304 self ._Tokenizer = AutoTokenizer .from_pretrained (self ._tokenizers_path )
297305 except Exception as e :
298- logger .log .error ("Token file in ' " + self ._tokenizers_path + "' not found." )
306+ logger .log .error ("Error loading tokens in " + self ._tokenizers_path )
299307 logger .log .error (e )
300308 return False
301309
@@ -320,7 +328,7 @@ def _load_model(self):
320328 self ._Model = AutoModelForCausalLM .from_pretrained (self ._model_path ,
321329 ** args )
322330 except Exception as e :
323- logger .log .error ("An unexpected error happened while loading the model" )
331+ logger .log .error ("Error loading the model " + self . _model_name )
324332 logger .log .error (e )
325333 return False
326334
@@ -360,9 +368,8 @@ def _save_model(self):
360368#--
361369#------------------------------------------------------------------------------
362370 def _download_tokens (self ):
363- model_name = self ._model_name
364371 logger .log .info ("Trying to download the tokenizer..." )
365- self ._Tokenizer = AutoTokenizer .from_pretrained (model_name ,
372+ self ._Tokenizer = AutoTokenizer .from_pretrained (self . _model_name ,
366373 cache_dir = "cache" ,
367374 resume_download = True )
368375 self ._save_tokens ()
@@ -374,11 +381,11 @@ def _download_tokens(self):
374381 def _download_model (self ):
375382 logger .log .info ("Trying to download the model..." )
376383 if self ._model_type == Model_Type .GENERATION .value :
377- self ._Model = AutoModelForCausalLM .from_pretrained (model_name ,
384+ self ._Model = AutoModelForCausalLM .from_pretrained (self . _model_name ,
378385 cache_dir = "cache" ,
379386 resume_download = True )
380387 elif self ._model_type == Model_Type .TRANSLATION .value :
381- self ._Model = AutoModelForSeq2SeqLM .from_pretrained (model_name ,
388+ self ._Model = AutoModelForSeq2SeqLM .from_pretrained (self . _model_name ,
382389 cache_dir = "cache" ,
383390 resume_download = True )
384391 self ._save_model ()
0 commit comments