ml4eft.core.classifier.Fitter#

class ml4eft.core.classifier.Fitter(json_path, mc_run, c_name, output_dir, print_log=False)[source]#

Bases: object

Training class

__init__(json_path, mc_run, c_name, output_dir, print_log=False)[source]#

Fitter constructor

Parameters
  • json_path (str) – Path to json run card

  • mc_run (int) – Replica number

  • c_name (str) – EFT coefficient for which to learn the ratio function

  • output_dir (str) – Path to where the models should be stored

  • print_log (bool, optional) – Set to true to print training progress to stdout, otherwise it prints to a log file only

Methods

__init__(json_path, mc_run, c_name, output_dir)

Fitter constructor

load_data()

Constructs training and validation sets

loss_fn(outputs, labels, w_e)

Loss function

train_classifier(data_train, data_val)

Starts the training of the binary classifier

training_loop(optimizer, train_loader, ...)

Optimize the classifier with optimizer on the training data set train_loader.

weight_reset(m)

Reset the weights and biases associated with the model m.

load_data()[source]#

Constructs training and validation sets

Returns

  • data_train (array_like) – Training data set

  • data_val (array_like) – Validation data set

loss_fn(outputs, labels, w_e)[source]#

Loss function

Parameters
Returns

Average loss of the mini-batch

Return type

torch.Tensor

train_classifier(data_train, data_val)[source]#

Starts the training of the binary classifier

Parameters
  • data_train (array_like) – Traning data set

  • data_val (array_like) – Validation data set

training_loop(optimizer, train_loader, val_loader)[source]#

Optimize the classifier with optimizer on the training data set train_loader. Keeps track of potential overfitting through val_loader.

Parameters
  • optimizer (torch.optim) – Optimizer, e.g. torch.optim.AdamW

  • train_loader (array_like) – List of torch.utils.data.DataLoader objects, one for the SM and the EFT (training)

  • val_loader (array_like) – List of torch.utils.data.DataLoader objects, one for the SM and the EFT (validation)

weight_reset(m)[source]#

Reset the weights and biases associated with the model m.

Parameters

m (MLP) – Model of type MLP.