-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathParseFile.java
More file actions
101 lines (76 loc) · 3.49 KB
/
ParseFile.java
File metadata and controls
101 lines (76 loc) · 3.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package LogisticRegression;
import java.io.*;
import java.util.*;
public class ParseFile {
public static void main(String[] args) {
FileReader fr;
BufferedReader br;
try {
fr = new FileReader(new File("logistic_input.txt"));
br = new BufferedReader(fr);
String trainingSetPath = br.readLine();
String testSetPath = br.readLine();
double learningRate = Double.parseDouble(br.readLine());
boolean isCrossValidation = Boolean.parseBoolean(br.readLine());
int folds = Integer.parseInt(br.readLine());
br.close();
Data data = new Data();
CrossValidation crossValidation = new CrossValidation(trainingSetPath);
List<Double> accuracyList = new ArrayList<Double>();
for (int i = 0; i < folds; i++) {
ArrayList<Example> examples = crossValidation.runCrossValidation(folds, data, i, isCrossValidation);
LogisticTrain logisticTrain = new LogisticTrain(data.getNumberOfCols(), learningRate, data);
for (String label : data.getDistinctLabels()) {
System.out.println("Building a model for " + label);
ArrayList<Example> trainingExamples = makeTrainData(examples, label, data);
logisticTrain.train(trainingExamples, label);
}
logisticTrain.printModels();
if (isCrossValidation) {
double acc = logisticTrain.classify(crossValidation.foldTestList, new HashMap<String, ArrayList<Double>>());
accuracyList.add(acc);
} else {
TestData testData = new TestData(testSetPath);
ArrayList<Example> testExamples = testData.testData();
double acc = logisticTrain.classify(testExamples, new HashMap<String, ArrayList<Double>>());
accuracyList.add(acc);
}
}
printAllAccuracy(accuracyList);
} catch (FileNotFoundException e1) {
e1.printStackTrace();
} catch (IOException e1) {
e1.printStackTrace();
}
}
private static void printAllAccuracy(List<Double> accuracyList) {
double sum = 0;
for (Double aDouble : accuracyList) {
System.out.println(aDouble);
sum = sum + aDouble;
}
System.out.println("Accuracy = " + sum / accuracyList.size());
System.out.println("Error = " + (100 - (sum / accuracyList.size())));
}
private static ArrayList<Example> makeTrainData(ArrayList<Example> dataSet, String label, Data data) {
for (Example example : dataSet) {
if (example.getActualLabel().equals(label)) {
example.setLabel(1);
} else {
example.setLabel(0);
}
}
return dataSet;
}
private static void findDistinctValuesInCols(Data data, ArrayList<Example> dataSet) {
Set<Double> uniQueValues = new HashSet<Double>();
for (int i = 0; i < data.getNumberOfCols(); i++) {
for (Example example : dataSet) {
double value = example.getValues()[i];
uniQueValues.add(value);
}
data.getDistinctValuesPerColumn()[i] = uniQueValues.size();
uniQueValues.clear();
}
}
}