Skip to content
Snippets Groups Projects
Commit cf579ecc authored by Dan Guest's avatar Dan Guest Committed by Frank Winklmeier
Browse files

H5 Writer: allow double -> half

H5 Writer: allow double -> half
parent c2c71ddd
No related branches found
No related tags found
1 merge request!66500H5 Writer: allow double -> half
......@@ -16,16 +16,26 @@
namespace H5Utils {
namespace internal {
H5::DataType halfPrecisionFloat(int ebias = 15);
template <typename T>
H5::DataType getCompressedType(Compression comp) {
if constexpr (std::is_floating_point<T>::value) {
switch (comp) {
case Compression::STANDARD: return H5Traits<T>::type;
case Compression::HALF_PRECISION: return halfPrecisionFloat();
case Compression::HALF_PRECISION_LARGE: return halfPrecisionFloat(5);
default: throw std::logic_error("unknown float compression");
}
}
if (comp != Compression::STANDARD) {
throw std::logic_error("compression not supported for this type");
}
return H5Traits<T>::type;
}
template <>
H5::DataType getCompressedType<float>(Compression comp);
}
}
} // end internal
} // end H5Utils
#endif
......@@ -5,33 +5,24 @@
#include "HDF5Utils/CompressedTypes.h"
namespace {
H5::DataType halfPrecisionFloat(int ebias = 15) {
// start with native float
H5::FloatType type(H5Tcopy(H5::PredType::NATIVE_FLOAT.getId()));
// These definitions are copied from h5py, see:
//
// https://github.com/h5py/h5py/blob/596748d52c351258c851bb56c8df1c25d3673110/h5py/h5t.pyx#L212-L217
//
type.setFields(15, 10, 5, 0, 10);
type.setSize(2);
type.setEbias(ebias);
return type;
}
}
namespace H5Utils {
namespace internal {
template <>
H5::DataType getCompressedType<float>(Compression comp) {
switch (comp) {
case Compression::STANDARD: return H5Traits<float>::type;
case Compression::HALF_PRECISION: return halfPrecisionFloat();
case Compression::HALF_PRECISION_LARGE: return halfPrecisionFloat(5);
default: throw std::logic_error("unknown float compression");
}
H5::DataType halfPrecisionFloat(int ebias) {
// start with native float
H5::FloatType type(H5Tcopy(H5::PredType::NATIVE_FLOAT.getId()));
// These definitions are copied from h5py, see:
//
// https://github.com/h5py/h5py/blob/596748d52c351258c851bb56c8df1c25d3673110/h5py/h5t.pyx#L212-L217
//
type.setFields(15, 10, 5, 0, 10);
type.setSize(2);
type.setEbias(ebias);
return type;
}
}
......
......@@ -11,10 +11,12 @@ struct out_t
double dtype;
float ftype;
char ctype;
short stype;
int itype;
long ltype;
long long lltype;
unsigned char uctype;
unsigned short ustype;
unsigned int uitype;
unsigned long ultype;
unsigned long long ulltype;
......@@ -24,20 +26,20 @@ using consumer_t = H5Utils::Consumers<const out_t&>;
consumer_t getConsumers() {
consumer_t consumers;
consumers.add(
"half",
[](const out_t& o) -> float { return o.ftype; },
0,
H5Utils::Compression::HALF_PRECISION);
auto half = H5Utils::Compression::HALF_PRECISION;
consumers.add("half" , [](const out_t& o) { return o.ftype; }, 0, half);
consumers.add("dhalf", [](const out_t& o) { return o.dtype; }, 0, half);
#define ADD(NAME) consumers.add(#NAME, [](const out_t& o){ return o.NAME;}, 0)
ADD(ftype);
ADD(dtype);
ADD(btype);
ADD(ctype);
ADD(stype);
ADD(itype);
ADD(ltype);
ADD(lltype);
ADD(uctype);
ADD(ustype);
ADD(uitype);
ADD(ultype);
ADD(ulltype);
......@@ -57,10 +59,12 @@ std::vector<out_t> getOutputs(int offset, size_t length, float factor) {
out.dtype = factored;
out.ftype = factored;
out.ctype = shifted;
out.stype = shifted;
out.itype = shifted;
out.ltype = shifted;
out.lltype = shifted;
out.uctype = shifted;
out.ustype = shifted;
out.uitype = shifted;
out.ultype = shifted;
out.ulltype = shifted;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment