Skip to content
Snippets Groups Projects
Verified Commit 6006e28c authored by Frank Sauerburger's avatar Frank Sauerburger
Browse files

Improve confusion_matrix example

parent 4dc5e43e
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
......@@ -148,43 +148,42 @@ def hist(dataframe, variable, bins, stacks, selection=None,
stack_props = props[i_stack]
process_kwds[kwd] = stack_props[i_process % len(stack_props)]
if "histtype" in process_kwds:
if process_kwds["histtype"] == "points":
del process_kwds['histtype']
defaults = {
'markersize': 4,
'fmt': 'o'
}
defaults.update(process_kwds)
process_kwds = defaults
n, _ = np.histogram(
variable(dataframe[sel(dataframe)]),
bins=bins, range=range,
weights=weight(dataframe[sel(dataframe)]))
bin_centers = (bins[1:] + bins[:-1]) / 2
bin_widths = bins[1:] - bins[:-1]
axes.errorbar(bin_centers, bottom + n, np.sqrt(n), bin_widths / 2,
label=process.label,
**process_kwds)
else:
n, _, _ = axes.hist(
variable(dataframe[sel(dataframe)]),
bins=bins, range=range,
bottom=bottom,
label=process.label,
weights=weight(dataframe[sel(dataframe)]),
**process_kwds)
if "histtype" in process_kwds and process_kwds["histtype"] == "points":
del process_kwds['histtype']
defaults = {
'markersize': 4,
'fmt': 'o'
}
defaults.update(process_kwds)
process_kwds = defaults
n, _ = np.histogram(
variable(dataframe[sel(dataframe)]),
bins=bins, range=range,
weights=weight(dataframe[sel(dataframe)]))
bin_centers = (bins[1:] + bins[:-1]) / 2
bin_widths = bins[1:] - bins[:-1]
axes.errorbar(bin_centers, bottom + n, np.sqrt(n), bin_widths / 2,
label=process.label,
**process_kwds)
else:
n, _, _ = axes.hist(
variable(dataframe[sel(dataframe)]),
bins=bins, range=range,
bottom=bottom,
label=process.label,
weights=weight(dataframe[sel(dataframe)]),
**process_kwds)
bottom += n
axes.set_xlim((bins.min(), bins.max()))
axes.set_ylim((0, axes.get_ylim()[1] * 1.4))
axes.legend(frameon=False)
axes.set_ylim((0, axes.get_ylim()[1] * 1.6))
axes.legend(frameon=False, loc=1)
if variable.unit is not None:
axes.set_xlabel("%s in %s" % (variable.name, variable.unit),
......@@ -241,21 +240,24 @@ def confusion_matrix(df, x_processes, y_processes, x_label, y_label,
elif axes is None:
axes = figure.subplots()
data = {x_label: [], y_label: [], 'z': []}
y_processes.reverse()
data = np.zeros((len(y_processes), len(x_processes)))
for i_x, x_process in enumerate(x_processes):
x_df = df[x_process.selection(df)]
total_weight = weight(x_df).sum()
for i_y, y_process in enumerate(y_processes):
x_y_df = x_df[y_process.selection(x_df)]
data[x_label].append(x_process.label)
data[y_label].append(y_process.label)
data['z'].append(weight(x_y_df).sum() / total_weight)
data[i_y][i_x] = weight(x_y_df).sum() / total_weight
data = pd.DataFrame(data)
data = data.pivot(y_label, x_label, "z")
data = pd.DataFrame(data,
columns=[p.label for p in x_processes],
index=[p.label for p in y_processes])
sns.heatmap(data, **dict(vmin=0, vmax=1, cmap="Greens", ax=axes,
cbar_kws={
'label': "$P($%s$|$%s$)$" % (y_label, x_label)
}
), **kwds)
axes.set_xlabel(x_label)
axes.set_ylabel(y_label)
return data
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment