diff --git a/quickstats/plots/likelihood_2D_plot.py b/quickstats/plots/likelihood_2D_plot.py index 35344c3058258d17461b7e1c2e098095ffd86951..ec2100ff0b3df677a98bb2c67ee142ed4e239768 100644 --- a/quickstats/plots/likelihood_2D_plot.py +++ b/quickstats/plots/likelihood_2D_plot.py @@ -109,13 +109,14 @@ class Likelihood2DPlot(AbstractPlot): if show_colormesh: cmap = config['cmap'] - ax.pcolormesh(X, Y, Z, cmap=cmap, shading='auto') - cp = ax.contour(X, Y, Z, levels=levels, colors=colors, + im = ax.pcolormesh(X, Y, Z, cmap=cmap, shading='auto') + import matplotlib.pyplot as plt + plt.colorbar(im, ax=ax, **config['colorbar']).set_label(**config['colorbar_label']) + if levels: + cp = ax.contour(X, Y, Z, levels=levels, colors=colors, linestyles=linestyles, linewidths=3) - self.contours['contours'].append(cp) - self.contours['levels'].append(levels) - if clabel_size is not None: - ax.clabel(cp, inline=True, fontsize=clabel_size) + if clabel_size is not None: + ax.clabel(cp, inline=True, fontsize=clabel_size) custom_handles = [Line2D([0], [0], color=color, linestyle=linestyle, lw=3, label=label) \ for color, key, label, linestyle in \ zip(config['sigma_colors'], config['sigma_levels'], config['sigma_names'], config['sigma_linestyles'])] @@ -128,7 +129,7 @@ class Likelihood2DPlot(AbstractPlot): ylabel: Optional[str] = "", zlabel: Optional[str] = "$-2\Delta ln(L)$", ymax: float = 5, ymin: float = -5, xmin: Optional[float] = -10, xmax: Optional[float] = 10, clabel_size=None, draw_sm_line: bool = False, draw_bestfit:bool=True, - show_colormesh=False): + show_colormesh=False, show_legend=True): ax = self.draw_frame() self.contours = {'keys': [], 'contours': [], 'levels': []} if isinstance(self.data_map, pd.DataFrame): @@ -197,8 +198,9 @@ class Likelihood2DPlot(AbstractPlot): ax.hlines(sm_values[1], xmin=0, xmax=1, zorder=0, transform=transform, **sm_line_styles) - handles, labels = self.get_legend_handles_labels() - ax.legend(handles, labels, **self.styles['legend']) + if show_legend: + handles, labels = self.get_legend_handles_labels() + ax.legend(handles, labels, **self.styles['legend']) self.draw_axis_components(ax, xlabel=xlabel, ylabel=ylabel) self.set_axis_range(ax, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax)