Skip to content

Commit

Permalink
ml_trainlda/ml_trainqda: added a "robust" option that uses geometric
Browse files Browse the repository at this point in the history
medians instead of means (i.e., estimate some quantities under Laplacian
rather than Gaussian noise assumption).
  • Loading branch information
chkothe committed Nov 11, 2015
1 parent edd54fb commit 17bfaf6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
23 changes: 20 additions & 3 deletions code/machine_learning/ml_trainlda.m
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
arg({'regularization','Regularizer','Regularization'}, 'auto', {'none','auto','shrinkage','independence'}, 'Type of regularization. Regularizes the robustness / flexibility of covariance estimates. Auto is analytical covariance shrinkage, shrinkage is shrinkage as selected via plambda, and independence is feature independence, also selected via plambda.'), ...
arg({'weight_bias','WeightedBias'}, false, [], 'Account for class priors in bias. If you do have unequal probabilities for the different classes, this should be enabled.'), ...
arg({'weight_cov','WeightedCov'}, false, [], 'Account for class priors in covariance. If you do have unequal probabilities for the different classes, it makes sense to enable this.'), ...
arg({'robust','Robust'}, false, [], 'Use robust estimation. Uses geometric medians in place of means; can help if some trials are very noisy.'), ...
arg({'votingScheme','VotingScheme'},'1vR',{'1v1','1vR'},'Voting scheme. If multi-class classification is used, this determine how binary classifiers are arranged to solve the multi-class problem. 1v1 gets slow for large numbers of classes (as all pairs are tested), but can be more accurate than 1vR.'));

% find the class labels
Expand All @@ -115,13 +116,29 @@
for c = 1:length(classes)
X = trials(targets==classes(c),:);
n{c} = size(X,1);
mu{c} = mean(X,1);
if robust
mu{c} = geometric_median(X);
else
mu{c} = mean(X,1);
end
if n{c} == 1
sig{c} = zeros(size(X,2));
elseif strcmp(regularization,'auto')
sig{c} = cov_shrink(X);
if robust
% for lack of a better solution in this case we estimate lambda using a non-robust
% Ledoit-Wolf estimator and then use it in a subsequent robust re-estimation
[dummy,lam] = cov_shrink(X); %#ok<ASGLU>
sig{c} = cov_blockgeom(X,1);
sig{c} = (1-lam)*sig{c} + lam*eye(length(sig{c}))*abs(mean(diag(sig{c})));
else
sig{c} = cov_shrink(X);
end
else
sig{c} = cov(X);
if robust
sig{c} = cov_blockgeom(X,1);
else
sig{c} = cov(X);
end
if ~isempty(plambda) && ~strcmp(regularization,'none')
% plambda-dependent regularization
if strcmp(regularization,'independence')
Expand Down
27 changes: 22 additions & 5 deletions code/machine_learning/ml_trainqda.m
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@
arg_norep('targets'), ...
arg({'lambda','Lambda'}, [], [0 1], 'Within-class covariance regularization parameter. Reasonable range: 0:0.1:1 - greater is stronger. Requires that the regularization mode is set to either "shrinkage" or "independence".'), ...
arg({'kappav','Kappa','kappa'}, [], [0 1], 'Between-class covariance regularization parameter. Reasonable range: 0:0.1:1 - greater is stronger. Requires that the regularization mode is set to either "shrinkage" or "independence".'), ...
arg({'regularization','Regularizer'}, 'auto', {'auto','shrinkage','independence'}, 'Regularization type. Regularizes the robustness / flexibility of covariance estimates. Auto is analytical covariance shrinkage, shrinkage is shrinkage as selected via lambda, and independence is feature independence, also selected via lambda.'), ...
arg({'regularization','Regularizer','Regularization'}, 'auto', {'auto','shrinkage','independence'}, 'Regularization type. Regularizes the robustness / flexibility of covariance estimates. Auto is analytical covariance shrinkage, shrinkage is shrinkage as selected via lambda, and independence is feature independence, also selected via lambda.'), ...
arg({'weight_cov','WeightedCov'}, false, [], 'Account for class priors in covariance. If you do have unequal probabilities for the different classes, it makes sense to enable this.'), ...
arg({'robust','Robust'}, false, [], 'Use robust estimation. Uses geometric medians in place of means; can help if some trials are very noisy.'), ...
arg({'votingScheme','VotingScheme'},'1vR',{'1v1','1vR'},'Voting scheme. If multi-class classification is used, this determine how binary classifiers are arranged to solve the multi-class problem. 1v1 gets slow for large numbers of classes (as all pairs are tested), but can be more accurate than 1vR.'), ...
arg_deprecated({'weight_bias','WeightedBias'},false));

Expand Down Expand Up @@ -97,11 +98,27 @@
% get mean and covariance
X = trials(targets==classes(c),:);
n{c} = size(X,1);
mu{c} = mean(X);
if robust
mu{c} = geometric_median(X);
else
mu{c} = mean(X);
end
if strcmp(regularization,'auto')
sig{c} = cov_shrink(X);
if robust
% for lack of a better solution in this case we estimate lambda using a non-robust
% Ledoit-Wolf estimator and then use it in a subsequent robust re-estimation
[dummy,lam] = cov_shrink(X); %#ok<ASGLU>
sig{c} = cov_blockgeom(X,1);
sig{c} = (1-lam)*sig{c} + lam*eye(length(sig{c}))*abs(mean(diag(sig{c})));
else
sig{c} = cov_shrink(X);
end
else
sig{c} = cov(X);
if robust
sig{c} = cov_blockgeom(X,1);
else
sig{c} = cov(X);
end
if ~isempty(lambda)
% lambda-dependent regularization
if strcmp(regularization,'independence')
Expand All @@ -124,6 +141,6 @@
end

% compute the model
model = struct('c',{1/2*(logdet(sig{1})-logdet(sig{2})) + 1/2*((mu{1}/sig{1})*mu{1}' - (mu{2}/sig{2})*mu{2}')}, ...
model = struct('c',{1/2*(logdet((sig{1}+sig{1}')/2)-logdet((sig{2}+sig{2}')/2)) + 1/2*((mu{1}/sig{1})*mu{1}' - (mu{2}/sig{2})*mu{2}')}, ...
'l',{mu{1}/sig{1} - mu{2}/sig{2}}, 'q',{-1/2*(inv(sig{1}) - inv(sig{2}))}, 'sc_info',{sc_info}, 'classes',{classes}, 'featuremask',{retain});
end

0 comments on commit 17bfaf6

Please sign in to comment.