classdef driftCorrection

    methods (Static)
    
        function [pointillisticData_accepted, pointillisticData_rejected, driftTrajectory, metadata, sharedParameters] = correct(pointillisticData_accepted, pointillisticData_rejected, driftCorrectionMethod, driftCorrectionSettings)
            % This function perform both the drift trajectory calculation
            % and its subraction from the pointillistic data

            driftTrajectory = driftCorrection.calculate(pointillisticData_accepted, driftCorrectionMethod, driftCorrectionSettings);
            pointillisticData_accepted = driftCorrection.subtract(pointillisticData_accepted, driftTrajectory);
            pointillisticData_rejected = driftCorrection.subtract(pointillisticData_rejected, driftTrajectory);

            metadata=struct();
            metadata.totalDrift.x=abs(driftTrajectory.x_coord(end)-driftTrajectory.x_coord(1));
            metadata.totalDrift.y=abs(driftTrajectory.y_coord(end)-driftTrajectory.y_coord(1));
            metadata.totalDrift.unit='camera pixel length';

            sharedParameters = struct();

        end

        function driftTrajectory = calculate(localizationData, driftCorrectionMethod, driftCorrectionSettings)
            % This function performs the drift trajectory calculation on the raw localized data.
            
            switch driftCorrectionMethod
                case 'fiducial marker'
                    driftTrajectory = driftCalculation_fiduciary(localizationData, driftCorrectionSettings);
                case 'rainSTORM cross correlation'
                    % using the rainSTORM's auto drift correction
                    
                    % running the drift correction
                    driftTrajectory = driftCalculation_rainSTORM(localizationData, driftCorrectionSettings);
                case 'BaSDI'
                    driftTrajectory = driftCalculation_BaSDI(localizationData, driftCorrectionSettings);
                otherwise
                    error('Unknown drift correction type.')
            end
        end
        
        function localizationData = subtract(localizationData, driftTrajectory)
            % This function subtracts the drift trajectory from the
            % pontillistic data coordinates.
            
            
            % the the frame to which the pointillistix datas belong
            frameIndices=localizationData.frame_idx;
            
            % field names the drift calculation was performed on (combination of "x", "y" and "z")
            driftFieldNames=fieldnames(driftTrajectory);
            % go though the fields of the drift calculation ("x", "y" or "z")
            for idxField=1:numel(driftFieldNames)
                % subtract the drift
                localizationData.(driftFieldNames{idxField})=localizationData.(driftFieldNames{idxField}) - driftTrajectory.(driftFieldNames{idxField})(frameIndices);
            end
            
            
%                 % field names the drift calculation was performed on (combination of "x", "y" and "z")
%                 driftFieldNames=fieldnames(driftTrajectory);
%                 % go though the pointillistic data
%                 for idxLoc=1:numel(localizationData.(driftFieldNames{1}))
%                     % the the frame to whichthe actual data belongs
%                     idxFrame=localizationData.frame_idx(idxLoc);
%                     % go though the fields of the drift calculatino ("x", "y" or "z")
%                     for idxField=1:numel(driftFieldNames)
%                         % subtract the drift
%                         localizationData.(driftFieldNames{idxField})(idxLoc)=localizationData.(driftFieldNames{idxField})(idxLoc) - driftTrajectory.(driftFieldNames{idxField})(idxFrame);
%                     end
%                 end
        end
        
        
        function localizationData = add(localizationData, driftTrajectory)
            % This function adds the drift trajectory from the
            % pontillistic data coordinates. Useful for undoing the the
            % drift correction.
            
            
            % the the frame to which the pointillistix datas belong
            frameIndices=localizationData.frame_idx;
            
            % field names the drift calculation was performed on (combination of "x", "y" and "z")
            driftFieldNames=fieldnames(driftTrajectory);
            % go though the fields of the drift calculatino ("x", "y" or "z")
            for idxField=1:numel(driftFieldNames)
                % subtract the drift
                localizationData.(driftFieldNames{idxField})=localizationData.(driftFieldNames{idxField}) + driftTrajectory.(driftFieldNames{idxField})(frameIndices);
            end
            
            
%                 % field names the drift calculation was performed on (combination of "x", "y" and "z")
%                 driftFieldNames=fieldnames(driftTrajectory);
%                 % go though the pointillistic data
%                 for idxLoc=1:numel(localizationData.(driftFieldNames{1}))
%                     % the the frame to whichthe actual data belongs
%                     idxFrame=localizationData.frame_idx(idxLoc);
%                     % go though the fields of the drift calculatino ("x", "y" or "z")
%                     for idxField=1:numel(driftFieldNames)
%                         % subtract the drift
%                         localizationData.(driftFieldNames{idxField})(idxLoc)=localizationData.(driftFieldNames{idxField})(idxLoc) + driftTrajectory.(driftFieldNames{idxField})(idxFrame);
%                     end
%                 end
        end

        
        function driftFigure = visualize(driftTrajectory)
            % Plot the drift trajectrory with color coding the frame
            % indices
            
            nFrames=numel(driftTrajectory.x_coord);
            totalDrift_x=abs(driftTrajectory.x_coord(end)-driftTrajectory.x_coord(1));
            totalDrift_y=abs(driftTrajectory.y_coord(end)-driftTrajectory.y_coord(1));


            driftFigure = figure('Visible', 'Off');
            scatter( driftTrajectory.y_coord, -driftTrajectory.x_coord, 12, jet(nFrames), '.');
            title(['Number of Frames: ', num2str(nFrames)]);
            xlabel(['Total drift, horizontal: ', num2str(totalDrift_y), ' [pixel]']);
            ylabel(['Total drift, vertical: ', num2str(totalDrift_x), ' [pixel]']);


        end
        
    end

end

