% Biophys J. 2015 Nov 3;109(9):1772-80. doi: 10.1016/j.bpj.2015.09.017.
% Optimal Drift Correction for Superresolution Localization Microscopy with Bayesian Inference.
%
% Software:  https://github.com/jiyuuchc/BaSDI/releases/

function driftCoords = driftCalculation_BaSDI(pointillisticData, settings)
    
    x_coords = pointillisticData.x_coord;
    y_coords = pointillisticData.y_coord;
    frameIndices = pointillisticData.frame_idx;
    
    nFrames = settings.nFrames;
    
    stackSize = settings.frameRange(2);
    scaleFactor = settings.scaleFactor;
    
    if stackSize < 2*nFrames
        warning('Too few frames (%d) for BaSDI drift correction with "stackSize" parameter %d, decrease the parameter. Now returning zero drift.', stackSize, nFrames)
        driftCoords.x_coord = zeros(nFrames, 1);
        driftCoords.y_coord = zeros(nFrames, 1);
        return
    end
    
    i = 0;
    for startFrame = 1:nFrames:stackSize
        endFrame = min(stackSize, startFrame + nFrames-1);
        
        map = frameIndices >= startFrame & frameIndices <= endFrame;
        x_coords_filtered = x_coords(map);
        y_coords_filtered = y_coords(map);
        
        i = i+1;
        O{i} = scaleFactor*[x_coords_filtered y_coords_filtered];
    end
    
    S = BaSDI(O, 1);
    
    drift = viterbi( S.e );
    drift = drift / scaleFactor;
    
    intType = 'pchip';
    dxList = interp1( (1+nFrames/2:nFrames:stackSize), drift(:,1), 1:stackSize, intType, 'extrap' );
    dyList = interp1( (1+nFrames/2:nFrames:stackSize), drift(:,2), 1:stackSize, intType, 'extrap' );
    
    dx = dxList;
    dy = dyList;
    
    %dx = dx - mean(dx);
    %dy = dy - mean(dy);
    
    driftCoords.x_coord = dx';
    driftCoords.y_coord = dy';

end


function v = viterbi(e, p, eps)
% function v = viterbi(e)
% return most likely state sequence based on viterbi algorithm

if (nargin < 2)
    p = 0.1;
end

[dh, dw, f] = size(e);

if (nargin < 3)
    eps = 0;
end

%default transition matrix
T = [p^2, p, p^2, p, 1, p, p^2, p, p^2]';
T = T / sum(T) - eps; %normalize
[x, y] = meshgrid(-1:1, -1:1);
s = x(:) * (dh + 2) + y(:); %shifting offset

%forward
h = zeros(dh + 2, dw + 2, f);
vn = zeros(dh + 2, dw + 2);
vn(2:dh+1, 2:dw+1) = e(:,:,1);
vp = vn;
[m, idx2] = max(vp(:));
m = m * eps;

for i = 2:f
    for x = 2:dw+1;
        for y = 2:dh+1
            sn = s + (x - 1) * (dh + 2) + y;
            [mp, idx] = max(vp(sn) .* T);
            idx = sn(idx);
            if (mp < m )
                mp = m;
                idx = idx2;
            end    
            vn(y, x) = mp * e(y-1, x-1, i);
            h(y, x, i) = idx;
        end
    end
    vp = vn;
    vp = vp / max(vp(:));
    [m, idx2] = max(vp(:));
    m = m * eps;
end

%backward
v = zeros(f,2);
[temp, idx] = max(vp(:));
idx = idx - 1;
v(f,2) = floor(idx / (dh + 2));
v(f,1) = idx - v(f,2) * (dh + 2);

for i = f:-1:2
    hp = h(:,:,i);
    idx = hp(idx + 1);
    idx = idx - 1;
    v(i-1,2) = floor(idx / (dh + 2));
    v(i-1,1) = idx - v(i-1,2) * (dh + 2);
end

v(:,1) = v(:,1) - (dh + 1)/2;
v(:,2) = v(:,2) - (dw + 1)/2;
v = -v;
end



function S = BaSDI(O, pixel_size)
% Bayesian Super-resolution Drift Inference
% S = BaSDI(O, pixel_size)
% O: input localization data
%    'O' can either be a simple array or a cell array
%    If O is a simple array, each row of the array represents one
%    localization event in the format of (y,x,frame). Frame starts at 1.
%    If O is a cell array, each cell is a two-column array representing a
%    single image frame.
%    The y,x coordinates can be in any physical unit, e.g. 'nm', as long as
%    it is the same unit for 'pixel_size'. Coordinate (0,0) represents top-left 
%    of the image.
% pixel_size: The pixle_size used for rendering the final corrected
% super-resolution image. 
%
% S: output structure. 
%    S.theta: the corrected image.
%    S.g:	  posterior distirbution funciton P(d_k|o,theta).
%    S.e:     Drift probability of each frame w/o considering prior
%    probability distribution P(d). Can be used as the input for
%    compute the most likely drift trace using viterbi.m.

if (nargin == 1) 
    pixel_size = 1;
    warning('BaSDI:Preprocessing', 'Only one input. Assuming pixel_size is 1');
end

% convert the array into an cell array
if ~iscell(O)
    if ~ismatrix(O)
        error('BaSDI:Preprocessing', 'Wrong input format');
    end
    
    [oh,ow] = size(O);
    if (ow ~= 3)
        error('BaSDI:Preprocessing', 'Wrong input format');        
    end
    
    max_frame = max(O(:,end));
    OC = {};
    for i = 1:max_frame
        OC{i} = O( find(floor(O(:,end))==i), 1:2 );
    end
    
    O = OC;
end

% Estimating image size and convert coordinates into pixels

padding = 60; % padding some empty pixels around image borders

L = length(O);
mc = [0,0];
for i = 1:L;
    O{i} = floor( O{i} / pixel_size ) + padding;
    mc = max(mc, max(O{i},[],1));
end

if (mc(1) < 100 || mc(2) < 100)

end

% round up to the near 10.
mc = (floor(mc/10) + 1) * 10;
h = mc(1) + padding;
w = mc(2) + padding;

if (h*w > 5e7) 
    warning('BaSDI:Preprocessing', 'Very large images. May take very long time');
end

S = BaSDI_main(O, h, w);

end

function S = BaSDI_main(O, h, w)

% Change these to suit you need
% ---------------------------------

% Annealing schedule. 
% The algorithm run multiple runs of optimization with reducing smoothing parameter (scale) for each round. The annealing schedule helps the convergence of the optimization
% scale: starts at the value below and gradually decrease to zero, controls the size of the low-pass filter used to smooth the theta image. You can reduce the starting value of scale (to save time) if your data is of good quality (high sampling rate). Reducing it too much will result in convergence problems.
% anneal_step controls the rate of scale decreasing. 
scale = 2.4; % starting smoothing parameter
anneal_step = 0.4; % speed of scale decreasing

% convergence control in each round of optimization
% In each round, the EM iterations were run until convergence is achieved (controlled by cvge parameter) or maximum iteration had been reached.
cvge = 0.3; % convergence test criteria
max_iter = 5; %% maximum number of iteration for each round of optimization

%others
p = 0.2; % amplitude (sigms^2) of drift
eps = 0.001/h/w; % creep probability. Set to 0 if your system don't have a creep problem
max_shift = 60; % maximum drift (pixels) that is being calculated
resolution = 2; % Localization uncertainty (FWHM) in pixels

% ---------------------------------------------

% setting up parameters
if (nargin < 5)
    theta = construct_palm(O, h, w);
end

parameters.p = p;
parameters.eps = eps;
parameters.smooth = resolution;
parameters.max_shift = max_shift;
parameters.scale = scale;

hForm = gcf;
hWaitBar = waitbar(0, 'Drift Estimation: 0%' );

startScale = scale;
endScale = 1.2;
max_allIters = (1+(startScale-endScale)/anneal_step) * max_iter;
allIters = 0;

%start
iter_r = 1;
d = zeros(length(O), 2);
while (scale >= endScale)

    iter = 0;
    display(['round - ' int2str(iter_r)]);
    
    %c = 0;
    %while (c(1) == 0  && iter < max_iter)
    c = [0 0];
    while ((c(1) == 0 || c(2) == 0) && iter < max_iter)
        
        fs = round(exp(scale));

        

        S = BaSDI_iter(O, h, w, parameters, conv2(theta,ones(fs,fs))); 
        %parameters.smooth = resolution * exp(scale);
        %S = BaSDI_iter(O, h, w, parameters, theta); 

        theta = S.theta;
        iter = iter + 1;
        allIters = allIters + 1;
        perc = allIters / max_allIters;
        
        set(0, 'CurrentFigure', hWaitBar); 
        waitbar( perc, hWaitBar, sprintf('Drift Estimation: %d%%', round(perc*100) ) )
        set(0, 'CurrentFigure', hForm); 


        d_out = processing_result(S.g);
        c = testing_converge(d, d_out, cvge);
        d = d_out;
        %input('');
        
        %imagesc(theta); input('');

    end
    
    scale = scale - anneal_step;
    iter_r = iter_r + 1;

end

% imagesc(S.theta);

close( hWaitBar );
end


function [S,P] = BaSDI_iter(O, h, w, parameters, theta)

% Preprocessing
if (nargin < 5)
    theta = construct_palm(O, h, w);
end

if (nargin < 4)
    parameters = struct();
end

if ~ isfield(parameters, 'p')
    parameters.p = 0.1;
end
if ~ isfield(parameters, 'eps')
    parameters.eps = 0;
end
if ~ isfield(parameters, 'max_shift')
    parameters.max_shift = 60;
end
if ~ isfield(parameters, 'smooth')
    parameters.smooth = 2;
end

OC = remove_border(O, h, w, parameters.max_shift);

%E step
disp('E step');
theta2 = PSFBlur(theta,parameters.smooth); 
e = EXY(theta2, OC, parameters.max_shift);
[g, g_s] = for_back(e, parameters.p, parameters.eps);

%M step
disp('M step');
theta = update_theta(O, h, w, g);

%processing_result(g); input('');
%imagesc(theta); input('');

S.theta = theta;
S.e = e;
S.g = g;
S.dim = [h,w];

P = parameters;

end



function O1 = cat_cellarray(O, n)
% function O1 = cat_cellarray(O, n)
% combining every n cells together

for i = 1 : floor( (length(O)-1) / n) + 1
    O1{i} = cat(1, O{ (i-1)*n + 1 : (i-1)*n + n });
end

end

function c = center_of_mass(I)

[h,w] = size(I);
[y,x] = meshgrid(1:h, 1:w);

I = I / mean(I(:));

c(1) = mean(mean(I.*y));
c(2) = mean(mean(I.*x));

end

function I = construct_palm(O, h, w, d)
% Output a reconstructed super-resolution image based on input data and drift
% I = construct_palm(O, h, w, d)
% Inputs:
% O: localization dataset. A cell array with N elements. N is the number of image frames.
% h: hight of the image
% w: width of the image
% d: drift trace. Nx2 array
% Output:
% I: reconstructed super-resolution image. An 2D array (h x w).

L = length(O);

if (nargin < 4)
    % use a efficient algorithm for no-drift PALM construction
    o = cat(1, O{:});
    idx = o(:,2) * h + o(:,1) + 1;
    [I,bins] = hist(idx, 1: h*w);

    I = reshape(I, h, w);

else 

    I = zeros(h, w);

    for i = 1:L;
        J = ij_to_image(O{i}, h, w);
        I = I + shift_image(J, -d(i, :));
    end
    
end

end


function e = EXY(theta, O, max_shift)
%e_xy = E_XY(theta, O, max_shift)
%compute P(dx,dy|theta,O) for each individual frame as a function of drift d
%O: a cell array of 0-based coordinates for every frame
%theta: proportional to the real image, no need to be normalized

if (nargin < 3)
    max_shift = 20;
end

bg = 1; % allow a fixed false positive probability as a background signal. 

[h,w] = size(theta);
if (h <= max_shift * 2 + 1 || w <= max_shift * 2 + 1)
    error('Image size too small to allow shifting calculation');
end
eps = (max_shift * 2 + 1)^(-2);
% theta = theta / sum(theta(:)); % normalize

% initialize 
theta = theta + bg; % add a small chance for bg noise
logtheta = log(theta); % compute in log space to avoid overflow

for k = 1:length(O);
    o = O{k};
    if (length(o) > 0) % make sure it's not an empty frame (no molecule)
        
		% all the real work is done in exyf2
        e(:,:,k) = exyf2(logtheta, o, max_shift, k);
        
    else % for empty frame, simple assign a constant probability distribution
        
        e(:,:,k) = eps;

    end
end
end


function e = exyf2(logtheta, o, max_shift, k)
%compute e(dx, dy) = P(dx,dy|theta,o)

[h,w] = size(logtheta);
e = zeros(max_shift * 2 + 1, max_shift * 2 + 1);
oi = ij_to_image(o - max_shift, h - max_shift * 2, w - max_shift * 2);
e = conv2(logtheta(h:-1:1,w:-1:1), oi, 'valid');
%convert back to linear scale
e(:) = e(:) - max(e(:));
e = exp(e(2 * max_shift + 1:-1:1, 2 * max_shift + 1:-1:1));
%normalize
s = sum(e(:));
if s > 0
    e = e / s;
end

end


function [g, g_s] = for_back(e_xy, p, eps)
% function [g, g_s] = for_back(e, p)
% forward_backward alogorithm for computing marginal probability
% of Markovian process

if (nargin < 2)
    p = 0.2; % default variance of the gaussian filter 
end

% h,w is the size of the shifting matrix
% f is the number of frames
[h,w,f] = size(e_xy);

if (nargin < 3)
    eps = 0;
end

%default transition matrix
T = fspecial('gaussian', [3, 3], sqrt(p)); 
% T = [p^2, p, p^2;
%     p, 1, p;
%     p^2, p, p^2];
% T = T / sum(T(:)); %normalize

%forward computation
a(:,:,1) = e_xy(:,:,1);

a_s = zeros(1,f);

for i = 2:f;
    e_i = e_xy(:,:,i);
    a_t = ofs_filter2(T, a(:,:,i-1), eps) .* e_i;
    a_m = max(a_t(:));
    if a_m > 0
        a(:,:,i) = a_t / a_m;
    else
        a(:,:,i) = 0*a_t;
    end
    a_s(i) = a_s(i - 1) + log(a_m);
end

%backward computation
b(:,:,f) = zeros(w,h) + 1;
b_s = zeros(1,f);

for i = f-1:-1:1
    e_i = e_xy(:,:,i);
    b_t = ofs_filter2(T,b(:,:,i+1).*e_i, eps);
    b_m = max(b_t(:));
    if b_m > 0
        b(:,:,i) = b_t / b_m;
    else
        b(:,:,i) = 0*b_t;
    end
    b_s(i) = b_s(i+1) + log(b_m);
end

%calculate the probability
g = a.*b;
gn = g(:,:,1);

for i = 1:f
    gk = g(:,:,i);
    s = sum(gk(:));
    if s > 0
        g(:,:,i) = gk / s;
    end
end
g_s = a_s + b_s;
end


function I = gen_img(s)
% Generate the ground truth image of the size s x s

I = zeros(s,s);
I(round(s/2 - s/10):round(s/2 + s/10),round(s/2 - s/10):round(s/2 + s/10))=10;
I(round(s/2 - s/10):round(s/2 + s/10),round(s/2 - s/33): round(s/2 + s/33)) = 0;
I(round(s/2 - s/33): round(s/2 + s/33),round(s/2 - s/10):round(s/2 + s/10)) = 0;

end


function O = gen_palm_data(I, n, d)
% Simple simulation to generate a localization imaging dataset
% O = gen_palm_data(I, n, d)
% Inputs:
% I: ground truth image.
% n: average number of molecules detected in each frame
% d: a drift trace (N x 2 array). The length of d also determined the number  of image frames generated
% Output:
% O: a cell array representing the localization dataset.

[h,w] = size(I);

d = round(d);
ni = double(I(:));
ni = ni / sum(ni) * n;

for i = 1:length(d);
    r = rand(length(ni),1);
    img = circshift(reshape(r < ni,h,w), d(i,:));
    
    idx = find(img) - 1;

    cols = floor(idx / h);
    o = idx - cols * h;
    o(:,2) = cols;
    O{i} = o;
end

end


function I = ij_to_image(ij, h, w)
% I = ij_to_image(ij, h, w)
% convert a set of coordinates to an binary image

I = zeros(h,w);
if (size(ij,1) > 0);
    idx = ij(:,2) * h + ij(:,1) + 1;
    I(idx) = 1;
end
end


function J = ofs_filter2(T, I, eps)
% J = ofs_filter2(T, I, eps)
% calculate offsetted 2D filter. Its equivalent to filter2(T + eps , I), but faster

J = filter2(T, I);
J = J + sum(I(:)) * eps;

end


function [d_out, sigma] = processing_result(g)
% [d_out, sigma] = processing_result(g) 
%  Compute some simple statistics of the drift
% g: Marginal posteriori distributions of drift. It is computed from for_back function
% d_out: expectation values of sample position
% sigma: distance between adjacent positions 

[h,w,f] = size(g);
maxshift = (h - 1)/2;
[x,y] = meshgrid(-maxshift:maxshift,-maxshift:maxshift);

for i = 1:f;
    gn = g(:,:,i);
    gn = gn / sum(gn(:));
    
    cx(i) = sum(sum(x.*gn));
    cy(i) = sum(sum(y.*gn));
    
    cx_2(i) = sum(sum((x.^2).*gn));
    cy_2(i) = sum(sum((y.^2).*gn));
    
    sdx = (cx_2(i)-(cx(i))^2);
    sdy = (cy_2(i)-(cy(i))^2);
    
    sigma(i) = sqrt(sdx + sdy);
end

%plot(-[cy - cy(1) ;cx - cx(1)]');
% plot(-[cy ; cx]');
d_out = -[cy; cx]';
%plot(sigma);
end


function J = PSFBlur(I, psfFWHM)
% J = PSFBlur(I, psfFWHM)
% Gaussian filtering of a image.

s = psfFWHM / 2.355;
d = floor(psfFWHM * 3);
if (d < 5);
    d = 5;
end
h = fspecial('gaussian',[d,d],s);
J = imfilter(I,h, 'circular');
end


function OC = remove_border(O, h, w, m)
% OC = remove_border(O, h, w, m)
% Remove localization data that are within a m pixel border of the images.
% O: localization dataset. Cell array of N elements.
% h: height of image.
% w: width of image.
% m: border size.
% OC: new localization dataset with molecules at the border removed.

for i = 1:length(O)
    o = O{i};
    len = size(o,1);
    oc = [];
    for j = 1:len;
        x = o(j,1);
        y = o(j,2);
        if (x>=m && x < w-m && y >= m && y < h - m)
            oc = [oc; x,y];
        end
    end
    OC{i}= oc;
end
end


function J = shift_image(I, d)
% J = shift_image(I, d)
% Shift image I by a translational drift d
% I: input image. 2D array
% d: tuple (dx, dy)
% J: output image.

[h,w] = size(I);
dx = d(2);
dy = d(1);

if (dx < 0)
    x1_s = - dx + 1; x2_s = w;
    x1_d = 1; x2_d = w + dx;
else
    x1_s = 1; x2_s = w - dx;
    x1_d = dx + 1; x2_d = w;
end

if (dy < 0)
    y1_s = - dy + 1; y2_s = h;
    y1_d = 1; y2_d = h + dy;
else
    y1_s = 1; y2_s = h - dy;
    y1_d = dy + 1; y2_d = h;
end

J = zeros(h,w);
J(y1_d:y2_d, x1_d:x2_d) = I(y1_s:y2_s, x1_s:x2_s);
end


function c = testing_converge(t, t1, eps)
% c = testing_converge(t, t1, eps)
% testing if the relative distance between two vector t and t1 is smaller than eps

if (nargin < 3)
    eps = 0.001;
end

dt = std(t - t1);
c = ( dt < std(t1) * eps );
end


function theta = update_theta(O,h,w,g)
% theta = update_theta(O,h,w,g)
% the M step: obtaining new theta values
% O: localization dataset. A cellarray.
% h: image height
% w: image width
% g: P(dx,dy). marginal posteriori distributions of drift from last iteration

theta = zeros(h,w);
[dh, dw, frames] = size(g);

maxshift = (dh - 1)/2;
[x,y] = meshgrid(-maxshift:maxshift,-maxshift:maxshift);
gn = g(:,:,1);
s = sum(gn(:));
if s > 0
    gn = gn / s;
end

cx = sum(sum(x.*gn));
cy = sum(sum(y.*gn));

for k = 1:frames;
    I = ij_to_image(O{k}, h, w);
    %theta = theta + imfilter(I, g(:,:,k),'conv');
    theta = theta + conv2(I, g(:,:,k),'same');
end

theta = circshift(theta, - round([cy, cx]));
end




function O = xyf2cells(xyf)
% O = xyf2cells(xyf)
% Convert an Nx3 array of x, y, frame# data to a cellarry
% The result is useful as an input for BaSDI

[h,w] = size(xyf);

%Do something useful if user mistakenly used a 3xN matrix instead of Nx3
if (w > 3 && h == 3) 
    h = w;
    w = 3;
    xyf = xyf';
end

%Well it's beyond hope
if (w < 3 )
    error('BaSDI:matrixSize', 'The input matrix size should be Nx3')
end

O = {};
frames = floor(xyf(:,3));
maxframe = max(frames);
minframe = min(frames);

cur = 1;
for f = minframe:maxframe
    idx = find(frames == f);
    O{cur} = xyf(idx, 1:2);
    cur = cur + 1;
end
end






