@@ -151,58 +151,48 @@ def load_csv(
151151
152152def list_image_classes (data_dir : str ) -> List [str ]:
153153 """
154- Lists image classes by finding subdirectories within the given directory.
155- Each subdirectory is considered as representing one class.
156- Ensures the provided directory exists and contains subdirectories following a 'folder per class' structure.
154+ Deprecated: Folder-per-class image structure is no longer supported.
157155
158- Params:
159- -------
160-
161- data_dir : str -> Path to the directory containing class subfolders.
162-
163- return : List[str] ->List of class names, represented as subfolder names.
164-
165- raises RuntimeError: If the specified directory does not exist or is not a directory.
166- raises RuntimeError: If no subdirectories (class folders) are found in the provided directory.
156+ This project now expects a flat directory of images with a labels.jsonl file
157+ providing pixel coordinates for each image. This function is kept for
158+ backward compatibility but will always raise to prevent accidental use.
167159 """
168- if not os .path .isdir (data_dir ):
169- raise RuntimeError (f"'{ data_dir } ' is not a directory" )
170- classes = [d for d in sorted (os .listdir (data_dir )) if os .path .isdir (os .path .join (data_dir , d ))]
171- if not classes :
172- raise RuntimeError ("No class subfolders found. Expected 'folder per class' structure." )
173- return classes
160+ raise RuntimeError (
161+ "Folder-per-class structure is no longer supported. Use labels.jsonl with a flat image directory."
162+ )
174163
175164
176165def count_images (data_dir : str ) -> int :
177166 """
178- Counts the total number of image files in the given directory, including all its
179- subdirectories, based on specific file extensions. Supported file extensions are:
180- ``.jpg``, ``.jpeg``, ``.png``, ``.bmp``, ``.gif``, ``.ppm``.
181-
182- The main implementation of the ``count_images`` function is for the step size in the model training.
183- 'steps_per_epoch = max(1, _count_images(data_dir) // batch_size)'
167+ Count labeled images using labels.jsonl in a flat directory.
184168
185- Raises a ``RuntimeError`` if no images are found under the provided directory.
186-
187- Params:
188- -------
189-
190- data_dir : str -> The root directory containing subdirectories of image classes
191-
192- return : int -> The total count of image files found in the directory
193-
194- raises RuntimeError: If no images are found in the provided directory
169+ Only counts entries that both exist on disk and have a supported image
170+ extension.
195171 """
172+ labels_path = os .path .join (data_dir , "labels.jsonl" )
173+ if not os .path .isfile (labels_path ):
174+ raise RuntimeError (f"labels.jsonl not found in: { data_dir } " )
196175 exts = {".jpg" , ".jpeg" , ".png" , ".bmp" , ".gif" , ".ppm" }
197176 total = 0
198- for clss in list_image_classes (data_dir ):
199- cls_dir = os .path .join (data_dir , clss )
200- for name in os .listdir (cls_dir ):
177+ with open (labels_path , "r" , encoding = "utf-8" ) as fh :
178+ for line in fh :
179+ line = line .strip ()
180+ if not line :
181+ continue
182+ try :
183+ obj = json .loads (line )
184+ except Exception :
185+ continue
186+ name = str (obj .get ("image" , "" )).strip ()
187+ if not name :
188+ continue
201189 _ , ext = os .path .splitext (name .lower ())
202- if ext in exts :
190+ if ext not in exts :
191+ continue
192+ if os .path .isfile (os .path .join (data_dir , name )):
203193 total += 1
204194 if total == 0 :
205- raise RuntimeError ("No images found under the provided directory ." )
195+ raise RuntimeError ("No labeled images found (labels.jsonl present but matched zero files) ." )
206196 return total
207197
208198
@@ -213,19 +203,101 @@ def make_image_dataset(
213203 shuffle : bool = True ,
214204 input_context : Optional [tf .distribute .InputContext ] = None ,
215205) -> tf .data .Dataset :
216- """Create a tf.data.Dataset from a folder-per-class directory."""
217- ds = tf .keras .utils .image_dataset_from_directory (
218- data_dir ,
219- labels = "inferred" ,
220- label_mode = "int" ,
221- image_size = image_size ,
222- batch_size = batch_size ,
223- shuffle = shuffle ,
224- seed = 1337 ,
225- )
206+ """
207+ Create a tf.data.Dataset for regression on (x_px, y_px) from a flat folder of images
208+ and a labels.jsonl file.
209+
210+ - labels.jsonl format (per line):
211+ {"image": "<file>", "point": {"x_px": <float>, "y_px": <float>},
212+ "image_size": {"width": <int>, "height": <int>}}
213+
214+ Targets are automatically scaled from original pixel coordinates to the
215+ provided resized image_size so the model predicts pixels in the resized
216+ space (not normalized). This keeps the target in pixels as requested while
217+ matching the actual tensor shape given to the model.
218+ """
219+ labels_path = os .path .join (data_dir , "labels.jsonl" )
220+ if not os .path .isfile (labels_path ):
221+ raise RuntimeError (f"labels.jsonl not found in: { data_dir } " )
222+
223+ img_h , img_w = int (image_size [0 ]), int (image_size [1 ])
224+
225+ filepaths : List [str ] = []
226+ targets : List [List [float ]] = []
227+
228+ exts = {".jpg" , ".jpeg" , ".png" , ".bmp" , ".gif" , ".ppm" }
229+ with open (labels_path , "r" , encoding = "utf-8" ) as fh :
230+ for line in fh :
231+ line = line .strip ()
232+ if not line :
233+ continue
234+ try :
235+ obj = json .loads (line )
236+ except Exception :
237+ continue
238+ name = str (obj .get ("image" , "" )).strip ()
239+ if not name :
240+ continue
241+ _ , ext = os .path .splitext (name .lower ())
242+ if ext not in exts :
243+ continue
244+ full_path = os .path .join (data_dir , name )
245+ if not os .path .isfile (full_path ):
246+ continue
247+
248+ point = obj .get ("point" ) or {}
249+ x_px = point .get ("x_px" )
250+ y_px = point .get ("y_px" )
251+ if x_px is None or y_px is None :
252+ continue
253+
254+ # Is not required because no matter what, the output must be the pixel in original size
255+ # img_size = obj.get("image_size") or {}
256+ # ow = img_size.get("width")
257+ # oh = img_size.get("height")
258+ # # If original sizes are missing, fall back to assuming the same as resize
259+ # if not ow or not oh:
260+ # ow, oh = img_w, img_h
261+ #
262+ # # Scale pixel coordinates from original image space to the resized space
263+ # sx = float(img_w) / float(ow)
264+ # sy = float(img_h) / float(oh)
265+ # tx = float(x_px) * sx
266+ # ty = float(y_px) * sy
267+
268+ filepaths .append (full_path )
269+ targets .append ([x_px , y_px ])
270+
271+ if not filepaths :
272+ raise RuntimeError ("No valid labeled images were parsed from labels.jsonl" )
273+
274+ # Optionally shuffle at the file list level for better randomness pre-epoch
275+ if shuffle :
276+ rng = np .random .default_rng (1337 )
277+ idx = np .arange (len (filepaths ))
278+ rng .shuffle (idx )
279+ filepaths = [filepaths [i ] for i in idx ]
280+ targets = [targets [i ] for i in idx ]
281+
282+ fp_ds = tf .data .Dataset .from_tensor_slices (filepaths )
283+ y_ds = tf .data .Dataset .from_tensor_slices (tf .convert_to_tensor (targets , dtype = tf .float32 ))
284+ ds = tf .data .Dataset .zip ((fp_ds , y_ds ))
285+
286+ def _load_and_preprocess (path , y ):
287+ img = tf .io .read_file (path )
288+ img = tf .image .decode_image (img , channels = 3 , expand_animations = False )
289+ img = tf .image .resize (img , [img_h , img_w ])
290+ img = tf .cast (img , tf .float32 ) / 255.0
291+ return img , y
292+
293+ ds = ds .map (_load_and_preprocess , num_parallel_calls = tf .data .AUTOTUNE )
294+
226295 if input_context is not None :
227296 ds = ds .shard (input_context .num_input_pipelines , input_context .input_pipeline_id )
228- ds = ds .repeat ().prefetch (tf .data .AUTOTUNE )
297+
298+ if shuffle :
299+ ds = ds .shuffle (buffer_size = min (10000 , len (filepaths )))
300+ ds = ds .batch (batch_size ).repeat ().prefetch (tf .data .AUTOTUNE )
229301 return ds
230302
231303
@@ -237,10 +309,10 @@ def build_deep_model(input_dim: int, num_classes: int) -> tf.keras.Model:
237309 model = tf .keras .Sequential (
238310 [
239311 tf .keras .layers .Input (shape = (input_dim ,)),
240- tf .keras .layers .Dense (64 , activation = "relu" ),
241- tf .keras .layers .Dense (32 , activation = "relu" ),
242312 tf .keras .layers .Dense (16 , activation = "relu" ),
243- tf .keras .layers .Dense (num_classes , activation = "softmax" ),
313+ tf .keras .layers .Dense (32 , activation = "relu" ),
314+ tf .keras .layers .Dense (64 , activation = "relu" ),
315+ tf .keras .layers .Dense (num_classes , activation = "softmax" ), # Softmax because multiclass classification
244316 ]
245317 )
246318 model .compile (
@@ -251,25 +323,32 @@ def build_deep_model(input_dim: int, num_classes: int) -> tf.keras.Model:
251323 return model
252324
253325
254- def build_cnn_model (input_shape : Tuple [int , int , int ], num_classes : int ) -> tf .keras .Model :
326+ def build_cnn_model (input_shape : Tuple [int , int , int ], num_outputs : int = 2 ) -> tf .keras .Model :
327+ """Build a simple CNN regressor that predicts (x_px, y_px) in resized pixels."""
255328 model = tf .keras .Sequential (
256329 [
257330 tf .keras .layers .Input (shape = input_shape ),
258- tf .keras .layers .Rescaling ( 1.0 / 255.0 ),
259- tf .keras .layers .Conv2D ( 32 , 3 , activation = "relu" , padding = "same" ),
331+ tf .keras .layers .Conv2D ( 32 , 3 , padding = "same" ),
332+ tf .keras .layers .PReLU ( ),
260333 tf .keras .layers .MaxPooling2D (),
261- tf .keras .layers .Conv2D (64 , 3 , activation = "relu" , padding = "same" ),
334+ tf .keras .layers .Conv2D (64 , 3 , padding = "same" ),
335+ tf .keras .layers .PReLU (),
262336 tf .keras .layers .MaxPooling2D (),
263- tf .keras .layers .Conv2D (128 , 3 , activation = "relu" , padding = "same" ),
337+ tf .keras .layers .Conv2D (128 , 3 , padding = "same" ),
338+ tf .keras .layers .PReLU (),
264339 tf .keras .layers .GlobalAveragePooling2D (),
265- tf .keras .layers .Dense (64 , activation = "relu" ),
266- tf .keras .layers .Dense (num_classes , activation = "softmax" ),
340+ # tf.keras.layers.Flatten(),
341+ tf .keras .layers .Dense (128 , activation = "relu" ),
342+ tf .keras .layers .Dense (num_outputs , activation = "linear" ),
267343 ]
268344 )
345+
346+ model .summary ()
347+
269348 model .compile (
270349 optimizer = tf .keras .optimizers .Adam (learning_rate = 1e-3 ),
271- loss = tf .keras .losses .SparseCategoricalCrossentropy (),
272- metrics = ["accuracy" ],
350+ loss = tf .keras .losses .MeanSquaredError (),
351+ metrics = [tf . keras . metrics . MeanAbsoluteError ( name = "mae" ), tf . keras . metrics . MeanSquaredError ( name = "mse" ) ],
273352 )
274353 return model
275354
@@ -580,18 +659,12 @@ def run_image_training(
580659 chief_port : int = 2223 ,
581660) -> None :
582661 """
583- Train a CNN model on an image dataset organized as folder-per-class.
662+ Train a CNN regressor to predict (x_px, y_px) in pixels using a flat image
663+ directory and labels.jsonl.
584664 """
585665 os .makedirs (output_dir , exist_ok = True )
586666
587- classes = list_image_classes (data_dir )
588- num_classes = len (classes )
589667 input_shape = (img_height , img_width , 3 )
590-
591- # Save label map
592- with open (os .path .join (output_dir , "label_map.json" ), "w" , encoding = "utf-8" ) as fh :
593- json .dump ({int (i ): s for i , s in enumerate (classes )}, fh , ensure_ascii = False , indent = 2 )
594-
595668 steps_per_epoch = max (1 , count_images (data_dir ) // batch_size )
596669
597670 if use_parameter_server and (worker_replicas > 0 ):
@@ -609,16 +682,17 @@ def per_worker_dataset_fn(input_context: Optional[tf.distribute.InputContext] =
609682 data_dir = data_dir ,
610683 image_size = (img_height , img_width ),
611684 batch_size = batch_size ,
612- shuffle = True ,
685+ shuffle = False ,
613686 input_context = input_context ,
614687 )
615688
616689 with strategy .scope ():
617- model = build_cnn_model (input_shape , num_classes )
690+ model = build_cnn_model (input_shape , num_outputs = 2 )
618691 optimizer = tf .keras .optimizers .Adam (learning_rate = 1e-4 )
619- loss_obj = tf .keras .losses .SparseCategoricalCrossentropy ()
620- train_acc = tf .keras .metrics .SparseCategoricalAccuracy ()
621- train_loss = tf .keras .metrics .Mean ()
692+ loss_obj = tf .keras .losses .MeanSquaredError ()
693+ train_mae = tf .keras .metrics .MeanAbsoluteError (name = "mae" )
694+ train_mse = tf .keras .metrics .MeanSquaredError (name = "mse" )
695+ train_loss = tf .keras .metrics .Mean (name = "loss" )
622696
623697 coordinator = tf .distribute .coordinator .ClusterCoordinator (strategy )
624698 per_worker_ds = coordinator .create_per_worker_dataset (per_worker_dataset_fn )
@@ -627,22 +701,24 @@ def per_worker_dataset_fn(input_context: Optional[tf.distribute.InputContext] =
627701 @tf .function
628702 def per_worker_train_step (iterator ):
629703 def step_fn (inputs ):
630- features , labels = inputs
704+ features , labels = inputs # labels shape: (None, 2)
631705 with tf .GradientTape () as tape :
632- logits = model (features , training = True )
633- loss = loss_obj (labels , logits )
706+ preds = model (features , training = True )
707+ loss = loss_obj (labels , preds )
634708 loss += tf .add_n (model .losses ) if model .losses else 0.0
635709 grads = tape .gradient (loss , model .trainable_variables )
636710 optimizer .apply_gradients (zip (grads , model .trainable_variables ))
637- train_acc .update_state (labels , logits )
711+ train_mae .update_state (labels , preds )
712+ train_mse .update_state (labels , preds )
638713 train_loss .update_state (loss )
639714 return loss
640715
641716 return strategy .run (step_fn , args = (next (iterator ),))
642717
643718 for epoch in range (epochs ):
644719 print (f"Starting epoch { epoch + 1 } /{ epochs } ..." )
645- train_acc .reset_state ()
720+ train_mae .reset_state ()
721+ train_mse .reset_state ()
646722 train_loss .reset_state ()
647723
648724 futures = []
@@ -652,39 +728,42 @@ def step_fn(inputs):
652728 coordinator .join ()
653729
654730 print (
655- f"Epoch { epoch + 1 } - loss: { train_loss .result ().numpy ():.4f} - accuracy : { train_acc .result ().numpy ():.4f} "
731+ f"Epoch { epoch + 1 } - loss: { train_loss .result ().numpy ():.4f} - mae : { train_mae . result (). numpy ():.4f } - mse: { train_mse .result ().numpy ():.4f} "
656732 )
657733
658- history = type ("_H" , (), {"history" : {"accuracy" : [train_acc .result ().numpy ()]}})()
734+ # Keras History-like
735+ history = type ("_H" , (), {"history" : {"mae" : [train_mae .result ().numpy ()], "mse" : [train_mse .result ().numpy ()], "loss" : [train_loss .result ().numpy ()]}})()
659736 else :
660737 print ("Running single-process image training." )
661738 ds = make_image_dataset (
662739 data_dir = data_dir ,
663740 image_size = (img_height , img_width ),
664741 batch_size = batch_size ,
665- shuffle = True ,
742+ shuffle = False ,
666743 input_context = None ,
667744 )
668- model = build_cnn_model (input_shape , num_classes )
745+ model = build_cnn_model (input_shape , num_outputs = 2 )
669746 history = model .fit (ds , epochs = epochs , steps_per_epoch = steps_per_epoch )
670747
671748 save_path = os .path .join (output_dir , "model.keras" )
672749 model .save (save_path )
673750 print (f"Model saved to: { save_path } " )
674751
675- final_acc = history .history .get ("accuracy" , [None ])[- 1 ]
676- print (f"Final training accuracy: { final_acc } " )
752+ final_mae = history .history .get ("mae" , [None ])[- 1 ]
753+ final_mse = history .history .get ("mse" , [None ])[- 1 ]
754+ final_loss = history .history .get ("loss" , [None ])[- 1 ]
755+ print (f"Final training - loss: { final_loss } , mae: { final_mae } , mse: { final_mse } " )
677756
678757
679758def parse_args (argv : List [str ]):
680759 parser = argparse .ArgumentParser (description = "Train TF Keras model on CSV or images (folder-per-class) with optional ParameterServerStrategy" )
681- parser .add_argument ("--data-path" , default = os .environ .get ("DATA_PATH" , "/app/infra/local/mysql-database/datasets/image-datasets/flower_photos " ), help = "Path to CSV or image root directory" )
760+ parser .add_argument ("--data-path" , default = os .environ .get ("DATA_PATH" , "/app/infra/local/mysql-database/datasets/image-datasets/laser-spots " ), help = "Path to CSV or image root directory" )
682761 parser .add_argument ("--data-url" , default = os .environ .get ("DATA_URL" , "/app/infra/local/mysql-database/datasets/csvs/health.csv" ), help = "HTTP(S) URL to CSV (used inside cluster if path not mounted)" )
683762 parser .add_argument ("--data-is-images" , action = "store_false" , help = "Treat data-path as folder-per-class image dataset" )
684763 parser .add_argument ("--img-height" , type = int , default = int (os .environ .get ("IMG_HEIGHT" , "180" )), help = "Image height for resizing" )
685764 parser .add_argument ("--img-width" , type = int , default = int (os .environ .get ("IMG_WIDTH" , "180" )), help = "Image width for resizing" )
686765 parser .add_argument ("--output-dir" , default = os .environ .get ("OUTPUT_DIR" , "./tf-model" ))
687- parser .add_argument ("--epochs" , type = int , default = int (os .environ .get ("EPOCHS" , "10 " )))
766+ parser .add_argument ("--epochs" , type = int , default = int (os .environ .get ("EPOCHS" , "3 " )))
688767 parser .add_argument ("--batch-size" , type = int , default = int (os .environ .get ("BATCH_SIZE" , "64" )))
689768 parser .add_argument ("--use-ps" , action = "store_true" , help = "Enable ParameterServerStrategy coordinator mode" )
690769 parser .add_argument ("--worker-replicas" , type = int , default = int (os .environ .get ("WORKER_REPLICAS" , "2" )))
0 commit comments