|
23 | 23 | #include "ImfTiledMisc.h" |
24 | 24 | #include "ImfTiledOutputPart.h" |
25 | 25 |
|
| 26 | +#include <Imath/half.h> |
| 27 | + |
26 | 28 | #include <chrono> |
| 29 | +#include <cmath> |
27 | 30 | #include <ctime> |
| 31 | +#include <limits> |
28 | 32 | #include <list> |
29 | 33 | #include <stdexcept> |
30 | 34 | #include <vector> |
31 | 35 | #include <sys/stat.h> |
32 | 36 |
|
33 | 37 | using namespace OPENEXR_IMF_NAMESPACE; |
34 | 38 | using IMATH_NAMESPACE::Box2i; |
| 39 | +using IMATH_NAMESPACE::half; |
35 | 40 |
|
36 | 41 | using std::cerr; |
37 | 42 | using namespace std::chrono; |
@@ -968,7 +973,8 @@ exrmetrics ( |
968 | 973 | bool write, |
969 | 974 | bool reread, |
970 | 975 | PixelMode pixelMode, |
971 | | - bool verbose) |
| 976 | + bool verbose, |
| 977 | + bool computeMSE) |
972 | 978 | { |
973 | 979 |
|
974 | 980 | if (verbose) |
@@ -1154,6 +1160,86 @@ exrmetrics ( |
1154 | 1160 | else { metrics.outputFileSize = fileSize; } |
1155 | 1161 | } |
1156 | 1162 |
|
| 1163 | + // |
| 1164 | + // compute arcsinh-space MSE for half-float channels vs. re-read data |
| 1165 | + // |
| 1166 | + if (computeMSE && write && reread) |
| 1167 | + { |
| 1168 | + bool anyHalfChannel = false; |
| 1169 | + |
| 1170 | + for (size_t p = 0; p < parts.size (); ++p) |
| 1171 | + { |
| 1172 | + string type = outHeaders[p].type (); |
| 1173 | + if (type != SCANLINEIMAGE && type != TILEDIMAGE) continue; |
| 1174 | + |
| 1175 | + Box2i dw = outHeaders[p].dataWindow (); |
| 1176 | + uint64_t width = dw.max.x + 1 - dw.min.x; |
| 1177 | + uint64_t height = dw.max.y + 1 - dw.min.y; |
| 1178 | + |
| 1179 | + double sumSq = 0.0; |
| 1180 | + uint64_t count = 0; |
| 1181 | + int channelIndex = 0; |
| 1182 | + |
| 1183 | + for (ChannelList::ConstIterator i = |
| 1184 | + outHeaders[p].channels ().begin (); |
| 1185 | + i != outHeaders[p].channels ().end (); |
| 1186 | + ++i, ++channelIndex) |
| 1187 | + { |
| 1188 | + if (i.channel ().type != HALF) continue; |
| 1189 | + anyHalfChannel = true; |
| 1190 | + |
| 1191 | + uint64_t pixelsInChannel = |
| 1192 | + (width / i.channel ().xSampling) * |
| 1193 | + (height / i.channel ().ySampling); |
| 1194 | + |
| 1195 | + const half* orig = nullptr; |
| 1196 | + const half* rereadPx = nullptr; |
| 1197 | + |
| 1198 | + if (type == SCANLINEIMAGE) |
| 1199 | + { |
| 1200 | + orig = reinterpret_cast<const half*> ( |
| 1201 | + parts[p].readBuf.scanlinePixelData[channelIndex].data ()); |
| 1202 | + rereadPx = reinterpret_cast<const half*> ( |
| 1203 | + parts[p].rereadBuf.scanlinePixelData[channelIndex].data ()); |
| 1204 | + } |
| 1205 | + else |
| 1206 | + { |
| 1207 | + if (parts[p].readBuf.tilePixelData.empty () || |
| 1208 | + parts[p].rereadBuf.tilePixelData.empty ()) |
| 1209 | + continue; |
| 1210 | + orig = reinterpret_cast<const half*> ( |
| 1211 | + parts[p].readBuf.tilePixelData[0][channelIndex].data ()); |
| 1212 | + rereadPx = reinterpret_cast<const half*> ( |
| 1213 | + parts[p].rereadBuf.tilePixelData[0][channelIndex].data ()); |
| 1214 | + } |
| 1215 | + |
| 1216 | + for (uint64_t px = 0; px < pixelsInChannel; ++px) |
| 1217 | + { |
| 1218 | + float a = static_cast<float> (orig[px]); |
| 1219 | + float b = static_cast<float> (rereadPx[px]); |
| 1220 | + if (std::isfinite (a) && std::isfinite (b)) |
| 1221 | + { |
| 1222 | + double diff = std::asinh (double (a) / 0.0000001) - |
| 1223 | + std::asinh (double (b) / 0.0000001); |
| 1224 | + sumSq += diff * diff; |
| 1225 | + ++count; |
| 1226 | + } |
| 1227 | + } |
| 1228 | + } |
| 1229 | + |
| 1230 | + metrics.stats[p].mseCount = count; |
| 1231 | + metrics.stats[p].mse = |
| 1232 | + count > 0 ? sumSq / count |
| 1233 | + : std::numeric_limits<double>::quiet_NaN (); |
| 1234 | + } |
| 1235 | + |
| 1236 | + if (!anyHalfChannel) |
| 1237 | + { |
| 1238 | + cerr << "warning: --mse requires half-float channels, " |
| 1239 | + "but none found in output image\n"; |
| 1240 | + } |
| 1241 | + } |
| 1242 | + |
1157 | 1243 | // |
1158 | 1244 | // sum across all parts |
1159 | 1245 | // |
@@ -1197,6 +1283,25 @@ exrmetrics ( |
1197 | 1283 | } |
1198 | 1284 | } |
1199 | 1285 |
|
| 1286 | + // accumulate MSE as weighted average across parts (by sample count) |
| 1287 | + if (computeMSE && write && reread) |
| 1288 | + { |
| 1289 | + double totalSum = 0.0; |
| 1290 | + uint64_t totalCount = 0; |
| 1291 | + for (size_t i = 0; i < metrics.stats.size (); ++i) |
| 1292 | + { |
| 1293 | + if (metrics.stats[i].mseCount > 0) |
| 1294 | + { |
| 1295 | + totalSum += metrics.stats[i].mse * metrics.stats[i].mseCount; |
| 1296 | + totalCount += metrics.stats[i].mseCount; |
| 1297 | + } |
| 1298 | + } |
| 1299 | + metrics.totalStats.mseCount = totalCount; |
| 1300 | + metrics.totalStats.mse = |
| 1301 | + totalCount > 0 ? totalSum / totalCount |
| 1302 | + : std::numeric_limits<double>::quiet_NaN (); |
| 1303 | + } |
| 1304 | + |
1200 | 1305 | if (verbose) { cerr << endl; } |
1201 | 1306 | return metrics; |
1202 | 1307 | } |
0 commit comments