Skip to content
Snippets Groups Projects
Commit 87d60e4b authored by Miroslav Saur's avatar Miroslav Saur
Browse files

Merge branch 'mveghel-flattenfix' into '2024-patches'

Fix sigmoid finite answer and flattening SIMD MLPs

See merge request lhcb/LHCb!4570
parents 0676ff36 dbecd0fc
No related branches found
No related tags found
2 merge requests!4575Synchronize master branch with 2024-patches,!4570Fix sigmoid finite answer and flattening SIMD MLPs
Pipeline #7470785 passed
......@@ -74,6 +74,17 @@ inline auto approx_exp( F x ) {
return z;
}
template <typename F>
inline auto approx_sigmoid( F x ) {
// NOTE: this form works best with approx exp (used here)
// 1 / ( 1 + exp( - w ) ) has discontinuities
auto ex = approx_exp( x );
auto z = ex / ( 1 + ex );
// ensure finite answer
z = select( x > 88.f, 1., z );
return z;
}
namespace SIMDWrapper {
namespace details {
......@@ -95,6 +106,10 @@ namespace SIMDWrapper {
inline auto exp( T x ) {
return approx_exp( x );
}
template <typename T>
inline auto sigmoid( T x ) {
return approx_sigmoid( x );
}
namespace simd_traits {
template <typename T>
......
......@@ -164,12 +164,8 @@ namespace LHCb::VectorizedML {
template <typename Range>
constexpr auto operator()( Range&& input ) const {
Vec<FType, nInput> out;
std::transform( input.m.begin(), input.m.end(), out.m.begin(), [&]( const auto& w ) {
// NOTE: this form works best with approx exp (used here)
// 1 / ( 1 + exp( - w ) ) has discontinuities
auto ew = exp( w );
return ew / ( 1 + ew );
} );
std::transform( input.m.begin(), input.m.end(), out.m.begin(),
[&]( const auto& w ) { return SIMDWrapper::sigmoid( w ); } );
return out;
}
......
......@@ -138,8 +138,10 @@ class Sequence(nn.Module):
]
xvals = [i / len(diffs) for i in range(len(diffs))]
spline = UnivariateSpline(xvals, diffs)
new_diffs = spline(xvals)
# ensure bounds by normalization
new_diffs = np.array(spline(xvals))
# ensure bounds by normalization, make sure we have non-zero positive values
for i in range(len(new_diffs)):
if new_diffs[i] < 0: new_diffs[i] == -new_diffs[i]
new_diffs = new_diffs / sum(new_diffs)
new_edges = [0]
for diff in new_diffs:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment