Source code for smefit.basis_rotation

# -*- coding: utf-8 -*-
"""Implement the corrections table basis rotation"""
import json

import numpy as np
import pandas as pd

from .compute_theory import flatten


[docs] def rotate_to_fit_basis(lin_dict, quad_dict, rotation_matrix_path): """ Rotate to fitting basis Parameters ---------- lin_dict: dict theory dictionary with linear operator corrections in the table basis quad_dict: dict theory dictionary with quadratic operator corrections in the table basis, emptry if quadratic corrections are not used rotation_matrix: pandas.DataFrame rotation matrix from tables basis to fitting basis Returns ------- lin_dict_fit_basis: dict theory dictionary with linear operator corrections in the fit basis quad_dict_fit_basis: dict theory dictionary with quadratic operator corrections in the fit basis, emptry if quadratic corrections are not used """ with open(rotation_matrix_path, encoding="utf-8") as f: rot = json.load(f) rotation_matrix = pd.DataFrame( data=rot["matrix"], index=rot["ypars"], columns=rot["xpars"] ) # select corrections to keep def is_to_keep(op1, op2=None): if op2 is None: return op1 in rotation_matrix.columns return op1 in rotation_matrix.columns and op2 in rotation_matrix.columns lin_dict_to_keep = {k: val for k, val in lin_dict.items() if is_to_keep(k)} lin_df = pd.DataFrame(lin_dict_to_keep) # select the columns corresponding to the # relevant corrections for the operator card basis R = rotation_matrix[lin_df.columns] lin_dict_fit_basis = lin_df @ R.T # look at the quadratic corrections? quad_dict_fit_basis = {} if quad_dict == {}: return lin_dict_fit_basis.to_dict("list"), quad_dict_fit_basis quad_dict_to_keep = { k: val for k, val in quad_dict.items() if is_to_keep(k.split("*")[0], k.split("*")[1]) } tensor = [] # loop over table basis pairs # and build an (n_op_table, n_dat, n_op_fit, n_op_fit) tensor for col, values in quad_dict_to_keep.items(): o1, o2 = col.split("*") r1 = rotation_matrix[o1] r2 = rotation_matrix[o2] r1r2 = np.outer(r1, r2) r1r2o1o2 = np.einsum("i,kj->ikj", values, r1r2) tensor.append(r1r2o1o2) # sum over table basis entries tensor = np.array(tensor) new_quad_corrections = tensor.sum(axis=0) # flatten the tensor (n_dat, n_op_fit, n_op_fit) -> (n_dat, n_op_fit_pairs) new_quad_matrix = [] for correction in new_quad_corrections: new_quad_matrix.append(flatten(correction, axis=1)) new_quad_matrix = np.array(new_quad_matrix) matrix_new_keys = [] for r1 in rotation_matrix.index: row = [] for r2 in rotation_matrix.index: row.append(f"{r1}*{r2}") matrix_new_keys.append(row) new_keys = flatten(np.array(matrix_new_keys)) for i, new_key in enumerate(new_keys): quad_dict_fit_basis[new_key] = new_quad_matrix[:, i] return lin_dict_fit_basis.to_dict("list"), quad_dict_fit_basis