function driftCoords = driftCorrection_rainSTORM(pointillisticData, settings)
% A cross-correlation like drift determinattion. RainSTORM's default algorithm, was implemented by József Németh, as a member of AdOptIm in SZTE.


    % needs to be a cloumn data
    x_coords(:, 1) = pointillisticData.x_coord;
    y_coords(:, 1) = pointillisticData.y_coord;
    try
        x_stds(:, 1) = pointillisticData.std_x;
        y_stds(:, 1) = pointillisticData.std_y;
    catch
        x_stds(:, 1) = pointillisticData.std;
        y_stds(:, 1) = pointillisticData.std;
    end
    frameIndices(:, 1) = double(pointillisticData.frame_idx);

    maximalFrameStep = settings.maximalFrameStep;
    frameRange = double(settings.frameRange);
    nFrames = frameRange(2)-frameRange(1)+1;
    if frameRange(2) < max( frameIndices )
        error("The given frame number is less than the maximal localization frame index.");
    end
    nLocs = numel( x_coords );

    dirList = [];

    % Modified by NT on 2020.04.16.:
    W = min(ceil(nFrames/5), maximalFrameStep);
    STEP = min(ceil(nFrames/5), maximalFrameStep);

    FRAMES = frameRange(1)+W:STEP:frameRange(2)-W;
    dir = [0.0 0.0];

    hWaitBar = waitbar(0, 'Drift Estimation: 0%' );

    % Forming the sections
    for midFrame = FRAMES
        % Filter interval
        map = frameIndices >= midFrame-W & frameIndices < midFrame+W;
        if sum(map)==0
            % if there were no localizations in the given frame range, add "0" values to the drift "direction"
            dirList = [dirList; [0, 0]];
            warning(['Could not perform the drift correction between the ', num2str(midFrame-W), ' and ', num2str(midFrame+W), ' frames because there are no localization in it.']);
            continue
        end

        x_coords_filtered = x_coords(map);
        y_coords_filtered = y_coords(map);
        frameIndices_filtered = frameIndices(map);
        x_stds_filtered = x_stds(map);
        y_stds_filtered = y_stds(map);

        % calculate the linear drift of the sections
        dir = estimateLinearDrift( x_coords_filtered, y_coords_filtered, frameIndices_filtered, x_stds_filtered, y_stds_filtered, dir );

        fprintf( 1, 'dx: %.8f   dy:%.8f\n', dir(1), dir(2)  );

        dirList = [dirList; dir];

        perc = midFrame / nFrames;
        waitbar( perc, hWaitBar, sprintf('Drift Estimation: %d%%', round(perc*100) ) )
    end

    %dirList = [ [0 0]; dirList; [0 0] ];
    dirList = [ dirList(1,:); dirList; dirList(end,:) ];

    intType = 'linear'; % 'pchip';
    % interpolate the drift for each frame
    TList = [0 (FRAMES) nFrames-1];
    dxList = interp1( TList, dirList(:,1), 0:nFrames-1, intType, 'extrap' );
    dyList = interp1( TList, dirList(:,2), 0:nFrames-1, intType, 'extrap' );
    % calculate the drift's coordinate shift for every frame
    dx = cumsum( dxList );
    dy = cumsum( dyList );

    close( hWaitBar );

    driftCoords.x_coord = dx';
    driftCoords.y_coord = dy';

end


function dir = estimateLinearDrift( x_coords, y_coords, frameIndices, x_stds, y_stds, x0 )
    precisions = max( [x_stds y_stds], [], 2 );

    %     percent = 20000/length(x_coords); % 0.1;
    %     WINDOW = 100;
    %
    %     %% Filter intensities
    %     map = false(size(precisions));
    %     for i1 = 1:WINDOW:length(precisions)
    %         i2 = min(i1-1+WINDOW, length(precisions));
    %
    %         precs = precisions( i1:i2 );
    %         sortedPrecs = sort( precs, 'ascend' ); ????????????? check!!!
    %         precThreshold = sortedPrecs( ceil( (i2-i1+1)*percent ) );
    %
    %         map( i1:i2 ) = precs >= precThreshold;
    %     end

    %     fprintf( 1, '%d points selected of %d\n', numel( find(map) ), length(x_coords) );
    %     x_coords = x_coords(map);
    %     y_coords = y_coords(map);
    %     precisions =  precisions(map);
    %     intensities = intensities(map);
    %     frameIndices = frameIndices(map);

    DX = [];
    DY = [];
    DT = [];
    N = length( x_coords );

    mydist = @(v1, v2) bsxfun(@minus, v1', v2);

    points = 1000;
    nPairs = 500000;

    distThresh = 2;
    timeThresh = (max(frameIndices)-min(frameIndices)) * 0.5;

    BLOCKS = ceil(N / points );

    % % % %     orderedIndices = randperm( N );
    % order the localizations by their precision value in ascending order
    [~, orderedIndices] = sort( precisions, 'ascend' );
    tic
    % forming the blocks in all possible combination without repetition
    for ii = 1:BLOCKS*2
        for jj = 1:BLOCKS+1
            i = ii - (jj-1);
            j = jj;

            if i < 1 || j < 1 || i > BLOCKS || j > BLOCKS || j > i
                continue
            end

            ix1_1 = 1+(i-1)*points;
            ix1_2 = min(N, ix1_1+points-1);
            ix2_1 = 1+(j-1)*points;
            ix2_2 = min( [N; ix2_1+points-1; ix2_1+(ix1_2-ix1_1)] );

            % localization indices belonging to the block combination
            indices = orderedIndices( ix1_1:ix1_2 );
            indices2 = orderedIndices(ix2_1:ix2_2 );

            % calculate the localization distances
            DX_ = mydist( x_coords( indices ), x_coords( indices2 ) );
            DY_ = mydist( y_coords( indices ), y_coords( indices2 ) );
            DT_ = mydist( frameIndices( indices ), frameIndices( indices2 ) );

            D = DX_.^2 + DY_.^2;

            % select the localization pairs of the block combination
            map = D < distThresh^2 & abs(DT_) > timeThresh;

            DX = [DX; DX_(map)];
            DY = [DY; DY_(map)];
            DT = [DT; DT_(map)];

            if length(DX) > nPairs
                break;
            end
        end
        if length(DX) > nPairs
            break;
        end

    end
    toc
    fprintf( 1, 'Selected point pairs: %d\n', length(DX) );


    options = optimoptions('fmincon');
    %     options.OutputFcn = @callbackFunction;
    options.TolFun = 1e-15;
    options.DiffMaxChange = 1e-2;
    options.DiffMinChange = 1e-6;
    options.MaxIter = 200;
    options.Display = 'none'; % 'iter';

    lb = [-0.002 -0.002];% -0.0; 1.0];
    ub = [ 0.002  0.002];%  0.0; 1.0];

    if ~exist('x0', 'var')
        x0 = [ 0.0  0.0];%  0.0; 1.0];
    end


    % perform the cross-correlation iteratively
    sList = 2.^[0:-1:-2];
    for s = sList
        f = @(dir) fun(dir, s, DX, DY, DT );

        % perform the cross-correlation (sum of Gaussian functions fitting and maximizing)
        dir = fmincon( f, x0, [], [], [], [], lb, ub, [], options );
        %         fprintf( 1, 'dx: %.8f   dy:%.8f\n', dir(1), dir(2)  );

        x0 = dir;
    end
end


function stop = callbackFunction(dir, optimValues, state)
    stop = false;

    fprintf( 1, 'dx: %.8f   dy:%.8f\n', dir(1), dir(2) );
end


function r = fun( dir, s, DX, DY, DT )
    r = -sum(  exp( -( (DX - dir(1)*DT).^2  + (DY - dir(2)*DT).^2 )/ (2*s^2) )    );
end

% Detailed description

% This drift correction estimates the drift purely based on the
% localization coordinates of the close frames of the image stack.

% Since the algorithm can only calculate linear drift, it divides the frame
% stack into smaller sections, determines the linear drift for every
% section and the resultant drift is calculated from the drift velocities
% of the subsequent sections. The drift correction basically performs
% cross-correlation on the localizations with Gaussian kernel applied on
% them. For performance reasons, the correlation is performed only on
% selected localization pairs. The lag (displacement) of the correlation
% function of each localization pair is calculated from by the drift
% velocity, from the frame difference and from the spatial distance of the
% localization pair. The linear drift velocity is calculated from the
% maximum of the sum of the correlation functions of the localization
% pairs.

% In more detail, the drift correction algorithm goes through the frames
% with a step size of 1000 frames ("STEP" and "W" variables) and at every
% step, cuts out 2001 frames around the actual one symmetrically. So the
% overlap between adjacent cuts is 1000 frames. The actual drift direction
% is calculated from this cut-out stack, so this step size is the
% resolution of the drift calculation.
% The algorithm then finds the localizations belonging to this cut out
% stack ("map" variable), the localizations are ordered with their
% localization precision value in ascending order ("orderedIndices") and
% divides them into block with size of 1000 ("points") localizations
% ("indices" and "indices2") (the last block can be smaller) (so the
% localizations with better precision are taken into account in the drift
% correction with higher priority (although with the same weight) than the
% localizations with worse precision, thus more likely accounted for when
% there are many localizations in the cut-out 2001 frames, because of the
% localization pair number thresholding).
% Pairs of blocks are formed in all possible combination (using the "ii",
% "jj", "i" and "j" variables), without repetition and every block can form
% a pair with itself. The distances ("DX_", "DY_") and frame index
% differences ("DT_") of the localizations are calculated in every block
% pair. Next, a filtering step selects localization pairs (another "map"
% variable) for the correlation (the spatial distance must be lower than 2
% frames ("distThresh") and the frame index difference ("timeThresh") must
% be higher than half of the number of frames in the cut-out stack). If
% more than 500000 ("nPairs") localization pairs are formed, the algorithm
% does not take into account any more localization pairs (for performance
% reasons?).
% Next the cross-correlation is performed. The sum of Gaussian functions
% (the "fun()" function) of the quantity of the
% [“localization distance” - “frame difference” * “drift velocity”]
% is minimized with the Matlab’s “fmincon” iterative function with respect
% to the drift velocity ("dir" variable). The lower ("lb") and the upper
% bound ("ub") of the “x” and “y” components of the drift velocity during
% the iteration is -0.001 pixel/frame and 0.001 pixel/frame, respectively.
% This whole correlation is also done iteratively, the first step is done
% with Gaussian kernel standard deviation value of 1 frame, then of 0.5
% frame and finally of 0.25 frame ("sList" variable), and the initial drift
% velocity is taken from the previous step for the “fmincon” function (for
% increased reliability?).
% After the drift velocities are determined for every 1000th frames, they
% are interpolated (extrapolated if needed) for each frame ("dxList" and
% "dyList") and the drift coordinates ("dx" and "dy") are given by the
% cumulative sum of these velocities.

