From c441b86ea994d65f50c5dec30c9c97250d88ac98 Mon Sep 17 00:00:00 2001 From: David Monniaux Date: Fri, 7 Jun 2019 15:48:49 +0200 Subject: réseau de neurones MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/monniaux/genann/example4shorter.c | 141 +++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 test/monniaux/genann/example4shorter.c (limited to 'test/monniaux/genann/example4shorter.c') diff --git a/test/monniaux/genann/example4shorter.c b/test/monniaux/genann/example4shorter.c new file mode 100644 index 00000000..ff4ce402 --- /dev/null +++ b/test/monniaux/genann/example4shorter.c @@ -0,0 +1,141 @@ +#include +#include +#include +#include +#include +#include "genann.h" + +#define VERIMAG +#ifdef VERIMAG +#include "../clock.h" +#endif + +/* This example is to illustrate how to use GENANN. + * It is NOT an example of good machine learning techniques. + */ + +const char *iris_data = "example/iris.data"; + +double *input, *class; +int samples; +const char *class_names[] = {"Iris-setosa", "Iris-versicolor", "Iris-virginica"}; + +void load_data() { + /* Load the iris data-set. */ + FILE *in = fopen("example/iris.data", "r"); + if (!in) { + printf("Could not open file: %s\n", iris_data); + exit(1); + } + + /* Loop through the data to get a count. */ + char line[1024]; + while (!feof(in) && fgets(line, 1024, in)) { + ++samples; + } + fseek(in, 0, SEEK_SET); + + printf("Loading %d data points from %s\n", samples, iris_data); + + /* Allocate memory for input and output data. */ + input = malloc(sizeof(double) * samples * 4); + class = malloc(sizeof(double) * samples * 3); + + /* Read the file into our arrays. */ + int i, j; + for (i = 0; i < samples; ++i) { + double *p = input + i * 4; + double *c = class + i * 3; + c[0] = c[1] = c[2] = 0.0; + + if (fgets(line, 1024, in) == NULL) { + perror("fgets"); + exit(1); + } + + char *split = strtok(line, ","); + for (j = 0; j < 4; ++j) { + p[j] = atof(split); + split = strtok(0, ","); + } + + split[strlen(split)-1] = 0; + if (strcmp(split, class_names[0]) == 0) {c[0] = 1.0;} + else if (strcmp(split, class_names[1]) == 0) {c[1] = 1.0;} + else if (strcmp(split, class_names[2]) == 0) {c[2] = 1.0;} + else { + printf("Unknown class %s.\n", split); + exit(1); + } + + /* printf("Data point %d is %f %f %f %f -> %f %f %f\n", i, p[0], p[1], p[2], p[3], c[0], c[1], c[2]); */ + } + + fclose(in); +} + + +int main(int argc, char *argv[]) +{ + printf("GENANN example 4.\n"); + printf("Train an ANN on the IRIS dataset using backpropagation.\n"); + +#ifdef VERIMAG + srand(42); +#else + srand(time(0)); +#endif + + /* Load the data from file. */ + load_data(); + + /* 4 inputs. + * 1 hidden layer(s) of 4 neurons. + * 3 outputs (1 per class) + */ + genann *ann = genann_init(4, 1, 4, 3); + + int i, j; +#ifdef VERIMAG + int loops = 500; +#else + int loops = 5000; +#endif + + /* Train the network with backpropagation. */ + printf("Training for %d loops over data.\n", loops); +#ifdef VERIMAG + clock_prepare(); + clock_start(); +#endif + for (i = 0; i < loops; ++i) { + for (j = 0; j < samples; ++j) { + genann_train(ann, input + j*4, class + j*3, .01); + } + /* printf("%1.2f ", xor_score(ann)); */ + } + + int correct = 0; + for (j = 0; j < samples; ++j) { + const double *guess = genann_run(ann, input + j*4); + if (class[j*3+0] == 1.0) {if (guess[0] > guess[1] && guess[0] > guess[2]) ++correct;} + else if (class[j*3+1] == 1.0) {if (guess[1] > guess[0] && guess[1] > guess[2]) ++correct;} + else if (class[j*3+2] == 1.0) {if (guess[2] > guess[0] && guess[2] > guess[1]) ++correct;} + else {printf("Logic error.\n"); exit(1);} + } +#ifdef VERIMAG + clock_stop(); +#endif + + printf("%d/%d correct (%0.1f%%).\n", correct, samples, (double)correct / samples * 100.0); + +#ifdef VERIMAG + print_total_clock(); +#endif + + genann_free(ann); + free(input); + free(class); + + return 0; +} -- cgit