Skip to content

Commit 5011ecd

Browse files
committed
Add MSE to exrmetrics
Signed-off-by: Pierre-Anthony Lemieux <pal@sandflow.com>
1 parent 6dfc45b commit 5011ecd

3 files changed

Lines changed: 141 additions & 5 deletions

File tree

src/bin/exrmetrics/exrmetrics.cpp

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,20 @@
2323
#include "ImfTiledMisc.h"
2424
#include "ImfTiledOutputPart.h"
2525

26+
#include <Imath/half.h>
27+
2628
#include <chrono>
29+
#include <cmath>
2730
#include <ctime>
31+
#include <limits>
2832
#include <list>
2933
#include <stdexcept>
3034
#include <vector>
3135
#include <sys/stat.h>
3236

3337
using namespace OPENEXR_IMF_NAMESPACE;
3438
using IMATH_NAMESPACE::Box2i;
39+
using IMATH_NAMESPACE::half;
3540

3641
using std::cerr;
3742
using namespace std::chrono;
@@ -968,7 +973,8 @@ exrmetrics (
968973
bool write,
969974
bool reread,
970975
PixelMode pixelMode,
971-
bool verbose)
976+
bool verbose,
977+
bool computeMSE)
972978
{
973979

974980
if (verbose)
@@ -1154,6 +1160,86 @@ exrmetrics (
11541160
else { metrics.outputFileSize = fileSize; }
11551161
}
11561162

1163+
//
1164+
// compute arcsinh-space MSE for half-float channels vs. re-read data
1165+
//
1166+
if (computeMSE && write && reread)
1167+
{
1168+
bool anyHalfChannel = false;
1169+
1170+
for (size_t p = 0; p < parts.size (); ++p)
1171+
{
1172+
string type = outHeaders[p].type ();
1173+
if (type != SCANLINEIMAGE && type != TILEDIMAGE) continue;
1174+
1175+
Box2i dw = outHeaders[p].dataWindow ();
1176+
uint64_t width = dw.max.x + 1 - dw.min.x;
1177+
uint64_t height = dw.max.y + 1 - dw.min.y;
1178+
1179+
double sumSq = 0.0;
1180+
uint64_t count = 0;
1181+
int channelIndex = 0;
1182+
1183+
for (ChannelList::ConstIterator i =
1184+
outHeaders[p].channels ().begin ();
1185+
i != outHeaders[p].channels ().end ();
1186+
++i, ++channelIndex)
1187+
{
1188+
if (i.channel ().type != HALF) continue;
1189+
anyHalfChannel = true;
1190+
1191+
uint64_t pixelsInChannel =
1192+
(width / i.channel ().xSampling) *
1193+
(height / i.channel ().ySampling);
1194+
1195+
const half* orig = nullptr;
1196+
const half* rereadPx = nullptr;
1197+
1198+
if (type == SCANLINEIMAGE)
1199+
{
1200+
orig = reinterpret_cast<const half*> (
1201+
parts[p].readBuf.scanlinePixelData[channelIndex].data ());
1202+
rereadPx = reinterpret_cast<const half*> (
1203+
parts[p].rereadBuf.scanlinePixelData[channelIndex].data ());
1204+
}
1205+
else
1206+
{
1207+
if (parts[p].readBuf.tilePixelData.empty () ||
1208+
parts[p].rereadBuf.tilePixelData.empty ())
1209+
continue;
1210+
orig = reinterpret_cast<const half*> (
1211+
parts[p].readBuf.tilePixelData[0][channelIndex].data ());
1212+
rereadPx = reinterpret_cast<const half*> (
1213+
parts[p].rereadBuf.tilePixelData[0][channelIndex].data ());
1214+
}
1215+
1216+
for (uint64_t px = 0; px < pixelsInChannel; ++px)
1217+
{
1218+
float a = static_cast<float> (orig[px]);
1219+
float b = static_cast<float> (rereadPx[px]);
1220+
if (std::isfinite (a) && std::isfinite (b))
1221+
{
1222+
double diff = std::asinh (double (a) / 0.0000001) -
1223+
std::asinh (double (b) / 0.0000001);
1224+
sumSq += diff * diff;
1225+
++count;
1226+
}
1227+
}
1228+
}
1229+
1230+
metrics.stats[p].mseCount = count;
1231+
metrics.stats[p].mse =
1232+
count > 0 ? sumSq / count
1233+
: std::numeric_limits<double>::quiet_NaN ();
1234+
}
1235+
1236+
if (!anyHalfChannel)
1237+
{
1238+
cerr << "warning: --mse requires half-float channels, "
1239+
"but none found in output image\n";
1240+
}
1241+
}
1242+
11571243
//
11581244
// sum across all parts
11591245
//
@@ -1197,6 +1283,25 @@ exrmetrics (
11971283
}
11981284
}
11991285

1286+
// accumulate MSE as weighted average across parts (by sample count)
1287+
if (computeMSE && write && reread)
1288+
{
1289+
double totalSum = 0.0;
1290+
uint64_t totalCount = 0;
1291+
for (size_t i = 0; i < metrics.stats.size (); ++i)
1292+
{
1293+
if (metrics.stats[i].mseCount > 0)
1294+
{
1295+
totalSum += metrics.stats[i].mse * metrics.stats[i].mseCount;
1296+
totalCount += metrics.stats[i].mseCount;
1297+
}
1298+
}
1299+
metrics.totalStats.mseCount = totalCount;
1300+
metrics.totalStats.mse =
1301+
totalCount > 0 ? totalSum / totalCount
1302+
: std::numeric_limits<double>::quiet_NaN ();
1303+
}
1304+
12001305
if (verbose) { cerr << endl; }
12011306
return metrics;
12021307
}

src/bin/exrmetrics/exrmetrics.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "ImfCompression.h"
1717

18+
#include <limits>
1819
#include <stdint.h>
1920

2021
#include <vector>
@@ -56,6 +57,12 @@ struct partStats
5657
std::vector<double>
5758
rereadPerf; // for deep, times reading the sample count, otherwise times reading the entire data
5859

60+
// arcsinh-space MSE for half-float channels (original vs. re-read after compression)
61+
// mean((asinh(a/1e-7) - asinh(b/1e-7))^2) over all finite half samples
62+
// NaN if not computed
63+
double mse = std::numeric_limits<double>::quiet_NaN ();
64+
uint64_t mseCount = 0;
65+
5966
partSizeData sizeData;
6067
};
6168

@@ -77,6 +84,7 @@ fileMetrics exrmetrics (
7784
bool write,
7885
bool reread,
7986
PixelMode pixelMode,
80-
bool verbose);
87+
bool verbose,
88+
bool computeMSE = false);
8189

8290
#endif

src/bin/exrmetrics/main.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <list>
2121
#include <vector>
2222

23+
#include <cmath>
2324
#include <math.h>
2425
#include <stdlib.h>
2526
#include <string.h>
@@ -89,6 +90,8 @@ usageMessage (ostream& stream, const char* program_name, bool verbose = false)
8990
" --csv print output in csv mode. If passes>1, show median timing\n"
9091
" default is JSON mode\n"
9192
" --passes num write and re-read file num times (default 1)\n"
93+
" --mse compute MSE for half-float channels in log space over all finite\n"
94+
" samples; compares original vs. re-read after compression\n"
9295
"\n"
9396
" -h, --help print this message\n"
9497
" -v output progress messages\n"
@@ -118,6 +121,7 @@ struct options
118121
bool outputSizeData = true;
119122
bool verbose = false;
120123
bool csv = false;
124+
bool computeMSE = false;
121125
std::vector<PixelMode> pixelModes;
122126
std::vector<OPENEXR_IMF_NAMESPACE::Compression> compressions;
123127

@@ -366,6 +370,11 @@ jsonStats (
366370
out << ",\n";
367371
out << " \"output size\": " << run.metrics.outputFileSize;
368372
}
373+
if (!std::isnan (run.metrics.totalStats.mse))
374+
{
375+
out << ",\n";
376+
out << " \"mse\": " << run.metrics.totalStats.mse;
377+
}
369378
if (timing)
370379
{
371380
out << ",\n";
@@ -419,6 +428,7 @@ csvStats (ostream& out, list<runData>& data, bool outputSizeData, int timing)
419428
}
420429
out << ",compression,pixel mode";
421430
if (outputSizeData) { out << ",output size"; }
431+
out << ",mse";
422432
if (timing & options::TIME_READ)
423433
{
424434
out << ",count read time";
@@ -455,6 +465,11 @@ csvStats (ostream& out, list<runData>& data, bool outputSizeData, int timing)
455465
out << ',' << compName << ',' << modeName (run.mode);
456466

457467
if (outputSizeData) { out << ',' << run.metrics.outputFileSize; }
468+
if (!std::isnan (run.metrics.totalStats.mse))
469+
{
470+
out << ',' << run.metrics.totalStats.mse;
471+
}
472+
else { out << ",---"; }
458473
if (timing & options::TIME_READ)
459474
{
460475
if (run.metrics.totalStats.sizeData.isDeep)
@@ -530,10 +545,13 @@ main (int argc, char** argv)
530545
opts.level,
531546
opts.passes,
532547
opts.outFile || opts.outputSizeData ||
533-
opts.timing & options::TIME_WRITE,
534-
opts.timing & options::TIME_REREAD,
548+
opts.timing & options::TIME_WRITE ||
549+
opts.computeMSE,
550+
opts.timing & options::TIME_REREAD ||
551+
opts.computeMSE,
535552
mode,
536-
opts.verbose);
553+
opts.verbose,
554+
opts.computeMSE);
537555
data.push_back (d);
538556
}
539557
}
@@ -870,6 +888,11 @@ options::parse (int argc, char* argv[])
870888
outputSizeData = false;
871889
i += 1;
872890
}
891+
else if (!strcmp (argv[i], "--mse"))
892+
{
893+
computeMSE = true;
894+
i += 1;
895+
}
873896
else if (!strcmp (argv[i], "-i"))
874897
{
875898
if (i > argc - 2)

0 commit comments

Comments
 (0)