-
Notifications
You must be signed in to change notification settings - Fork 54
/
cnmf.m
67 lines (50 loc) · 1.68 KB
/
cnmf.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
function [ Z, H, dnorm] = cnmf ( X, k, y, varargin )
% Matrix sizes
% X: m x n
% Z: m x num_of_components
% H: num_of_components x num_of_components
% Process optional arguments
pnames = {'z0' 'h0' 'bUpdateH' 'maxiter' 'nonlinearity_function', 'TolFun'};
% Do SVD initialisation of the init components
if 1
[z0, h0] = NNDSVD(abs(X), k, 0);
else
z0 = rand(size(X, 1), k);
h0 = rand(k, size(X,2));
end
dflts = {z0, h0, 1, 300, @(x) x, 1e-5};
[z0, h0, bUpdateH, max_iter, g, tolfun] = ...
internal.stats.parseArgs(pnames,dflts,varargin{:});
% X => % p x n
Z = z0; % p x k
A = ind2vec([y; [length(y)+1:size(X, 2)]']')';
H = max(abs(A' * pinv(h0)), eps); % c x k
for i = 1:max_iter
if bUpdateH
numer = A' * X' * Z;
H = H .* (numer ./ (((A' * A) * H * (Z' * Z)) + eps(numer)));
end
numer = X * A * H;
Z = Z .* (numer ./ (Z * (H' * A') * (A * H) + eps(numer)));
if mod(i, 10) == 0 || mod(i+1, 10) == 0
s = X - Z * H' * A';
dnorm = sqrt(sum(s(:).^2));
if mod(i+1, 10) == 0
dnorm0 = dnorm;
continue
end
% if mod(i, 100) == 0
display(sprintf('...CNMF iteration #%d out of %d, error: %f\n', i, max_iter, dnorm));
% end
% if exist('dnorm0')
% assert(dnorm <= dnorm0, sprintf('Rec. error increasing! From %f to %f. (%d)', dnorm0, dnorm, k));
% end
% Check for convergence
if exist('dnorm0') && dnorm0-dnorm <= tolfun*max(1,dnorm0)
display(sprintf('Stopped at %d: dnorm: %f, dnorm0: %f', i, dnorm, dnorm0));
break;
end
end
end
H = A * H;
H = H(length(y) +1 : end, :)';