Skip to content

Commit dcd00f9

Browse files
author
lbechberger
committed
added grid search
1 parent 428e765 commit dcd00f9

5 files changed

Lines changed: 77 additions & 6 deletions

File tree

code/application/application.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
with open(args.dim_red_file, 'rb') as f_in:
3030
dimensionality_reduction = pickle.load(f_in)
3131
with open(args.classifier_file, 'rb') as f_in:
32-
classifier = pickle.load(f_in)
32+
classifier = pickle.load(f_in)["classifier"]
3333

3434
# chain them together into a single pipeline
3535
pipeline = make_pipeline(preprocessing, feature_extraction, dimensionality_reduction, classifier)

code/classification/classifier.sge

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/bin/bash
2+
#$ -N classifier
3+
#$ -l mem=2G
4+
#$ -cwd
5+
#$ -pe default 2
6+
#$ -o $HOME
7+
#$ -e $HOME
8+
#$ -l h=*cippy*
9+
10+
export PATH="$HOME/miniconda/bin:$PATH"
11+
eval "$(conda shell.bash hook)"
12+
conda activate MLinPractice
13+
14+
# train classifier on training set
15+
echo " training"
16+
python -m code.classification.run_classifier data/dimensionality_reduction/training.pickle -e $*
17+
18+
# evaluate classifier on validation set
19+
echo " validation"
20+
python -m code.classification.run_classifier data/dimensionality_reduction/validation.pickle -i $*
21+
22+
conda deactivate

code/classification/grid_search.sh

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/bin/bash
2+
3+
mkdir -p data/classification
4+
5+
# specify hyperparameter values
6+
values_of_k=("1 2 3 4 5 6 7 8 9 10")
7+
8+
9+
# different execution modes
10+
if [ $1 = local ]
11+
then
12+
echo "[local execution]"
13+
cmd="code/classification/classifier.sge"
14+
elif [ $1 = grid ]
15+
then
16+
echo "[grid execution]"
17+
cmd="qsub code/classification/classifier.sge"
18+
else
19+
echo "[ERROR! Argument not supported!]"
20+
exit 1
21+
fi
22+
23+
# do the grid search
24+
for k in $values_of_k
25+
do
26+
echo $k
27+
$cmd 'data/classification/clf_'"$k"'.pickle' --knn $k -s 42 --accuracy --kappa
28+
done

code/classification/run_classifier.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.preprocessing import StandardScaler
1515
from sklearn.neighbors import KNeighborsClassifier
1616
from sklearn.pipeline import make_pipeline
17+
from mlflow import log_metric, log_param, set_tracking_uri
1718

1819
# setting up CLI
1920
parser = argparse.ArgumentParser(description = "Classifier")
@@ -26,37 +27,54 @@
2627
parser.add_argument("--knn", type = int, help = "k nearest neighbor classifier with the specified value of k", default = None)
2728
parser.add_argument("-a", "--accuracy", action = "store_true", help = "evaluate using accuracy")
2829
parser.add_argument("-k", "--kappa", action = "store_true", help = "evaluate using Cohen's kappa")
30+
parser.add_argument("--log_folder", help = "where to log the mlflow results", default = "data/classification/mlflow")
2931
args = parser.parse_args()
3032

3133
# load data
3234
with open(args.input_file, 'rb') as f_in:
3335
data = pickle.load(f_in)
3436

37+
set_tracking_uri(args.log_folder)
38+
3539
if args.import_file is not None:
3640
# import a pre-trained classifier
3741
with open(args.import_file, 'rb') as f_in:
38-
classifier = pickle.load(f_in)
42+
input_dict = pickle.load(f_in)
43+
44+
classifier = input_dict["classifier"]
45+
for param, value in input_dict["params"].items():
46+
log_param(param, value)
47+
48+
log_param("dataset", "validation")
3949

4050
else: # manually set up a classifier
4151

4252
if args.majority:
4353
# majority vote classifier
4454
print(" majority vote classifier")
55+
log_param("classifier", "majority")
56+
params = {"classifier": "majority"}
4557
classifier = DummyClassifier(strategy = "most_frequent", random_state = args.seed)
4658

4759
elif args.frequency:
4860
# label frequency classifier
4961
print(" label frequency classifier")
62+
log_param("classifier", "frequency")
63+
params = {"classifier": "frequency"}
5064
classifier = DummyClassifier(strategy = "stratified", random_state = args.seed)
5165

5266

5367
elif args.knn is not None:
5468
print(" {0} nearest neighbor classifier".format(args.knn))
69+
log_param("classifier", "knn")
70+
log_param("k", args.knn)
71+
params = {"classifier": "knn", "k": args.knn}
5572
standardizer = StandardScaler()
56-
knn_classifier = KNeighborsClassifier(args.knn)
73+
knn_classifier = KNeighborsClassifier(args.knn, n_jobs = -1)
5774
classifier = make_pipeline(standardizer, knn_classifier)
5875

5976
classifier.fit(data["features"], data["labels"].ravel())
77+
log_param("dataset", "training")
6078

6179
# now classify the given data
6280
prediction = classifier.predict(data["features"])
@@ -66,13 +84,16 @@
6684
if args.accuracy:
6785
evaluation_metrics.append(("accuracy", accuracy_score))
6886
if args.kappa:
69-
evaluation_metrics.append(("Cohen's kappa", cohen_kappa_score))
87+
evaluation_metrics.append(("Cohen_kappa", cohen_kappa_score))
7088

7189
# compute and print them
7290
for metric_name, metric in evaluation_metrics:
73-
print(" {0}: {1}".format(metric_name, metric(data["labels"], prediction)))
91+
metric_value = metric(data["labels"], prediction)
92+
print(" {0}: {1}".format(metric_name, metric_value))
93+
log_metric(metric_name, metric_value)
7494

7595
# export the trained classifier if the user wants us to do so
7696
if args.export_file is not None:
97+
output_dict = {"classifier": classifier, "params": params}
7798
with open(args.export_file, 'wb') as f_out:
78-
pickle.dump(classifier, f_out)
99+
pickle.dump(output_dict, f_out)
-5 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)