1414from sklearn .preprocessing import StandardScaler
1515from sklearn .neighbors import KNeighborsClassifier
1616from sklearn .pipeline import make_pipeline
17+ from mlflow import log_metric , log_param , set_tracking_uri
1718
1819# setting up CLI
1920parser = argparse .ArgumentParser (description = "Classifier" )
2627parser .add_argument ("--knn" , type = int , help = "k nearest neighbor classifier with the specified value of k" , default = None )
2728parser .add_argument ("-a" , "--accuracy" , action = "store_true" , help = "evaluate using accuracy" )
2829parser .add_argument ("-k" , "--kappa" , action = "store_true" , help = "evaluate using Cohen's kappa" )
30+ parser .add_argument ("--log_folder" , help = "where to log the mlflow results" , default = "data/classification/mlflow" )
2931args = parser .parse_args ()
3032
3133# load data
3234with open (args .input_file , 'rb' ) as f_in :
3335 data = pickle .load (f_in )
3436
37+ set_tracking_uri (args .log_folder )
38+
3539if args .import_file is not None :
3640 # import a pre-trained classifier
3741 with open (args .import_file , 'rb' ) as f_in :
38- classifier = pickle .load (f_in )
42+ input_dict = pickle .load (f_in )
43+
44+ classifier = input_dict ["classifier" ]
45+ for param , value in input_dict ["params" ].items ():
46+ log_param (param , value )
47+
48+ log_param ("dataset" , "validation" )
3949
4050else : # manually set up a classifier
4151
4252 if args .majority :
4353 # majority vote classifier
4454 print (" majority vote classifier" )
55+ log_param ("classifier" , "majority" )
56+ params = {"classifier" : "majority" }
4557 classifier = DummyClassifier (strategy = "most_frequent" , random_state = args .seed )
4658
4759 elif args .frequency :
4860 # label frequency classifier
4961 print (" label frequency classifier" )
62+ log_param ("classifier" , "frequency" )
63+ params = {"classifier" : "frequency" }
5064 classifier = DummyClassifier (strategy = "stratified" , random_state = args .seed )
5165
5266
5367 elif args .knn is not None :
5468 print (" {0} nearest neighbor classifier" .format (args .knn ))
69+ log_param ("classifier" , "knn" )
70+ log_param ("k" , args .knn )
71+ params = {"classifier" : "knn" , "k" : args .knn }
5572 standardizer = StandardScaler ()
56- knn_classifier = KNeighborsClassifier (args .knn )
73+ knn_classifier = KNeighborsClassifier (args .knn , n_jobs = - 1 )
5774 classifier = make_pipeline (standardizer , knn_classifier )
5875
5976 classifier .fit (data ["features" ], data ["labels" ].ravel ())
77+ log_param ("dataset" , "training" )
6078
6179# now classify the given data
6280prediction = classifier .predict (data ["features" ])
6684if args .accuracy :
6785 evaluation_metrics .append (("accuracy" , accuracy_score ))
6886if args .kappa :
69- evaluation_metrics .append (("Cohen's kappa " , cohen_kappa_score ))
87+ evaluation_metrics .append (("Cohen_kappa " , cohen_kappa_score ))
7088
7189# compute and print them
7290for metric_name , metric in evaluation_metrics :
73- print (" {0}: {1}" .format (metric_name , metric (data ["labels" ], prediction )))
91+ metric_value = metric (data ["labels" ], prediction )
92+ print (" {0}: {1}" .format (metric_name , metric_value ))
93+ log_metric (metric_name , metric_value )
7494
7595# export the trained classifier if the user wants us to do so
7696if args .export_file is not None :
97+ output_dict = {"classifier" : classifier , "params" : params }
7798 with open (args .export_file , 'wb' ) as f_out :
78- pickle .dump (classifier , f_out )
99+ pickle .dump (output_dict , f_out )
0 commit comments