aboutsummaryrefslogtreecommitdiffstats
path: root/test/monniaux/genann/example4shorter.c
diff options
context:
space:
mode:
Diffstat (limited to 'test/monniaux/genann/example4shorter.c')
-rw-r--r--test/monniaux/genann/example4shorter.c141
1 files changed, 141 insertions, 0 deletions
diff --git a/test/monniaux/genann/example4shorter.c b/test/monniaux/genann/example4shorter.c
new file mode 100644
index 00000000..ff4ce402
--- /dev/null
+++ b/test/monniaux/genann/example4shorter.c
@@ -0,0 +1,141 @@
+#include <stdio.h>
+#include <stdlib.h>
+#include <time.h>
+#include <string.h>
+#include <math.h>
+#include "genann.h"
+
+#define VERIMAG
+#ifdef VERIMAG
+#include "../clock.h"
+#endif
+
+/* This example is to illustrate how to use GENANN.
+ * It is NOT an example of good machine learning techniques.
+ */
+
+const char *iris_data = "example/iris.data";
+
+double *input, *class;
+int samples;
+const char *class_names[] = {"Iris-setosa", "Iris-versicolor", "Iris-virginica"};
+
+void load_data() {
+ /* Load the iris data-set. */
+ FILE *in = fopen("example/iris.data", "r");
+ if (!in) {
+ printf("Could not open file: %s\n", iris_data);
+ exit(1);
+ }
+
+ /* Loop through the data to get a count. */
+ char line[1024];
+ while (!feof(in) && fgets(line, 1024, in)) {
+ ++samples;
+ }
+ fseek(in, 0, SEEK_SET);
+
+ printf("Loading %d data points from %s\n", samples, iris_data);
+
+ /* Allocate memory for input and output data. */
+ input = malloc(sizeof(double) * samples * 4);
+ class = malloc(sizeof(double) * samples * 3);
+
+ /* Read the file into our arrays. */
+ int i, j;
+ for (i = 0; i < samples; ++i) {
+ double *p = input + i * 4;
+ double *c = class + i * 3;
+ c[0] = c[1] = c[2] = 0.0;
+
+ if (fgets(line, 1024, in) == NULL) {
+ perror("fgets");
+ exit(1);
+ }
+
+ char *split = strtok(line, ",");
+ for (j = 0; j < 4; ++j) {
+ p[j] = atof(split);
+ split = strtok(0, ",");
+ }
+
+ split[strlen(split)-1] = 0;
+ if (strcmp(split, class_names[0]) == 0) {c[0] = 1.0;}
+ else if (strcmp(split, class_names[1]) == 0) {c[1] = 1.0;}
+ else if (strcmp(split, class_names[2]) == 0) {c[2] = 1.0;}
+ else {
+ printf("Unknown class %s.\n", split);
+ exit(1);
+ }
+
+ /* printf("Data point %d is %f %f %f %f -> %f %f %f\n", i, p[0], p[1], p[2], p[3], c[0], c[1], c[2]); */
+ }
+
+ fclose(in);
+}
+
+
+int main(int argc, char *argv[])
+{
+ printf("GENANN example 4.\n");
+ printf("Train an ANN on the IRIS dataset using backpropagation.\n");
+
+#ifdef VERIMAG
+ srand(42);
+#else
+ srand(time(0));
+#endif
+
+ /* Load the data from file. */
+ load_data();
+
+ /* 4 inputs.
+ * 1 hidden layer(s) of 4 neurons.
+ * 3 outputs (1 per class)
+ */
+ genann *ann = genann_init(4, 1, 4, 3);
+
+ int i, j;
+#ifdef VERIMAG
+ int loops = 500;
+#else
+ int loops = 5000;
+#endif
+
+ /* Train the network with backpropagation. */
+ printf("Training for %d loops over data.\n", loops);
+#ifdef VERIMAG
+ clock_prepare();
+ clock_start();
+#endif
+ for (i = 0; i < loops; ++i) {
+ for (j = 0; j < samples; ++j) {
+ genann_train(ann, input + j*4, class + j*3, .01);
+ }
+ /* printf("%1.2f ", xor_score(ann)); */
+ }
+
+ int correct = 0;
+ for (j = 0; j < samples; ++j) {
+ const double *guess = genann_run(ann, input + j*4);
+ if (class[j*3+0] == 1.0) {if (guess[0] > guess[1] && guess[0] > guess[2]) ++correct;}
+ else if (class[j*3+1] == 1.0) {if (guess[1] > guess[0] && guess[1] > guess[2]) ++correct;}
+ else if (class[j*3+2] == 1.0) {if (guess[2] > guess[0] && guess[2] > guess[1]) ++correct;}
+ else {printf("Logic error.\n"); exit(1);}
+ }
+#ifdef VERIMAG
+ clock_stop();
+#endif
+
+ printf("%d/%d correct (%0.1f%%).\n", correct, samples, (double)correct / samples * 100.0);
+
+#ifdef VERIMAG
+ print_total_clock();
+#endif
+
+ genann_free(ann);
+ free(input);
+ free(class);
+
+ return 0;
+}