Skip to content

Commit 243252b

Browse files
committed
Refactor train_tf_ps.py for regression on labeled image datasets.
- Activation Function to PRelu. - CNN dataset shuffle disabled. - Deprecate folder-per-class structure; implement `labels.jsonl` support for flat image directories. - Replace classification logic with CNN regressor for predicting pixel coordinates `(x_px, y_px)`. - Disabled pixel normalization for the rescaling. - Revise dataset loading to handle `labels.jsonl` and ensure targets align with resized image dimensions. - Update training pipeline with regression-specific metrics, optimizer, and model design. - Streamline docstrings, argument parsing, and default settings.
1 parent 9812d26 commit 243252b

1 file changed

Lines changed: 167 additions & 88 deletions

File tree

workloads/raw-tf/train_tf_ps.py

Lines changed: 167 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -151,58 +151,48 @@ def load_csv(
151151

152152
def list_image_classes(data_dir: str) -> List[str]:
153153
"""
154-
Lists image classes by finding subdirectories within the given directory.
155-
Each subdirectory is considered as representing one class.
156-
Ensures the provided directory exists and contains subdirectories following a 'folder per class' structure.
154+
Deprecated: Folder-per-class image structure is no longer supported.
157155
158-
Params:
159-
-------
160-
161-
data_dir : str -> Path to the directory containing class subfolders.
162-
163-
return : List[str] ->List of class names, represented as subfolder names.
164-
165-
raises RuntimeError: If the specified directory does not exist or is not a directory.
166-
raises RuntimeError: If no subdirectories (class folders) are found in the provided directory.
156+
This project now expects a flat directory of images with a labels.jsonl file
157+
providing pixel coordinates for each image. This function is kept for
158+
backward compatibility but will always raise to prevent accidental use.
167159
"""
168-
if not os.path.isdir(data_dir):
169-
raise RuntimeError(f"'{data_dir}' is not a directory")
170-
classes = [d for d in sorted(os.listdir(data_dir)) if os.path.isdir(os.path.join(data_dir, d))]
171-
if not classes:
172-
raise RuntimeError("No class subfolders found. Expected 'folder per class' structure.")
173-
return classes
160+
raise RuntimeError(
161+
"Folder-per-class structure is no longer supported. Use labels.jsonl with a flat image directory."
162+
)
174163

175164

176165
def count_images(data_dir: str) -> int:
177166
"""
178-
Counts the total number of image files in the given directory, including all its
179-
subdirectories, based on specific file extensions. Supported file extensions are:
180-
``.jpg``, ``.jpeg``, ``.png``, ``.bmp``, ``.gif``, ``.ppm``.
181-
182-
The main implementation of the ``count_images`` function is for the step size in the model training.
183-
'steps_per_epoch = max(1, _count_images(data_dir) // batch_size)'
167+
Count labeled images using labels.jsonl in a flat directory.
184168
185-
Raises a ``RuntimeError`` if no images are found under the provided directory.
186-
187-
Params:
188-
-------
189-
190-
data_dir : str -> The root directory containing subdirectories of image classes
191-
192-
return : int -> The total count of image files found in the directory
193-
194-
raises RuntimeError: If no images are found in the provided directory
169+
Only counts entries that both exist on disk and have a supported image
170+
extension.
195171
"""
172+
labels_path = os.path.join(data_dir, "labels.jsonl")
173+
if not os.path.isfile(labels_path):
174+
raise RuntimeError(f"labels.jsonl not found in: {data_dir}")
196175
exts = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".ppm"}
197176
total = 0
198-
for clss in list_image_classes(data_dir):
199-
cls_dir = os.path.join(data_dir, clss)
200-
for name in os.listdir(cls_dir):
177+
with open(labels_path, "r", encoding="utf-8") as fh:
178+
for line in fh:
179+
line = line.strip()
180+
if not line:
181+
continue
182+
try:
183+
obj = json.loads(line)
184+
except Exception:
185+
continue
186+
name = str(obj.get("image", "")).strip()
187+
if not name:
188+
continue
201189
_, ext = os.path.splitext(name.lower())
202-
if ext in exts:
190+
if ext not in exts:
191+
continue
192+
if os.path.isfile(os.path.join(data_dir, name)):
203193
total += 1
204194
if total == 0:
205-
raise RuntimeError("No images found under the provided directory.")
195+
raise RuntimeError("No labeled images found (labels.jsonl present but matched zero files).")
206196
return total
207197

208198

@@ -213,19 +203,101 @@ def make_image_dataset(
213203
shuffle: bool = True,
214204
input_context: Optional[tf.distribute.InputContext] = None,
215205
) -> tf.data.Dataset:
216-
"""Create a tf.data.Dataset from a folder-per-class directory."""
217-
ds = tf.keras.utils.image_dataset_from_directory(
218-
data_dir,
219-
labels="inferred",
220-
label_mode="int",
221-
image_size=image_size,
222-
batch_size=batch_size,
223-
shuffle=shuffle,
224-
seed=1337,
225-
)
206+
"""
207+
Create a tf.data.Dataset for regression on (x_px, y_px) from a flat folder of images
208+
and a labels.jsonl file.
209+
210+
- labels.jsonl format (per line):
211+
{"image": "<file>", "point": {"x_px": <float>, "y_px": <float>},
212+
"image_size": {"width": <int>, "height": <int>}}
213+
214+
Targets are automatically scaled from original pixel coordinates to the
215+
provided resized image_size so the model predicts pixels in the resized
216+
space (not normalized). This keeps the target in pixels as requested while
217+
matching the actual tensor shape given to the model.
218+
"""
219+
labels_path = os.path.join(data_dir, "labels.jsonl")
220+
if not os.path.isfile(labels_path):
221+
raise RuntimeError(f"labels.jsonl not found in: {data_dir}")
222+
223+
img_h, img_w = int(image_size[0]), int(image_size[1])
224+
225+
filepaths: List[str] = []
226+
targets: List[List[float]] = []
227+
228+
exts = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".ppm"}
229+
with open(labels_path, "r", encoding="utf-8") as fh:
230+
for line in fh:
231+
line = line.strip()
232+
if not line:
233+
continue
234+
try:
235+
obj = json.loads(line)
236+
except Exception:
237+
continue
238+
name = str(obj.get("image", "")).strip()
239+
if not name:
240+
continue
241+
_, ext = os.path.splitext(name.lower())
242+
if ext not in exts:
243+
continue
244+
full_path = os.path.join(data_dir, name)
245+
if not os.path.isfile(full_path):
246+
continue
247+
248+
point = obj.get("point") or {}
249+
x_px = point.get("x_px")
250+
y_px = point.get("y_px")
251+
if x_px is None or y_px is None:
252+
continue
253+
254+
# Is not required because no matter what, the output must be the pixel in original size
255+
# img_size = obj.get("image_size") or {}
256+
# ow = img_size.get("width")
257+
# oh = img_size.get("height")
258+
# # If original sizes are missing, fall back to assuming the same as resize
259+
# if not ow or not oh:
260+
# ow, oh = img_w, img_h
261+
#
262+
# # Scale pixel coordinates from original image space to the resized space
263+
# sx = float(img_w) / float(ow)
264+
# sy = float(img_h) / float(oh)
265+
# tx = float(x_px) * sx
266+
# ty = float(y_px) * sy
267+
268+
filepaths.append(full_path)
269+
targets.append([x_px, y_px])
270+
271+
if not filepaths:
272+
raise RuntimeError("No valid labeled images were parsed from labels.jsonl")
273+
274+
# Optionally shuffle at the file list level for better randomness pre-epoch
275+
if shuffle:
276+
rng = np.random.default_rng(1337)
277+
idx = np.arange(len(filepaths))
278+
rng.shuffle(idx)
279+
filepaths = [filepaths[i] for i in idx]
280+
targets = [targets[i] for i in idx]
281+
282+
fp_ds = tf.data.Dataset.from_tensor_slices(filepaths)
283+
y_ds = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(targets, dtype=tf.float32))
284+
ds = tf.data.Dataset.zip((fp_ds, y_ds))
285+
286+
def _load_and_preprocess(path, y):
287+
img = tf.io.read_file(path)
288+
img = tf.image.decode_image(img, channels=3, expand_animations=False)
289+
img = tf.image.resize(img, [img_h, img_w])
290+
img = tf.cast(img, tf.float32) / 255.0
291+
return img, y
292+
293+
ds = ds.map(_load_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
294+
226295
if input_context is not None:
227296
ds = ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)
228-
ds = ds.repeat().prefetch(tf.data.AUTOTUNE)
297+
298+
if shuffle:
299+
ds = ds.shuffle(buffer_size=min(10000, len(filepaths)))
300+
ds = ds.batch(batch_size).repeat().prefetch(tf.data.AUTOTUNE)
229301
return ds
230302

231303

@@ -237,10 +309,10 @@ def build_deep_model(input_dim: int, num_classes: int) -> tf.keras.Model:
237309
model = tf.keras.Sequential(
238310
[
239311
tf.keras.layers.Input(shape=(input_dim,)),
240-
tf.keras.layers.Dense(64, activation="relu"),
241-
tf.keras.layers.Dense(32, activation="relu"),
242312
tf.keras.layers.Dense(16, activation="relu"),
243-
tf.keras.layers.Dense(num_classes, activation="softmax"),
313+
tf.keras.layers.Dense(32, activation="relu"),
314+
tf.keras.layers.Dense(64, activation="relu"),
315+
tf.keras.layers.Dense(num_classes, activation="softmax"), # Softmax because multiclass classification
244316
]
245317
)
246318
model.compile(
@@ -251,25 +323,32 @@ def build_deep_model(input_dim: int, num_classes: int) -> tf.keras.Model:
251323
return model
252324

253325

254-
def build_cnn_model(input_shape: Tuple[int, int, int], num_classes: int) -> tf.keras.Model:
326+
def build_cnn_model(input_shape: Tuple[int, int, int], num_outputs: int = 2) -> tf.keras.Model:
327+
"""Build a simple CNN regressor that predicts (x_px, y_px) in resized pixels."""
255328
model = tf.keras.Sequential(
256329
[
257330
tf.keras.layers.Input(shape=input_shape),
258-
tf.keras.layers.Rescaling(1.0 / 255.0),
259-
tf.keras.layers.Conv2D(32, 3, activation="relu", padding="same"),
331+
tf.keras.layers.Conv2D(32, 3, padding="same"),
332+
tf.keras.layers.PReLU(),
260333
tf.keras.layers.MaxPooling2D(),
261-
tf.keras.layers.Conv2D(64, 3, activation="relu", padding="same"),
334+
tf.keras.layers.Conv2D(64, 3, padding="same"),
335+
tf.keras.layers.PReLU(),
262336
tf.keras.layers.MaxPooling2D(),
263-
tf.keras.layers.Conv2D(128, 3, activation="relu", padding="same"),
337+
tf.keras.layers.Conv2D(128, 3, padding="same"),
338+
tf.keras.layers.PReLU(),
264339
tf.keras.layers.GlobalAveragePooling2D(),
265-
tf.keras.layers.Dense(64, activation="relu"),
266-
tf.keras.layers.Dense(num_classes, activation="softmax"),
340+
# tf.keras.layers.Flatten(),
341+
tf.keras.layers.Dense(128, activation="relu"),
342+
tf.keras.layers.Dense(num_outputs, activation="linear"),
267343
]
268344
)
345+
346+
model.summary()
347+
269348
model.compile(
270349
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
271-
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
272-
metrics=["accuracy"],
350+
loss=tf.keras.losses.MeanSquaredError(),
351+
metrics=[tf.keras.metrics.MeanAbsoluteError(name="mae"), tf.keras.metrics.MeanSquaredError(name="mse")],
273352
)
274353
return model
275354

@@ -580,18 +659,12 @@ def run_image_training(
580659
chief_port: int = 2223,
581660
) -> None:
582661
"""
583-
Train a CNN model on an image dataset organized as folder-per-class.
662+
Train a CNN regressor to predict (x_px, y_px) in pixels using a flat image
663+
directory and labels.jsonl.
584664
"""
585665
os.makedirs(output_dir, exist_ok=True)
586666

587-
classes = list_image_classes(data_dir)
588-
num_classes = len(classes)
589667
input_shape = (img_height, img_width, 3)
590-
591-
# Save label map
592-
with open(os.path.join(output_dir, "label_map.json"), "w", encoding="utf-8") as fh:
593-
json.dump({int(i): s for i, s in enumerate(classes)}, fh, ensure_ascii=False, indent=2)
594-
595668
steps_per_epoch = max(1, count_images(data_dir) // batch_size)
596669

597670
if use_parameter_server and (worker_replicas > 0):
@@ -609,16 +682,17 @@ def per_worker_dataset_fn(input_context: Optional[tf.distribute.InputContext] =
609682
data_dir=data_dir,
610683
image_size=(img_height, img_width),
611684
batch_size=batch_size,
612-
shuffle=True,
685+
shuffle=False,
613686
input_context=input_context,
614687
)
615688

616689
with strategy.scope():
617-
model = build_cnn_model(input_shape, num_classes)
690+
model = build_cnn_model(input_shape, num_outputs=2)
618691
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
619-
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
620-
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
621-
train_loss = tf.keras.metrics.Mean()
692+
loss_obj = tf.keras.losses.MeanSquaredError()
693+
train_mae = tf.keras.metrics.MeanAbsoluteError(name="mae")
694+
train_mse = tf.keras.metrics.MeanSquaredError(name="mse")
695+
train_loss = tf.keras.metrics.Mean(name="loss")
622696

623697
coordinator = tf.distribute.coordinator.ClusterCoordinator(strategy)
624698
per_worker_ds = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
@@ -627,22 +701,24 @@ def per_worker_dataset_fn(input_context: Optional[tf.distribute.InputContext] =
627701
@tf.function
628702
def per_worker_train_step(iterator):
629703
def step_fn(inputs):
630-
features, labels = inputs
704+
features, labels = inputs # labels shape: (None, 2)
631705
with tf.GradientTape() as tape:
632-
logits = model(features, training=True)
633-
loss = loss_obj(labels, logits)
706+
preds = model(features, training=True)
707+
loss = loss_obj(labels, preds)
634708
loss += tf.add_n(model.losses) if model.losses else 0.0
635709
grads = tape.gradient(loss, model.trainable_variables)
636710
optimizer.apply_gradients(zip(grads, model.trainable_variables))
637-
train_acc.update_state(labels, logits)
711+
train_mae.update_state(labels, preds)
712+
train_mse.update_state(labels, preds)
638713
train_loss.update_state(loss)
639714
return loss
640715

641716
return strategy.run(step_fn, args=(next(iterator),))
642717

643718
for epoch in range(epochs):
644719
print(f"Starting epoch {epoch+1}/{epochs}...")
645-
train_acc.reset_state()
720+
train_mae.reset_state()
721+
train_mse.reset_state()
646722
train_loss.reset_state()
647723

648724
futures = []
@@ -652,39 +728,42 @@ def step_fn(inputs):
652728
coordinator.join()
653729

654730
print(
655-
f"Epoch {epoch+1} - loss: {train_loss.result().numpy():.4f} - accuracy: {train_acc.result().numpy():.4f}"
731+
f"Epoch {epoch+1} - loss: {train_loss.result().numpy():.4f} - mae: {train_mae.result().numpy():.4f} - mse: {train_mse.result().numpy():.4f}"
656732
)
657733

658-
history = type("_H", (), {"history": {"accuracy": [train_acc.result().numpy()]}})()
734+
# Keras History-like
735+
history = type("_H", (), {"history": {"mae": [train_mae.result().numpy()], "mse": [train_mse.result().numpy()], "loss": [train_loss.result().numpy()]}})()
659736
else:
660737
print("Running single-process image training.")
661738
ds = make_image_dataset(
662739
data_dir=data_dir,
663740
image_size=(img_height, img_width),
664741
batch_size=batch_size,
665-
shuffle=True,
742+
shuffle=False,
666743
input_context=None,
667744
)
668-
model = build_cnn_model(input_shape, num_classes)
745+
model = build_cnn_model(input_shape, num_outputs=2)
669746
history = model.fit(ds, epochs=epochs, steps_per_epoch=steps_per_epoch)
670747

671748
save_path = os.path.join(output_dir, "model.keras")
672749
model.save(save_path)
673750
print(f"Model saved to: {save_path}")
674751

675-
final_acc = history.history.get("accuracy", [None])[-1]
676-
print(f"Final training accuracy: {final_acc}")
752+
final_mae = history.history.get("mae", [None])[-1]
753+
final_mse = history.history.get("mse", [None])[-1]
754+
final_loss = history.history.get("loss", [None])[-1]
755+
print(f"Final training - loss: {final_loss}, mae: {final_mae}, mse: {final_mse}")
677756

678757

679758
def parse_args(argv: List[str]):
680759
parser = argparse.ArgumentParser(description="Train TF Keras model on CSV or images (folder-per-class) with optional ParameterServerStrategy")
681-
parser.add_argument("--data-path", default=os.environ.get("DATA_PATH", "/app/infra/local/mysql-database/datasets/image-datasets/flower_photos"), help="Path to CSV or image root directory")
760+
parser.add_argument("--data-path", default=os.environ.get("DATA_PATH", "/app/infra/local/mysql-database/datasets/image-datasets/laser-spots"), help="Path to CSV or image root directory")
682761
parser.add_argument("--data-url", default=os.environ.get("DATA_URL", "/app/infra/local/mysql-database/datasets/csvs/health.csv"), help="HTTP(S) URL to CSV (used inside cluster if path not mounted)")
683762
parser.add_argument("--data-is-images", action="store_false", help="Treat data-path as folder-per-class image dataset")
684763
parser.add_argument("--img-height", type=int, default=int(os.environ.get("IMG_HEIGHT", "180")), help="Image height for resizing")
685764
parser.add_argument("--img-width", type=int, default=int(os.environ.get("IMG_WIDTH", "180")), help="Image width for resizing")
686765
parser.add_argument("--output-dir", default=os.environ.get("OUTPUT_DIR", "./tf-model"))
687-
parser.add_argument("--epochs", type=int, default=int(os.environ.get("EPOCHS", "10")))
766+
parser.add_argument("--epochs", type=int, default=int(os.environ.get("EPOCHS", "3")))
688767
parser.add_argument("--batch-size", type=int, default=int(os.environ.get("BATCH_SIZE", "64")))
689768
parser.add_argument("--use-ps", action="store_true", help="Enable ParameterServerStrategy coordinator mode")
690769
parser.add_argument("--worker-replicas", type=int, default=int(os.environ.get("WORKER_REPLICAS", "2")))

0 commit comments

Comments
 (0)