Quantcast
Channel: MATLAB Central Newsreader Recent Posts
Viewing all articles
Browse latest Browse all 19628

Re: KNN classifier with ROC Analysis

$
0
0
"Aaronne" wrote in message <ki9lon$qgq$1@newscl01ah.mathworks.com>...
> Hi Smart guys,
>
> I wrote following codes to get a plot of ROC for my KNN classifier:
>
> load fisheriris;
>
> features = meas;
> featureSelcted = features;
> numFeatures = size(meas,1);
>
> %% Define ground truth
> groundTruthGroup = species;
>
> %% Construct a KNN classifier
> KNNClassifierObject = ClassificationKNN.fit(featureSelcted, groundTruthGroup, 'NumNeighbors', 3, 'Distance', 'euclidean');
>
> % Predict resubstitution response of k-nearest neighbor classifier
> [KNNLabel, KNNScore] = resubPredict(KNNClassifierObject);
>
> % Fit probabilities for scores
> groundTruthNumericalLable = [ones(50,1); zeros(50,1); -1.*ones(50,1)];
> [FPR, TPR, Thr, AUC, OPTROCPT] = perfcurve(groundTruthNumericalLable(:,1), KNNScore(:,1), 1);
>
> Then we can plot the FPR vs TPR to get the ROC curve.
>
> However, the FPR and TPR is different from what I got using my own implementation that the one above will not display all the points, actually, the codes above display only three points on the ROC. The codes I implemented will dispaly 151 points on the ROC as the size of the data is 150.
>
> patternsKNN = [KNNScore(:,1), groundTruthNumericalLable(:,1)];
> patternsKNN = sortrows(patternsKNN, -1);
> groundTruthPattern = patternsKNN(:,2);
>
> POS = cumsum(groundTruthPattern==1);
> TPR = POS/sum(groundTruthPattern==1);
> NEG = cumsum(groundTruthPattern==0);
> FPR = NEG/sum(groundTruthPattern==0);
>
> FPR = [0; FPR];
> TPR = [0; TPR];
>
>
> May I ask how to tune '`perfcurve`' to let it output all the points for the ROC? Thanks a lot.
>
>
>
> A.

Hi, Aaronne,

What do you mean by "let it output all the points for the ROC"? I think the points that suffice to plot the ROC curve are in FPR, TPR.

I had write a script that does almost the same thing as perfcurve in Matlab, see below. Hope that will help.

function [X,Y,T,AUC]=calculate_ROC_3(labels,scores,posclass)
% sort in parallel
% in ROC curve, there is no room for ACC or BACC
% 2013-06-17

n=length(scores);
X=zeros(n,1);
Y=zeros(n,1);
T=zeros(n,1);
for i=1:n
    labels_predict=scores>=scores(i);
    
    % calculate the confusion matrix
    temp=labels*2+labels_predict;
    table=tabulate(temp);
    
    % if table doesn't have 4 rows, put them in correct location
    m=size(table,1);
    if m~=4
        temp=[[0:3]',zeros(4,2)];
        for j=1:m
            temp(table(j,1)+1,:)=table(j,:);
        end
        table=temp;
    end
    
    % statistics, see Wikipedia
    TN=table(1,2);
    FP=table(2,2);
    FN=table(3,2);
    TP=table(4,2);
    TPR=TP/(TP+FN); % sensitivity
    SPC=TN/(FP+TN); % specificity
    FPR=FP/(FP+TN); % false positive rate
    ACC=(TP+TN)/(TP+TN+FP+FN); % accuracy
    BACC=(TPR+SPC)/2; % balanced accuracy
    dPrime=norminv(TPR)-norminv(FPR); % d'
    
    X(i)=FPR;
    Y(i)=TPR;
    T(i)=scores(i);
end

% there are many equal number, "sort" will cause problems
ix=sort_parallel(X,Y);
X=X(ix);
Y=Y(ix);

% counting the begining and ending parts
X=[0;X];
Y=[0;Y];

AUC=0;
for i=1:length(X)-1
    AUC=AUC+(X(i+1)-X(i))*Y(i+1);
end

if posclass==0
    X=1-X;
    Y=1-Y;
    AUC=1-AUC;
end

function ix=sort_parallel(x,y)
% x, y are vectors with the same length
% sort (xi, yi) in order
% xi, yi are in [0,1]
% when xi>xj, yi>yj
% when xi=xj, yi and yj could be different
% 2013-06-17

% find the minimal difference in x, the one bigger than 0
xs=sort(x);
xd=diff(xs);
xu=unique(xd);
delta=xu(2); % the first one is usually 0

% scale x by a big enough ratio
% then x will dominate the order, y will take effects when x are the same
% x/delta will guaratee that the smallest difference in x will be scaled to
% be no smaller than 1. At the same time, the biggest difference in y will
% not be larger than 1. So it does what I think.
z=x/delta+y;
[~,ix]=sort(z);

Best Regards,
Jing.

Viewing all articles
Browse latest Browse all 19628

Trending Articles