关于改变kNN算法中k的值:改变kNN算法中k的值-Java

Altering the value of k in kNN algorithm - Java

我已应用 KNN 算法对手写数字进行分类。数字最初是 8*8 的矢量格式,然后拉伸形成一个 1*64 的矢量。

就目前而言,我的代码应用了 kNN 算法,但只使用了 k = 1。在尝试了几件事后,我不完全确定如何更改 k 值,但我一直在抛出错误。如果有人能帮助我朝着正确的方向前进,我将不胜感激。训练数据集可以在这里找到,验证集在这里。

ImageMatrix.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import java.util.*;

public class ImageMatrix {
    private int[] data;
    private int classCode;
    private int curData;
public ImageMatrix(int[] data, int classCode) {
    assert data.length == 64; //maximum array length of 64
    this.data = data;
    this.classCode = classCode;
}

    public String toString() {
        return"Class Code:" + classCode +" Data :" + Arrays.toString(data) +"\
"
; //outputs readable
    }

    public int[] getData() {
        return data;
    }

    public int getClassCode() {
        return classCode;
    }
    public int getCurData() {
        return curData;
    }



}

ImageMatrixDB.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import java.util.*;
import java.io.*;
import java.util.ArrayList;
public class ImageMatrixDB implements Iterable<ImageMatrix> {
    private List<ImageMatrix> list = new ArrayList<ImageMatrix>();

    public ImageMatrixDB load(String f) throws IOException {
        try (
            FileReader fr = new FileReader(f);
            BufferedReader br = new BufferedReader(fr)) {
            String line = null;

            while((line = br.readLine()) != null) {
                int lastComma = line.lastIndexOf(',');
                int classCode = Integer.parseInt(line.substring(1 + lastComma));
                int[] data = Arrays.stream(line.substring(0, lastComma).split(","))
                                   .mapToInt(Integer::parseInt)
                                   .toArray();
                ImageMatrix matrix = new ImageMatrix(data, classCode); // Classcode->100% when 0 -> 0% when 1 - 9..
                list.add(matrix);
            }
        }
        return this;
    }

    public void printResults(){ //output results
        for(ImageMatrix matrix: list){
            System.out.println(matrix);
        }
    }


    public Iterator<ImageMatrix> iterator() {
        return this.list.iterator();
    }

    /// kNN implementation ///
    public static int distance(int[] a, int[] b) {
        int sum = 0;
        for(int i = 0; i < a.length; i++) {
            sum += (a[i] - b[i]) * (a[i] - b[i]);
        }
        return (int)Math.sqrt(sum);
    }


    public static int classify(ImageMatrixDB trainingSet, int[] curData) {
        int label = 0, bestDistance = Integer.MAX_VALUE;
        for(ImageMatrix matrix: trainingSet) {
            int dist = distance(matrix.getData(), curData);
            if(dist < bestDistance) {
                bestDistance = dist;
                label = matrix.getClassCode();
            }
        }
        return label;
    }


    public int size() {

        return list.size(); //returns size of the list

        }


    public static void main(String[] argv) throws IOException {
        ImageMatrixDB trainingSet = new ImageMatrixDB();
        ImageMatrixDB validationSet = new ImageMatrixDB();
        trainingSet.load("cw2DataSet1.csv");
        validationSet.load("cw2DataSet2.csv");
        int numCorrect = 0;
        for(ImageMatrix matrix:validationSet) {
            if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;
        } //285 correct
        System.out.println("Accuracy:" + (double)numCorrect / validationSet.size() * 100 +"%");
        System.out.println();
    }


在分类的 for 循环中,您试图找到最接近测试点的训练示例。您需要使用找到最接近测试数据的 K 个训练点的代码来切换它。然后你应该为这些 K 点中的每一个调用 getClassCode 并找到其中大多数(即最频繁)的类代码。然后,分类将返回您找到的主要类代码。

您可以以任何适合您需要的方式打破联系(即,将 2 个最常见的类代码分配给相同数量的训练数据)。

我在Java方面真的很缺乏经验,但是只是通过查看语言参考,我想出了下面的实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
public static int classify(ImageMatrixDB trainingSet, int[] curData, int k) {
    int label = 0, bestDistance = Integer.MAX_VALUE;
    int[][] distances = new int[trainingSet.size()][2];
    int i=0;

    // Place distances in an array to be sorted
    for(ImageMatrix matrix: trainingSet) {
        distances[i][0] = distance(matrix.getData(), curData);
        distances[i][1] = matrix.getClassCode();
        i++;
    }

    Arrays.sort(distances, (int[] lhs, int[] rhs) -> lhs[0]-rhs[0]);

    // Find frequencies of each class code
    i = 0;
    Map<Integer,Integer> majorityMap;
    majorityMap = new HashMap<Integer,Integer>();
    while(i < k) {
        if( majorityMap.containsKey( distances[i][1] ) ) {
            int currentValue = majorityMap.get(distances[i][1]);
            majorityMap.put(distances[i][1], currentValue + 1);
        }
        else {
            majorityMap.put(distances[i][1], 1);
        }
        ++i;
    }

    // Find the class code with the highest frequency
    int maxVal = -1;
    for (Entry<Integer, Integer> entry: majorityMap.entrySet()) {
        int entryVal = entry.getValue();
        if(entryVal > maxVal) {
            maxVal = entryVal;
            label = entry.getKey();
        }
    }

    return label;
}

您需要做的就是添加 K 作为参数。但是请记住,上面的代码并没有以特定方式处理关系。