Skip to content

Commit 3fd8d1c

Browse files
committed
Enhance train_tf_ps.py with CNN architecture updates, and dataset resizing.
- Update Conv2D kernel sizes to 5 for improved feature extraction. - Increase epochs to 100 and adjust batch size to 32 for better training. - Add TensorBoard callback for training visualization (commented and in progress). - Modify default image resizing dimensions to 256x320. - Implement GPU memory growth configuration to prevent OOM errors. - Plot Mean Absolute Error (MAE) during training for insights. - Expand `.gitignore` to exclude Python artifacts, logs, and additional dataset files.
1 parent bbd9752 commit 3fd8d1c

2 files changed

Lines changed: 73 additions & 11 deletions

File tree

.gitignore

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,50 @@
1-
.idea
1+
# Terraform
22
**/.gcp/**/*.json
33
*.terraform*
44
*tfstate*
5-
6-
*config-kind-in-container*
75
*output/
6+
7+
#Datasets and models
8+
*.keras
89
*image-datasets
10+
11+
# Python
12+
__pycache__/
13+
develop-eggs/
14+
dist/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
*.egg-info/
24+
.installed.cfg
25+
*.egg
26+
27+
# Virtual Environment
28+
.env
29+
.venv
30+
env/
31+
venv/
32+
ENV/
33+
34+
# IDE
35+
.idea
36+
.vscode/
37+
38+
# Logs and databases
39+
*.log
40+
*.sqlite
41+
*.db
42+
43+
# Build and documentation
44+
build/
45+
docs/_build/
46+
.coverage
47+
coverage.xml
48+
*.cover
49+
htmlcov/
50+
*config-kind-in-container*

workloads/raw-tf/train_tf_ps.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,15 @@
2323
"""
2424
import argparse
2525
import csv
26+
import datetime
2627
import io
2728
import json
2829
import os
2930
import sys
3031
from typing import List, Tuple, Optional
3132

33+
import matplotlib.pyplot as plt
34+
3235
# TensorFlow-specific setting to reduce the amount of logging output, hiding INFO and WARNING
3336
# and only showing ERROR.
3437
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
@@ -328,13 +331,13 @@ def build_cnn_model(input_shape: Tuple[int, int, int], num_outputs: int = 2, fla
328331
model = tf.keras.Sequential(
329332
[
330333
tf.keras.layers.Input(shape=input_shape),
331-
tf.keras.layers.Conv2D(32, 3, padding="same"),
334+
tf.keras.layers.Conv2D(32, 5, padding="same"),
332335
tf.keras.layers.PReLU(),
333336
tf.keras.layers.MaxPooling2D(),
334-
tf.keras.layers.Conv2D(64, 3, padding="same"),
337+
tf.keras.layers.Conv2D(64, 5, padding="same"),
335338
tf.keras.layers.PReLU(),
336339
tf.keras.layers.MaxPooling2D(),
337-
tf.keras.layers.Conv2D(128, 3, padding="same"),
340+
tf.keras.layers.Conv2D(128, 5, padding="same"),
338341
tf.keras.layers.PReLU(),
339342
tf.keras.layers.Flatten() if flat else tf.keras.layers.GlobalAveragePooling2D(),
340343
tf.keras.layers.Dense(2592, activation="relu") if flat else tf.keras.layers.Dense(128, activation="relu"),
@@ -743,7 +746,12 @@ def step_fn(inputs):
743746
input_context=None,
744747
)
745748
model = build_cnn_model(input_shape, num_outputs=2, flat=flat_layer,)
746-
history = model.fit(ds, epochs=epochs, steps_per_epoch=steps_per_epoch)
749+
# log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
750+
# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
751+
history = model.fit(ds, epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=[tensorboard_callback])
752+
plt.plot(history.history['mae'])
753+
plt.xlabel('epoch')
754+
plt.show()
747755

748756
save_path = os.path.join(output_dir, "model.keras")
749757
model.save(save_path)
@@ -760,11 +768,11 @@ def parse_args(argv: List[str]):
760768
parser.add_argument("--data-path", default=os.environ.get("DATA_PATH", "/app/infra/local/mysql-database/datasets/image-datasets/laser-spots"), help="Path to CSV or image root directory")
761769
parser.add_argument("--data-url", default=os.environ.get("DATA_URL", "/app/infra/local/mysql-database/datasets/csvs/health.csv"), help="HTTP(S) URL to CSV (used inside cluster if path not mounted)")
762770
parser.add_argument("--data-is-images", action="store_false", help="Treat data-path as folder-per-class image dataset")
763-
parser.add_argument("--img-height", type=int, default=int(os.environ.get("IMG_HEIGHT", "180")), help="Image height for resizing")
764-
parser.add_argument("--img-width", type=int, default=int(os.environ.get("IMG_WIDTH", "180")), help="Image width for resizing")
771+
parser.add_argument("--img-height", type=int, default=int(os.environ.get("IMG_HEIGHT", "256")), help="Image height for resizing")
772+
parser.add_argument("--img-width", type=int, default=int(os.environ.get("IMG_WIDTH", "320")), help="Image width for resizing")
765773
parser.add_argument("--output-dir", default=os.environ.get("OUTPUT_DIR", "./tf-model"))
766-
parser.add_argument("--epochs", type=int, default=int(os.environ.get("EPOCHS", "3")))
767-
parser.add_argument("--batch-size", type=int, default=int(os.environ.get("BATCH_SIZE", "64")))
774+
parser.add_argument("--epochs", type=int, default=int(os.environ.get("EPOCHS", "100")))
775+
parser.add_argument("--batch-size", type=int, default=int(os.environ.get("BATCH_SIZE", "32")))
768776
parser.add_argument("--use-ps", action="store_true", help="Enable ParameterServerStrategy coordinator mode")
769777
parser.add_argument("--worker-replicas", type=int, default=int(os.environ.get("WORKER_REPLICAS", "2")))
770778
parser.add_argument("--ps-replicas", type=int, default=int(os.environ.get("PS_REPLICAS", "1")))
@@ -777,6 +785,18 @@ def parse_args(argv: List[str]):
777785

778786

779787
if __name__ == "__main__":
788+
# Configure GPU to prevent out-of-memory errors
789+
gpus = tf.config.list_physical_devices("GPU")
790+
if gpus:
791+
try:
792+
# Allow multiple devices to be used
793+
tf.config.set_soft_device_placement(True)
794+
# Restrict TensorFlow to only allocate memory as needed
795+
for gpu in gpus:
796+
tf.config.experimental.set_memory_growth(gpu, True)
797+
except RuntimeError as e:
798+
# Memory growth must be set before GPUs have been initialized
799+
print(e)
780800
args = parse_args(sys.argv[1:])
781801
input("Press enter to continue...")
782802
# Resolve data source

0 commit comments

Comments
 (0)