Source code for smefit.compute_theory

# -*- coding: utf-8 -*-

"""
Module for the generation of theory predictions
"""
import jax
import jax.numpy as jnp
import numpy as np


[docs] def flatten(quad_mat, axis=0): """ Delete lower triangular part of a quadratic matrix and flatten it into an array Parameters ---------- quad_mat: numpy.ndarray tensor to flatten axis: int axis along which the triangular part is selected """ size = quad_mat.shape[axis] return quad_mat[np.triu_indices(size)]
[docs] def make_predictions( dataset, coefficients_values, use_quad, use_multiplicative_prescription, rgemat=None ): """ Generate the corrected theory predictions for dataset given a set of |SMEFT| coefficients. Parameters ---------- dataset : DataTuple dataset tuple coefficients_values : numpy.ndarray |EFT| coefficients values use_quad: bool if True include also |HO| corrections use_multiplicative_prescription: bool if True add the |EFT| contribution as a k-factor rgemat: numpy.ndarray solution matrix of the RGE Returns ------- corrected_theory : numpy.ndarray SM + EFT theory predictions """ coefficients_values = jnp.array(coefficients_values) # Compute total linear correction if rgemat is not None: # Check if rgemat comes from a dynamic scale # otherwise it is a single RGEmatrix if len(rgemat.shape) == 3: wcs = jnp.einsum("nij,j->ni", rgemat, coefficients_values) summed_corrections = jnp.einsum("ij,ij->i", dataset.LinearCorrections, wcs) else: summed_corrections = jnp.einsum( "ij,jk,k->i", dataset.LinearCorrections, rgemat, coefficients_values, ) else: summed_corrections = jnp.einsum( "ij,j->i", dataset.LinearCorrections, coefficients_values ) # Compute total quadratic correction if use_quad: if rgemat is not None: # check that rgemat is a 3D array if len(rgemat.shape) == 3: # do outer product on previously computed wcs coeff_outer_coeff = jnp.einsum("ni,nj->nij", wcs, wcs) coeff_outer_coeff_flat = jax.vmap(flatten)(coeff_outer_coeff) summed_quad_corrections = jnp.einsum( "ij,ij->i", dataset.QuadraticCorrections, coeff_outer_coeff_flat ) else: ext_coeffs = jnp.einsum("ij,j->i", rgemat, coefficients_values) coeff_outer_coeff = jnp.outer(ext_coeffs, ext_coeffs) summed_quad_corrections = jnp.einsum( "ij,j->i", dataset.QuadraticCorrections, flatten(coeff_outer_coeff) ) else: coeff_outer_coeff = jnp.outer(coefficients_values, coefficients_values) # note @ is slower when running with mpiexec summed_quad_corrections = jnp.einsum( "ij,j->i", dataset.QuadraticCorrections, flatten(coeff_outer_coeff) ) summed_corrections += summed_quad_corrections # Sum of SM theory + SMEFT corrections if use_multiplicative_prescription: corrected_theory = dataset.SMTheory * (1.0 + summed_corrections) else: corrected_theory = dataset.SMTheory + summed_corrections return corrected_theory