Skip to content
Snippets Groups Projects
Verified Commit 297f0f15 authored by Frank Sauerburger's avatar Frank Sauerburger
Browse files

Implement class normalizer function

parent 5059a563
Branches 14-create-class-normalizer 6-implement-meta-model
No related tags found
No related merge requests found
......@@ -378,6 +378,28 @@ class EstimatorNormalizer(Normalizer):
width = pd.read_hdf(path, os.path.join(key, "width"))
return cls(None, center=center, width=width)
def normalize_category_weights(df, categories, weight='weight'):
"""
The categorical weight normalizer acts on the weight variable only. The
returned dataframe will satisfy the following conditions:
- The sum of weights of all events is equal to the total number of
entries.
- The sum of weights of a category is equal to the total number of entries
divided by the number of classes. Therefore the sum of weights of two
categories are equal.
- The relative weights within a category are unchanged.
"""
df_out = df[:]
w_norm = np.empty(len(df))
for category in categories:
idx = category(df)
w_norm[idx] = df[idx][weight].sum()
df_out[weight] = df_out[weight] / w_norm * len(df) / len(categories)
return df_out
class HepNet:
"""
......
......@@ -7,8 +7,10 @@ import math
import pandas as pd
from nnfwtbn.model import CrossValidator, ClassicalCV, MixedCV, \
Normalizer, EstimatorNormalizer
Normalizer, EstimatorNormalizer, \
normalize_category_weights
from nnfwtbn.variable import Variable
from nnfwtbn.cut import Cut
class StubCrossValidator(CrossValidator):
def select_slice(self, df, slice_i):
......@@ -620,3 +622,63 @@ class EstimatorNormalizerTestCase(unittest.TestCase):
os.close(fd)
os.remove(path)
self.assertTrue(norm1 == norm2)
class CategoricalWeightNormalizerTestCase(unittest.TestCase):
"""
Test the implementation of normalize_category_weights.
"""
def generate_df(self):
"""
Generate toy dataframe.
"""
return pd.DataFrame({
"x": [9, 10, 10, 12, 12, 13],
"weight": [0.1, 0.2, 0.3, 1.4, 1.8, 1],
"alt_weight": [1.1, 1.2, 1.3, 2.4, 2.8, 2],
"fpid": [1, 2, 1, 2, 1, 3],
})
def test_alternative_weight(self):
"""
Check that the constructor normalized the classes using an alternative
weight variables.
"""
df = self.generate_df()
categories = [Cut(lambda d: d.fpid == 1),
Cut(lambda d: d.fpid == 2),
Cut(lambda d: d.fpid == 3)]
df = normalize_category_weights(df, categories,
weight='alt_weight')
c1, c2, c3 = categories
self.assertAlmostEqual(df.alt_weight.sum(), len(df))
self.assertAlmostEqual(df[c1(df)].alt_weight.sum(), 2)
self.assertAlmostEqual(df[c2(df)].alt_weight.sum(), 2)
self.assertAlmostEqual(df[c3(df)].alt_weight.sum(), 2)
def test_main(self):
"""
Check that the constructor normalized the classes.
"""
df = self.generate_df()
categories = [Cut(lambda d: d.fpid == 1),
Cut(lambda d: d.fpid == 2),
Cut(lambda d: d.fpid == 3)]
df = normalize_category_weights(df, categories)
c1, c2, c3 = categories
self.assertAlmostEqual(df.weight.sum(), len(df))
self.assertAlmostEqual(df[c1(df)].weight.sum(), 2)
self.assertAlmostEqual(df[c2(df)].weight.sum(), 2)
self.assertAlmostEqual(df[c3(df)].weight.sum(), 2)
self.assertAlmostEqual(df.weight[2] / df.weight[0], 3)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment