ml4eft.core.classifier.Fitter
ml4eft.core.classifier.Fitter#
- class ml4eft.core.classifier.Fitter(json_path, mc_run, c_name, output_dir, print_log=False)[source]#
Bases:
object
Training class
- __init__(json_path, mc_run, c_name, output_dir, print_log=False)[source]#
Fitter constructor
- Parameters
json_path (str) – Path to json run card
mc_run (int) – Replica number
c_name (str) – EFT coefficient for which to learn the ratio function
output_dir (str) – Path to where the models should be stored
print_log (bool, optional) – Set to true to print training progress to stdout, otherwise it prints to a log file only
Methods
__init__
(json_path, mc_run, c_name, output_dir)Fitter constructor
Constructs training and validation sets
loss_fn
(outputs, labels, w_e)Loss function
train_classifier
(data_train, data_val)Starts the training of the binary classifier
training_loop
(optimizer, train_loader, ...)Optimize the classifier with optimizer on the training data set train_loader.
weight_reset
(m)Reset the weights and biases associated with the model
m
.- load_data()[source]#
Constructs training and validation sets
- Returns
data_train (array_like) – Training data set
data_val (array_like) – Validation data set
- loss_fn(outputs, labels, w_e)[source]#
Loss function
- Parameters
outputs (torch.Tensor) – Output of the decision function
labels (torch.Tensor) – Classification labels
w_e (torch.Tensor) – Event weights
- Returns
Average loss of the mini-batch
- Return type
- train_classifier(data_train, data_val)[source]#
Starts the training of the binary classifier
- Parameters
data_train (array_like) – Traning data set
data_val (array_like) – Validation data set
- training_loop(optimizer, train_loader, val_loader)[source]#
Optimize the classifier with optimizer on the training data set train_loader. Keeps track of potential overfitting through val_loader.
- Parameters
optimizer (torch.optim) – Optimizer, e.g. torch.optim.AdamW
train_loader (array_like) – List of torch.utils.data.DataLoader objects, one for the SM and the EFT (training)
val_loader (array_like) – List of torch.utils.data.DataLoader objects, one for the SM and the EFT (validation)