function [dataset, mdl] = train_classifier(dataset,lam)
% This function train cell classifier (logistic regression)
% INPUT:
%   [dataset] a structure that contain fields
%       - features : N x d
%       - labels_ex : N x 1 human / expert labels
%       - labels_ml : N x 1 cell classifier / ML labels 
%       - labels_ml_prob : N x 1 ML labels equal to 1 probability
% OUTPUT:
%   [dataset] : modify the labeld_ml field in the input dataset
%   [mdl]     : the cell classifier, a logistic regression model
if nargin<2
        lam = "auto";
end

labeled_idxs = find(dataset.labels_ex ~= 0);
labeled_ex01 = dataset.labels_ex(labeled_idxs);


num_samples = size(labeled_idxs,1);
if ~isstring(lam)
        lam = lam / num_samples;
end

assert(sum(labeled_ex01==0)==0, 'there are unlabeled data included');

labeled_ex01(labeled_ex01==-1) = 0;
assert(all(labeled_ex01 == 0 | labeled_ex01 == 1), 'The array labeled_ex01 is not binary.');

if ~dataset.balance
    mdl = fitclinear(dataset.features(labeled_idxs,:), labeled_ex01, ...
            'learner', 'logistic', 'regularization', 'lasso',...
             'ClassNames', [0, 1], 'Prior', 'empirical','Lambda',lam);
else
    % balance the training dataset
    dataset_2bb.features = dataset.features(labeled_idxs,:);
    dataset_2bb.labels   = labeled_ex01;
    dataset_2bb.labels(dataset_2bb.labels==0) = -1;
    dataset_b = balance_pretrained_dataset(dataset_2bb);
    dataset_b.labels(dataset_b.labels==-1) = 0;
    mdl = fitclinear(dataset_b.features, dataset_b.labels, ...
            'learner', 'logistic', 'regularization', 'lasso',...
             'ClassNames', [0, 1], 'Prior', 'empirical','Lambda',lam);	
end
[pred, pred_probs] = predict(mdl, dataset.features);

dataset.labels_ml_prob = pred_probs(:,2);
dataset.labels_ml = pred;
dataset.labels_ml(dataset.labels_ml==0) = -1; % because we label it in -1 and 1

dataset.mdl = mdl;
end