Skip to content
This repository was archived by the owner on Dec 2, 2025. It is now read-only.

Commit 980effc

Browse files
Merge pull request #22 from LyaaaaaGames/Fix_Translator_Not_Loading
Fixed an error with the translator and some typos
2 parents 12e3d25 + d6975fa commit 980effc

2 files changed

Lines changed: 47 additions & 35 deletions

File tree

server/model.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,15 @@
165165
#-- - Extracted from init to load the code related to the loading of files.
166166
#-- - Splitted download into download_model and download_tokens.
167167
#-- - Splitted save into save_model and save_tokens.
168+
#--
169+
#-- - 19/09/2023 Lyaaaaa
170+
#-- - Updated _set_model_parameters to set all the parameters only for the
171+
#-- generators (except low_memory_mode which is used by the translator too).
172+
#-- - Removed some log from _load as they are repeating themself.
173+
#-- - Updated _download_tokens to directly use self._model_name.
174+
#-- - Fixed an error in _download_model. It used model_name which doesn't
175+
#-- exist. Now it uses self._model_name.
176+
#-- - Updated the logs in _load_model and _load_tokens.
168177
#------------------------------------------------------------------------------
169178

170179
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
@@ -230,29 +239,30 @@ def _set_model_parameters(self, p_parameters : dict):
230239
if self._low_memory_mode == None:
231240
self._low_memory_mode = p_parameters["low_memory_mode"]
232241

233-
if self._limit_memory == False:
234-
self._max_memory = None
235-
elif self._limit_memory == None and p_parameters["limit_memory"] == True:
236-
self._max_memory = {0 : p_parameters["max_memory"]["0"],
237-
"cpu" : p_parameters["max_memory"]["cpu"]}
242+
if self._model_type == Model_Type.GENERATION.value:
243+
if self._limit_memory == False:
244+
self._max_memory = None
245+
elif self._limit_memory == None and p_parameters["limit_memory"] == True:
246+
self._max_memory = {0 : p_parameters["max_memory"]["0"],
247+
"cpu" : p_parameters["max_memory"]["cpu"]}
238248

239-
if self._allow_offload == True:
240-
self.create_offload_folder()
241-
elif self._allow_offload == None and p_parameters["allow_offload"] == True:
242-
self.create_offload_folder()
249+
if self._allow_offload == True:
250+
self.create_offload_folder()
251+
elif self._allow_offload == None and p_parameters["allow_offload"] == True:
252+
self.create_offload_folder()
243253

244254

245-
if self._allow_download == None:
246-
self._allow_download = p_parameters["allow_download"]
255+
if self._allow_download == None:
256+
self._allow_download = p_parameters["allow_download"]
247257

248-
if self._device_map == None:
249-
self._device_map = p_parameters["device_map"]
258+
if self._device_map == None:
259+
self._device_map = p_parameters["device_map"]
250260

251-
if self._torch_dtype == None:
252-
self._torch_dtype = Torch_Dtypes.dtypes.value[p_parameters["torch_dtype"]]
261+
if self._torch_dtype == None:
262+
self._torch_dtype = Torch_Dtypes.dtypes.value[p_parameters["torch_dtype"]]
253263

254-
if self._offload_dict == None:
255-
self._offload_dict = p_parameters["offload_dict"]
264+
if self._offload_dict == None:
265+
self._offload_dict = p_parameters["offload_dict"]
256266

257267
#------------------------------------------------------------------------------
258268
#--
@@ -264,7 +274,6 @@ def _load(self):
264274
if self._allow_download == True:
265275
self._download_tokens()
266276
else:
267-
logger.log.error("Couldn't load the tokens files.")
268277
logger.log.info("Downloading with the server is disabled")
269278
else:
270279
logger.log.info("Tokens successfully loaded from local files")
@@ -280,7 +289,6 @@ def _load(self):
280289
if self._allow_download == True:
281290
self._download_model()
282291
else:
283-
logger.log.error("Couldn't load the model " + self._model_name)
284292
logger.log.info("Downloading with the server is disabled.")
285293
else:
286294
logger.log.info("Model successfully loaded from local files")
@@ -295,7 +303,7 @@ def _load_tokens(self):
295303
try:
296304
self._Tokenizer = AutoTokenizer.from_pretrained(self._tokenizers_path)
297305
except Exception as e:
298-
logger.log.error("Token file in '" + self._tokenizers_path + "' not found.")
306+
logger.log.error("Error loading tokens in " + self._tokenizers_path)
299307
logger.log.error(e)
300308
return False
301309

@@ -320,7 +328,7 @@ def _load_model(self):
320328
self._Model = AutoModelForCausalLM.from_pretrained(self._model_path,
321329
**args)
322330
except Exception as e:
323-
logger.log.error("An unexpected error happened while loading the model")
331+
logger.log.error("Error loading the model " + self._model_name)
324332
logger.log.error(e)
325333
return False
326334

@@ -360,9 +368,8 @@ def _save_model(self):
360368
#--
361369
#------------------------------------------------------------------------------
362370
def _download_tokens(self):
363-
model_name = self._model_name
364371
logger.log.info("Trying to download the tokenizer...")
365-
self._Tokenizer = AutoTokenizer.from_pretrained(model_name,
372+
self._Tokenizer = AutoTokenizer.from_pretrained(self._model_name,
366373
cache_dir = "cache",
367374
resume_download = True)
368375
self._save_tokens()
@@ -374,11 +381,11 @@ def _download_tokens(self):
374381
def _download_model(self):
375382
logger.log.info("Trying to download the model...")
376383
if self._model_type == Model_Type.GENERATION.value:
377-
self._Model = AutoModelForCausalLM.from_pretrained(model_name,
384+
self._Model = AutoModelForCausalLM.from_pretrained(self._model_name,
378385
cache_dir = "cache",
379386
resume_download = True)
380387
elif self._model_type == Model_Type.TRANSLATION.value:
381-
self._Model = AutoModelForSeq2SeqLM.from_pretrained(model_name,
388+
self._Model = AutoModelForSeq2SeqLM.from_pretrained(self._model_name,
382389
cache_dir = "cache",
383390
resume_download = True)
384391
self._save_model()

server/server.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@
9696
#-- - Replaced the print here and there by logger.log.info
9797
#-- - Updated the exceptions handlers in handler to display the error.
9898
#-- - Updated shutdown_server to receive an exit code and to display a message.
99+
#--
100+
#-- - 19/09/2023 Lyaaaaa
101+
#-- - Fixed a syntax error in handler.
102+
#-- - Updated handle_request to define parameters differently depending of
103+
#-- the model type (generator or translator).
99104
#------------------------------------------------------------------------------
100105

101106
import asyncio
@@ -135,7 +140,7 @@ async def handler(p_websocket, path):
135140
await p_websocket.send(data_to_send)
136141

137142
except websockets.exceptions.ConnectionClosed as e:
138-
logger.info(e)
143+
logger.info.error(e)
139144
exit_code = 0
140145
shutdown_server(exit_code)
141146

@@ -175,16 +180,15 @@ def handle_request(p_websocket, p_data : dict):
175180
shutdown_server()
176181

177182
elif request == Request.LOAD_MODEL.value:
178-
parameters = {"low_memory_mode" : p_data['low_memory_mode'],
179-
"allow_offload" : p_data['allow_offload'],
180-
"limit_memory" : p_data['limit_memory'],
181-
"max_memory" : p_data['max_memory'],
182-
"allow_download" : p_data['allow_download'],
183-
"device_map" : p_data['device_map'],
184-
"torch_dtype" : p_data['torch_dtype'],
185-
"offload_dict" : p_data['offload_dict'],}
186-
187183
if p_data["model_type"] == Model_Type.GENERATION.value:
184+
parameters = {"low_memory_mode" : p_data['low_memory_mode'],
185+
"allow_offload" : p_data['allow_offload'],
186+
"limit_memory" : p_data['limit_memory'],
187+
"max_memory" : p_data['max_memory'],
188+
"allow_download" : p_data['allow_download'],
189+
"device_map" : p_data['device_map'],
190+
"torch_dtype" : p_data['torch_dtype'],
191+
"offload_dict" : p_data['offload_dict'],}
188192
del generator
189193
logger.log.debug("loading generator")
190194
model_name = p_data['model_name']
@@ -195,6 +199,7 @@ def handle_request(p_websocket, p_data : dict):
195199
logger.log.info("Is CUDA available: " + format(generator.is_cuda_available))
196200

197201
elif p_data["model_type"] == Model_Type.TRANSLATION.value:
202+
parameters = {"low_memory_mode" : p_data['low_memory_mode']}
198203
logger.log.debug("loading translator")
199204
model_name = p_data["to_eng_model"]
200205
to_eng_translator = Translator(model_name,

0 commit comments

Comments
 (0)