Skip to content
This repository was archived by the owner on Dec 2, 2025. It is now read-only.

Commit 98443ab

Browse files
Merge pull request #24 from LyaaaaaGames/Develop
Release 1.3.0
2 parents d7042be + 1d7c1e0 commit 98443ab

11 files changed

Lines changed: 448 additions & 186 deletions

conda_config.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,7 @@ dependencies:
99
- websockets=10.4
1010
- transformers=4.27
1111
- sentencepiece
12-
- accelerate
12+
- accelerate
13+
- torchvision
14+
- torchaudio
15+
- cpuonly

conda_config_cuda.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ name: aidventure
22
channels:
33
- conda-forge
44
- pytorch
5+
- nvidia
56
dependencies:
67
- pip
78
- python=3.9.7
89
- pytorch
10+
- torchvision
11+
- torchaudio
12+
- pytorch-cuda=11.8
913
- websockets=10.4
1014
- transformers=4.27
1115
- sentencepiece
12-
- accelerate
13-
- torchvision
14-
- cudatoolkit
16+
- accelerate

server/config.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#--
1010
#-- Implementation Notes (Leave empty if nothing to say):
1111
#-- - This is the config file used by the server.
12+
#-- - The settings here have priority over the client's settings.
13+
#-- Setting them to None will give the priority to the client.
1214
#--
1315
#-- Anticipated changes (Leave empty if nothing to say):
1416
#-- -
@@ -29,15 +31,49 @@
2931
#--
3032
#-- 09/11/2022 Lyaaaaa
3133
#-- - Set LOG_LEVEL default value back to INFO
34+
#--
35+
#-- 04/05/2022 Lyaaaaa
36+
#-- - Added a new section "Models". This section contains settings for the
37+
#-- Model class.
38+
#--
39+
#-- 05/05/2022 Lyaaaaa
40+
#-- - Import torch_dtype to support the usage of an enum for the dtypes.
41+
#-- - Added OFFLOAD_DICT to the settings. When True, it avoids RAM peak when
42+
#-- loading a model.
43+
#--
44+
#-- 18/09/2023 Lyaaaaa
45+
#-- - LOG_FILEMODE default value is now "a" again. The log file is now
46+
#-- manually deleted to avoid losing logs.
3247
#---------------------------------------------------------------------------
3348

3449
import logging
50+
from torch_dtype import Torch_Dtypes
3551

3652
# Network
3753
HOST = "0.0.0.0"
3854
PORT = 9999
3955

4056
# Logs
4157
LOG_FILENAME = "server_logs.text"
42-
LOG_FILEMODE = "w"
58+
LOG_FILEMODE = "a"
4359
LOG_LEVEL = logging.INFO
60+
61+
# Models.
62+
#See possible values here: https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.from_pretrained
63+
64+
TOKENIZERS_PATH = "models/"
65+
MODELS_PATH = "models/"
66+
DEFAULT_MODEL = "EleutherAI/gpt-neo-125M"
67+
ALLOW_DOWNLOAD = None # True/False/None. If True, the server will download AI's files.
68+
ALLOW_OFFLOAD = None # True/False/None
69+
OFFLOAD_FOLDER = "offload-" # Prefix to the temp folder.
70+
LOW_CPU_MEM_USAGE = None # True/False/None
71+
LIMIT_MEMORY = None # True/False/None
72+
OFFLOAD_DICT = None # True/False/None
73+
74+
# https://huggingface.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map
75+
# MAX_MEMORY must be a dict. E.G {0: "30GB", 1: "46GB", [x: "yMB/yGB"], "cpu": "20000MB"}. x is a gpu.
76+
MAX_MEMORY = None # None/dict/See documentation
77+
DEVICE_MAP = None # None/see documentation
78+
TORCH_DTYPE = None # "Auto"/None/torch.dtype/See torch_dtype.py for more info.
79+

server/downloader.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

server/generator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,17 @@
2525
#-- - p_parameters aren't sent into generate() anymore. They are now given
2626
#-- to a GenerationConfig object which is an attribute (generation_config)
2727
#-- of the Model. generate() automatically uses these config.
28+
#--
29+
#-- - 05/05/2023 Lyaaaaa
30+
#-- - The condition for moving the inputs to the gpu is now "is_cuda_available"
31+
#-- and not checking the is_gpu_enabled attribute anymore.
32+
#-- - Import logger to display a log when loading the inputs in the gpu.
33+
#-- - Called _empty_gpu_cache after the generation. This releases some memory.
2834
#------------------------------------------------------------------------------
2935

3036
from model import Model
3137
from transformers import GenerationConfig
38+
import logger
3239

3340
class Generator(Model):
3441

@@ -44,12 +51,14 @@ def generate_text(self,
4451
model_input = p_memory + p_context + p_prompt
4552
model_input = self._Tokenizer(model_input, return_tensors = "pt")
4653

47-
if self.is_gpu_enabled:
54+
if self.is_cuda_available:
55+
logger.log.info("Loading inputs to GPU")
4856
model_input.to("cuda")
4957

5058
self._Model.generation_config = GenerationConfig(**p_parameters)
5159

5260
model_output = self._Model.generate(**model_input)
5361
generated_text = self._Tokenizer.decode(model_output[0], skip_special_tokens=True)
5462

63+
self._empty_gpu_cache()
5564
return generated_text

server/logger.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,32 @@
1616
#-- Changelog:
1717
#-- 24/02/2022 Lyaaaaa
1818
#-- - Created the file.
19+
#--
20+
#-- 18/09/2023 Lyaaaaa
21+
#-- - Added delete_log_file function.
22+
#-- - Updated init_logger to call delete_log_file.
1923
#---------------------------------------------------------------------------
2024
import logging
2125
import config
26+
import os
2227

2328
log = None
2429

2530
def init_logger():
2631
global log
32+
delete_log_file()
2733
logging.basicConfig(filename = config.LOG_FILENAME,
2834
filemode = config.LOG_FILEMODE,
2935
format = '%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
3036
datefmt = '%H:%M:%S')
3137
log = logging.getLogger("AIdventure_Server")
3238
log.setLevel(config.LOG_LEVEL)
3339
log.addHandler(logging.StreamHandler())
40+
41+
42+
#------------------------------------------------------------------------------
43+
#
44+
#------------------------------------------------------------------------------
45+
def delete_log_file():
46+
if os.path.exists(config.LOG_FILENAME):
47+
os.remove(config.LOG_FILENAME)

0 commit comments

Comments
 (0)