-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild.rs
More file actions
116 lines (97 loc) · 3.65 KB
/
build.rs
File metadata and controls
116 lines (97 loc) · 3.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
use burn_onnx::{LoadStrategy, ModelGen};
use std::path::Path;
fn download_file_if_necessary<P: AsRef<Path>>(url: &str, path: P) {
let path = path.as_ref();
if path.exists() {
return;
}
println!("url {}", url);
let response = reqwest::blocking::get(url).unwrap();
let bytes = response.bytes().unwrap();
std::fs::write(path, bytes).unwrap();
}
fn burn_onnx_converter<P: AsRef<Path>>(path: P, out_dir: &str) {
ModelGen::new()
.input(path.as_ref().to_str().unwrap())
.out_dir(out_dir)
.load_strategy(LoadStrategy::Embedded)
.run_from_script();
}
fn test_files() {
let path = Path::new("dataset");
std::fs::create_dir_all(path).unwrap();
let url = "https://raw.githubusercontent.com/serengil/deepface/refs/heads/master/tests/dataset/img1.jpg";
download_file_if_necessary(url, path.join("one_face.jpg"));
}
fn detection_models() {
let path = Path::new("models/detection");
std::fs::create_dir_all(path).unwrap();
const WEIGHTS: [(&'static str, &'static str); 2] = [
(
"https://github.com/A2va/deepface-rs/releases/download/v0.0/centerface.onnx",
"centerface.onnx",
),
(
"https://github.com/A2va/deepface-rs/releases/download/v0.0/yunet.onnx",
"yunet.onnx",
),
];
for (url, filename) in WEIGHTS {
let filename = Path::new(filename);
let feature = filename.file_stem().unwrap().to_str().unwrap().to_string();
let feature = format!("CARGO_FEATURE_{}", feature.replace('-', "_").to_uppercase());
if std::env::var(&feature).is_ok() {
let file = path.join(filename);
println!("url {}", url);
download_file_if_necessary(url, &file);
if file.exists() {
let extension = file.extension().unwrap().to_str().unwrap();
match extension {
"onnx" => burn_onnx_converter(file, path.to_str().unwrap()),
_ => (),
}
}
}
}
}
fn recognition_models() {
let path = Path::new("models/recognition");
std::fs::create_dir_all(path).unwrap();
const WEIGHTS: [(&'static str, &'static str); 2] = [
(
"https://github.com/A2va/deepface-rs/releases/download/v0.0/deepid.onnx",
"deepid.onnx",
),
(
"https://github.com/A2va/deepface-rs/releases/download/v0.0/facenet512.onnx",
"facenet512.onnx",
),
];
for (url, filename) in WEIGHTS {
let filename = Path::new(filename);
let feature = filename.file_stem().unwrap().to_str().unwrap().to_string();
let feature = format!("CARGO_FEATURE_{}", feature.replace('-', "_").to_uppercase());
if std::env::var(&feature).is_ok() {
let file = path.join(filename);
println!("url {}", url);
download_file_if_necessary(url, &file);
if file.exists() {
let extension = file.extension().unwrap().to_str().unwrap();
match extension {
"onnx" => burn_onnx_converter(file, path.to_str().unwrap()),
_ => (),
}
}
}
}
}
fn main() {
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-env-changed=CARGO_FEATURE_DETECTION_CENTERFACE");
println!("cargo:rerun-if-env-changed=CARGO_FEATURE_DETECTION_YUNET");
println!("cargo:rerun-if-env-changed=CARGO_FEATURE_RECOGNITION_DEEPID");
println!("cargo:rerun-if-env-changed=CARGO_FEATURE_RECOGNITION_FACENET512");
test_files();
detection_models();
recognition_models();
}