Source code for smefit.analyze.fisher

# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import colors
from mpl_toolkits.axes_grid1 import make_axes_locatable
from rich.progress import track

from .latex_tools import latex_packages
from .pca import impose_constrain


[docs] class FisherCalculator: """Computes and writes the Fisher information table, and plots heat map. Linear Fisher information depends only on the theoretical corrections, while quadratic information requires fit results. Parameter constraints are also taken into account. Only fitted degrees of freedom are shown in the tables. Parameters ---------- coefficients: smefit.coefficients.CoefficienManager coefficient manager datasets: smefit.loader.DataTuple DataTuple object with all the data information """ def __init__(self, coefficients, datasets, compute_quad): self.coefficients = coefficients self.free_parameters = self.coefficients.free_parameters.index self.datasets = datasets # update eft corrections with the constraints if compute_quad: ( self.new_LinearCorrections, self.new_QuadraticCorrections, ) = impose_constrain(self.datasets, self.coefficients, update_quad=True) else: self.new_LinearCorrections = impose_constrain( self.datasets, self.coefficients ) self.lin_fisher = None self.quad_fisher = None self.summary_table = None self.summary_HOtable = None
[docs] def compute_linear(self): """Compute linear Fisher information.""" fisher_tab = [] cnt = 0 for ndat in self.datasets.NdataExp: fisher_row = np.zeros(self.free_parameters.size) idxs = slice(cnt, cnt + ndat) sigma = self.new_LinearCorrections[:, idxs] fisher_row = np.diag(sigma @ self.datasets.InvCovMat[idxs, idxs] @ sigma.T) fisher_tab.append(fisher_row) cnt += ndat self.lin_fisher = pd.DataFrame( fisher_tab, index=self.datasets.ExpNames, columns=self.free_parameters )
[docs] def compute_quadratic(self, posterior_df, smeft_predictions): """Compute quadratic Fisher information.""" quad_fisher = [] # compute some average values over the replicas # delta exp - th (n_dat) delta_th = self.datasets.Commondata - np.mean(smeft_predictions, axis=0) # c, c**2 mean (n_free_op) posterior_df = posterior_df[self.free_parameters] c_mean = np.mean(posterior_df.values, axis=0) c2_mean = np.mean(posterior_df.values**2, axis=0) # squared quad corr diag_corr = np.diagonal(self.new_QuadraticCorrections, axis1=0, axis2=1) off_diag_corr = self.new_QuadraticCorrections diag_index = np.diag_indices(self.free_parameters.size) off_diag_corr[diag_index[0], diag_index[1], :] = 0 # additional tensors tmp = np.einsum("ri,ijk->rjk", posterior_df, off_diag_corr, optimize="optimal") A_all = np.mean(tmp, axis=0) # (n_free_op, n_dat) B_all = ( np.einsum("rj,rjk->jk", posterior_df, tmp, optimize="optimal") / posterior_df.shape[0] ) # (n_free_op, n_dat) D_all = ( np.einsum("rjk,rjl->jkl", tmp, tmp, optimize="optimal") / posterior_df.shape[0] ) # (n_free_op, n_dat, n_dat) cnt = 0 for ndat in track( self.datasets.NdataExp, description="[green]Computing quadratic Fisher information per dataset...", ): # slice the big matrices idxs = slice(cnt, cnt + ndat) quad_corr = diag_corr[idxs, :].T lin_corr = self.new_LinearCorrections[:, idxs] inv_corr = self.datasets.InvCovMat[idxs, idxs] delta = delta_th[idxs] A = A_all[:, idxs] B = B_all[:, idxs] D = D_all[:, idxs, idxs] # (n_free_op) fisher_row = ( -quad_corr @ inv_corr @ delta.T - delta @ inv_corr @ quad_corr.T + lin_corr @ inv_corr @ A.T + A @ inv_corr @ lin_corr.T + 2 * c_mean
[docs] @ ( lin_corr @ inv_corr @ quad_corr.T + quad_corr @ inv_corr @ lin_corr.T ) + 2 * (B @ inv_corr @ quad_corr.T + quad_corr @ inv_corr @ B.T) + 4 * c2_mean @ quad_corr @ inv_corr @ quad_corr.T + np.einsum("ikl,kl -> i", D, inv_corr, optimize="optimal") ) quad_fisher.append(np.diag(fisher_row)) cnt += ndat self.quad_fisher = pd.DataFrame( quad_fisher + self.lin_fisher.values, index=self.datasets.ExpNames, columns=self.free_parameters, )
@staticmethod def normalize(table, norm, log): """ Normalize a Pandas DataFrame Parameters ---------- table: pandas.DataFrame table to normalize norm: "data", "coeff" if "data" it normalize by columns, if "coeff" by rows log: bool presents the log of the Fisher if True Returns ------- pandas.DataFrame normalized table """ if table is None or table.empty: return None if norm == "data": axis_sum, axis_div = 1, 0 elif norm == "coeff": axis_sum, axis_div = 0, 1 else: raise ValueError(f"Invalid norm value: {norm}. Must be 'data' or 'coeff'.") table = table.div(table.sum(axis=axis_sum), axis=axis_div) * 100 if log: table = np.log(table[table > 0.0]) return table.replace(np.nan, 0.0)
[docs] def groupby_data(self, table, data_groups, norm, log): """Merge fisher per data group.""" summary_table = pd.merge( data_groups.reset_index(), table, left_on="level_1", right_index=True ) summary_table = summary_table.groupby("level_0").sum(numeric_only=True) summary_table.index.name = "data_group" return self.normalize(summary_table, norm, log)
[docs] def write_grouped(self, coeff_config, data_groups, summary_only): """Write Fisher information tables in latex, both for grouped data and for summary. Parameters ---------- coeff_config: dict coefficient dictionary per group with latex names data_groups: dict dictionary with datasets per group and relative links summary_only: bool if True only the summary Fisher table fro grouped data is written Returns ------- list(str) list of the latex commands """ L = latex_packages() L.extend( [ r"\begin{document}", r"\begin{landscape}", ] ) # fisher tables per data_group if not summary_only: for data_group, data_dict in data_groups.groupby(level=0): temp_table = self.lin_fisher.loc[data_dict.index.get_level_values(1)] temp_HOtable = None if self.quad_fisher is not None: temp_HOtable = self.quad_fisher.loc[ data_dict.index.get_level_values(1) ] L.extend( self._write( temp_table, temp_HOtable, coeff_config, data_dict.droplevel(0), data_group, ) ) L.extend( self._write( self.summary_table, self.summary_HOtable, coeff_config, ) ) L.append(r"\end{landscape}") return L
def _write( self, lin_fisher, quad_fisher, coeff_config, data_dict=None, data_group=None ): """Write Fisher information table in latex. Parameters ---------- lin_fisher: pandas.DataFrame linear Fisher information table quad_fisher: pandas.DataFrame, None quadratic Fisher information table, None if linear only coeff_config: dict coefficient dictionary per group with latex names data_dict: dict, optional dictionary with datasets and relative links data_group: str, optional data group name Returns ------- list(str) list of the latex commands """ def color(value, thr_val=10): if value > thr_val: return ("blue", value) return ("black", value) L = [ r"\begin{table}[H]", r"\scriptsize", r"\centering", r"\begin{tabular}{|c|c|" + "c|" * lin_fisher.shape[0] + "}", r"\hline", f"\\multicolumn{{2}}{{|c|}}{{}} \ & \\multicolumn{{{lin_fisher.shape[0]}}}{{c|}}{{Processes}} \\\\ \\hline", ] temp = " Class & Coefficient " if data_dict is None: for dataset in lin_fisher.index: temp += f"& {{\\rm {dataset} }}" else: for dataset, link in data_dict.items(): temp += f"& \\href{{{link}}}{{${{\rm {dataset}}}$}}".replace("_", r"\_") temp += r"\\ \hline" L.append(temp) # loop on coeffs for coeff_group, coeff_dict in coeff_config.groupby(level=0): coeff_dict = coeff_dict.droplevel(0) L.append(f"\\multirow{{{coeff_dict.shape[0]}}}{{*}}{{{coeff_group}}}") for coeff, latex_name in coeff_dict.items(): idx = np.where(coeff == self.free_parameters)[0][0] temp = f" & {latex_name}" # loop on columns for idj, fisher_col in enumerate(lin_fisher.values): temp += r" & \textcolor{%s}{%.2f}" % color(fisher_col[idx]) if quad_fisher is not None: temp += r"(\textcolor{%s}{%0.2f})" % color( quad_fisher.iloc[idj, idx] ) temp += ( r"\\ \hline" if coeff == [*coeff_dict.keys()][-1] else f"\\\\ \\cline{{2-{(2 + lin_fisher.shape[0])}}}" ) L.append(temp) caption = ( "Fisher information" if data_group is None else f"Fisher information in {data_group} datasets" ) L.extend( [ r"\end{tabular}", f"\\caption{{{caption}}}", r"\end{table}", ] ) return L
[docs] def plot( self, latex_names, fig_name, title=None, summary_only=True, figsize=(11, 15), ): """Plot the heat map of Fisher table. Parameters ---------- latex_names : list list of coefficients latex names fig_name: str figure path summary_only: if True plot the fisher grouped per datsets, else the fine grained dataset per dataset figsize : tuple figure size title: str, None plot title """ if summary_only: fisher_df = self.summary_table quad_fisher_df = self.summary_HOtable else: fisher_df = self.lin_fisher quad_fisher_df = self.quad_fisher fig = plt.figure(figsize=figsize) if quad_fisher_df is not None: ax = fig.add_subplot(121) else: ax = plt.gca() # colour map cmap_full = plt.get_cmap("Blues") cmap = colors.LinearSegmentedColormap.from_list( f"trunc({{{cmap_full.name}}},{{0}},{{0.8}})", cmap_full(np.linspace(0, 0.8, 100)), ) norm = colors.BoundaryNorm(np.arange(110, step=10), cmap.N) # ticks yticks = np.arange(fisher_df.shape[1]) xticks = np.arange(fisher_df.shape[0]) x_labels = [f"\\rm{{{name}}}".replace("_", "\\_") for name in fisher_df.index] def set_ticks(ax): ax.set_yticks(yticks, labels=latex_names, fontsize=15) ax.set_xticks( xticks, labels=x_labels, rotation=90, fontsize=15, ) ax.tick_params(which="major", top=False, bottom=False, left=False) # minor grid ax.set_xticks(xticks - 0.5, minor=True) ax.set_yticks(yticks - 0.5, minor=True) ax.tick_params(which="minor", bottom=False) ax.grid(visible=True, which="minor", alpha=0.2) def plot_values(ax, df): for i, row in enumerate(df.values.T): for j, elem in enumerate(row): if elem > 0: ax.text( j, i, f"{elem:.1f}", va="center", ha="center", fontsize=8, ) cax = ax.matshow(fisher_df.values.T, cmap=cmap, norm=norm) plot_values(ax, fisher_df) set_ticks(ax) ax.set_title(r"\rm Linear", fontsize=20, y=-0.08) cax1 = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.5) colour_bar = fig.colorbar(cax, cax=cax1) if quad_fisher_df is not None: ax = fig.add_subplot(122) cax = ax.matshow(quad_fisher_df.values.T, cmap=cmap, norm=norm) plot_values(ax, quad_fisher_df) set_ticks(ax) ax.set_title(r"\rm Quadratic", fontsize=20, y=-0.08) cax1 = make_axes_locatable(ax).append_axes("right", size="10%", pad=0.1) colour_bar = fig.colorbar(cax, cax=cax1) fig.subplots_adjust(top=0.85) colour_bar.set_label( r"${\rm Normalized\ Value}$", fontsize=25, labelpad=30, rotation=270, ) plt.suptitle(f"\\rm Fisher\\ information:\\ {title}", fontsize=25, y=0.98) plt.savefig(f"{fig_name}.pdf") plt.savefig(f"{fig_name}.png")