use strict;
use AI::NNFlex::Backprop;
use AI::NNFlex::Dataset;
use Data::Dumper;
my $n = 0.4;
my $num_epochs = 100;
my $network = AI::NNFlex::Backprop->new(learningrate=>.9,
bias=>1,
);
$network->add_layer(nodes=>3,activationfunction=>'tanh');
#$network->add_layer(nodes=>3,activationfunction=>'tanh');
#$network->add_layer(nodes=>2,activationfunction=>'tanh');
#$network->add_layer(nodes=>3,activationfunction=>'tanh');
$network->add_layer(nodes=>5,activationfunction=>'tanh');
$network->add_layer(nodes=>2,activationfunction=>'sigmoid');
$network->init();
my $test_set = AI::NNFlex::Dataset->new([
[6.28318,1.570795,0], [1,0],
[6.28318,1.570795,1.570795], [0,-1],
[6.28318,1.570795,3.14159], [-1,0],
[6.28318,1.570795,4.712385], [0,1],
[6.28318,1.570795,6.28318], [1,0],
[6.28318,1.570795,7.853975], [0,-1],
[6.28318,3.14159,0], [0,-1],
[6.28318,3.14159,1.570795], [-1,0],
[6.28318,3.14159,3.14159], [0,1],
[6.28318,3.14159,4.712385], [1,0],
[6.28318,3.14159,6.28318], [0,-1],
[6.28318,3.14159,7.853975], [-1,0],
[6.28318,4.712385,0], [-1,0],
[6.28318,4.712385,1.570795], [0,1],
[6.28318,4.712385,3.14159], [1,0],
[6.28318,4.712385,4.712385], [0,-1],
[6.28318,4.712385,6.28318], [-1,0],
[6.28318,4.712385,7.853975], [0,1],
[6.28318,6.28318,0], [0,1],
[6.28318,6.28318,1.570795], [1,0],
[6.28318,6.28318,3.14159], [0,-1],
[6.28318,6.28318,4.712385], [-1,0],
[6.28318,6.28318,6.28318], [0,1],
[6.28318,6.28318,7.853975], [1,0],
[6.28318,7.853975,0], [1,0],
[6.28318,7.853975,1.570795], [0,-1],
[6.28318,7.853975,3.14159], [-1,0],
[6.28318,7.853975,4.712385], [0,1],
[6.28318,7.853975,6.28318], [1,0],
[6.28318,7.853975,7.853975], [0,-1],
[7.853975,0,0], [1,0],
[7.853975,0,1.570795], [0,-1],
[7.853975,0,3.14159], [-1,0],
[7.853975,0,4.712385], [0,1],
[7.853975,0,6.28318], [1,0],
[7.853975,0,7.853975], [0,-1],
[7.853975,1.570795,0], [0,-1],
[7.853975,1.570795,1.570795], [-1,0],
[7.853975,1.570795,3.14159], [0,1],
[7.853975,1.570795,4.712385], [1,0],
[7.853975,1.570795,6.28318], [0,-1],
[7.853975,1.570795,7.853975], [-1,0],
[7.853975,3.14159,0], [-1,0],
[7.853975,3.14159,1.570795], [0,1],
[7.853975,3.14159,3.14159], [1,0],
[7.853975,3.14159,4.712385], [0,-1],
[7.853975,3.14159,6.28318], [-1,0],
[7.853975,3.14159,7.853975], [0,1],
[7.853975,4.712385,0], [0,1],
[7.853975,4.712385,1.570795], [1,0],
[7.853975,4.712385,3.14159], [0,-1],
[7.853975,4.712385,4.712385], [-1,0],
[7.853975,4.712385,6.28318], [0,1],
[7.853975,4.712385,7.853975], [1,0],
[7.853975,6.28318,0], [1,0],
[7.853975,6.28318,1.570795], [0,-1],
[7.853975,6.28318,3.14159], [-1,0],
[7.853975,6.28318,4.712385], [0,1],
[7.853975,6.28318,6.28318], [1,0],
[7.853975,6.28318,7.853975], [0,-1],
[7.853975,7.853975,0], [0,-1],
[7.853975,7.853975,1.570795], [-1,0],
[7.853975,7.853975,3.14159], [0,1],
[7.853975,7.853975,4.712385], [1,0],
[7.853975,7.853975,6.28318], [0,-1],
[7.853975,7.853975,7.853975], [-1,0]
]);
my $train_set = AI::NNFlex::Dataset->new([
[0,0,0], [0,1],
[0,0,1.570795], [1,0],
[0,0,3.14159], [0,-1],
[0,0,4.712385], [-1,0],
[0,0,6.28318], [0,1],
[0,0,7.853975], [1,0],
[0,1.570795,0], [1,0],
[0,1.570795,1.570795], [0,-1],
[0,1.570795,3.14159], [-1,0],
[0,1.570795,4.712385], [0,1],
[0,1.570795,6.28318], [1,0],
[0,1.570795,7.853975], [0,-1],
[0,3.14159,0], [0,-1],
[0,3.14159,1.570795], [-1,0],
[0,3.14159,3.14159], [0,1],
[0,3.14159,4.712385], [1,0],
[0,3.14159,6.28318], [0,-1],
[0,3.14159,7.853975], [-1,0],
[0,4.712385,0], [-1,0],
[0,4.712385,1.570795], [0,1],
[0,4.712385,3.14159], [1,0],
[0,4.712385,4.712385], [0,-1],
[0,4.712385,6.28318], [-1,0],
[0,4.712385,7.853975], [0,1],
[0,6.28318,0], [0,1],
[0,6.28318,1.570795], [1,0],
[0,6.28318,3.14159], [0,-1],
[0,6.28318,4.712385], [-1,0],
[0,6.28318,6.28318], [0,1],
[0,6.28318,7.853975], [1,0],
[0,7.853975,0], [1,0],
[0,7.853975,1.570795], [0,-1],
[0,7.853975,3.14159], [-1,0],
[0,7.853975,4.712385], [0,1],
[0,7.853975,6.28318], [1,0],
[0,7.853975,7.853975], [0,-1],
[1.570795,0,0], [1,0],
[1.570795,0,1.570795], [0,-1],
[1.570795,0,3.14159], [-1,0],
[1.570795,0,4.712385], [0,1],
[1.570795,0,6.28318], [1,0],
[1.570795,0,7.853975], [0,-1],
[1.570795,1.570795,0], [0,-1],
[1.570795,1.570795,1.570795], [-1,0],
[1.570795,1.570795,3.14159], [0,1],
[1.570795,1.570795,4.712385], [1,0],
[1.570795,1.570795,6.28318], [0,-1],
[1.570795,1.570795,7.853975], [-1,0],
[1.570795,3.14159,0], [-1,0],
[1.570795,3.14159,1.570795], [0,1],
[1.570795,3.14159,3.14159], [1,0],
[1.570795,3.14159,4.712385], [0,-1],
[1.570795,3.14159,6.28318], [-1,0],
[1.570795,3.14159,7.853975], [0,1],
[1.570795,4.712385,0], [0,1],
[1.570795,4.712385,1.570795], [1,0],
[1.570795,4.712385,3.14159], [0,-1],
[1.570795,4.712385,4.712385], [-1,0],
[1.570795,4.712385,6.28318], [0,1],
[1.570795,4.712385,7.853975], [1,0],
[1.570795,6.28318,0], [1,0],
[1.570795,6.28318,1.570795], [0,-1],
[1.570795,6.28318,3.14159], [-1,0],
[1.570795,6.28318,4.712385], [0,1],
[1.570795,6.28318,6.28318], [1,0],
[1.570795,6.28318,7.853975], [0,-1],
[1.570795,7.853975,0], [0,-1],
[1.570795,7.853975,1.570795], [-1,0],
[1.570795,7.853975,3.14159], [0,1],
[1.570795,7.853975,4.712385], [1,0],
[1.570795,7.853975,6.28318], [0,-1],
[1.570795,7.853975,7.853975], [-1,0],
[3.14159,0,0], [0,-1],
[3.14159,0,1.570795], [-1,0],
[3.14159,0,3.14159], [0,1],
[3.14159,0,4.712385], [1,0],
[3.14159,0,6.28318], [0,-1],
[3.14159,0,7.853975], [-1,0],
[3.14159,1.570795,0], [-1,0],
[3.14159,1.570795,1.570795], [0,1],
[3.14159,1.570795,3.14159], [1,0],
[3.14159,1.570795,4.712385], [0,-1],
[3.14159,1.570795,6.28318], [-1,0],
[3.14159,1.570795,7.853975], [0,1],
[3.14159,3.14159,0], [0,1],
[3.14159,3.14159,1.570795], [1,0],
[3.14159,3.14159,3.14159], [0,-1],
[3.14159,3.14159,4.712385], [-1,0],
[3.14159,3.14159,6.28318], [0,1],
[3.14159,3.14159,7.853975], [1,0],
[3.14159,4.712385,0], [1,0],
[3.14159,4.712385,1.570795], [0,-1],
[3.14159,4.712385,3.14159], [-1,0],
[3.14159,4.712385,4.712385], [0,1],
[3.14159,4.712385,6.28318], [1,0],
[3.14159,4.712385,7.853975], [0,-1],
[3.14159,6.28318,0], [0,-1],
[3.14159,6.28318,1.570795], [-1,0],
[3.14159,6.28318,3.14159], [0,1],
[3.14159,6.28318,4.712385], [1,0],
[3.14159,6.28318,6.28318], [0,-1],
[3.14159,6.28318,7.853975], [-1,0],
[3.14159,7.853975,0], [-1,0],
[3.14159,7.853975,1.570795], [0,1],
[3.14159,7.853975,3.14159], [1,0],
[3.14159,7.853975,4.712385], [0,-1],
[3.14159,7.853975,6.28318], [-1,0],
[3.14159,7.853975,7.853975], [0,1],
[4.712385,0,0], [-1,0],
[4.712385,0,1.570795], [0,1],
[4.712385,0,3.14159], [1,0],
[4.712385,0,4.712385], [0,-1],
[4.712385,0,6.28318], [-1,0],
[4.712385,0,7.853975], [0,1],
[4.712385,1.570795,0], [0,1],
[4.712385,1.570795,1.570795], [1,0],
[4.712385,1.570795,3.14159], [0,-1],
[4.712385,1.570795,4.712385], [-1,0],
[4.712385,1.570795,6.28318], [0,1],
[4.712385,1.570795,7.853975], [1,0],
[4.712385,3.14159,0], [1,0],
[4.712385,3.14159,1.570795], [0,-1],
[4.712385,3.14159,3.14159], [-1,0],
[4.712385,3.14159,4.712385], [0,1],
[4.712385,3.14159,6.28318], [1,0],
[4.712385,3.14159,7.853975], [0,-1],
[4.712385,4.712385,0], [0,-1],
[4.712385,4.712385,1.570795], [-1,0],
[4.712385,4.712385,3.14159], [0,1],
[4.712385,4.712385,4.712385], [1,0],
[4.712385,4.712385,6.28318], [0,-1],
[4.712385,4.712385,7.853975], [-1,0],
[4.712385,6.28318,0], [-1,0],
[4.712385,6.28318,1.570795], [0,1],
[4.712385,6.28318,3.14159], [1,0],
[4.712385,6.28318,4.712385], [0,-1],
[4.712385,6.28318,6.28318], [-1,0],
[4.712385,6.28318,7.853975], [0,1],
[4.712385,7.853975,0], [0,1],
[4.712385,7.853975,1.570795], [1,0],
[4.712385,7.853975,3.14159], [0,-1],
[4.712385,7.853975,4.712385], [-1,0],
[4.712385,7.853975,6.28318], [0,1],
[4.712385,7.853975,7.853975], [1,0],
[6.28318,0,0], [0,1],
[6.28318,0,1.570795], [1,0],
[6.28318,0,3.14159], [0,-1],
[6.28318,0,4.712385], [-1,0],
[6.28318,0,6.28318], [0,1],
[6.28318,0,7.853975], [1,0]
]);
my $epoch = 1;
my $err = 1;
while($err > .001 && $epoch < 100) {
$err = $train_set->learn($network);
my $outputsRef = $test_set->run($network);
print Dumper($outputsRef);
print "Error: $err\n";
$epoch++;
}
The output of the network with the test set gives the following.
$ perl test1.pl
$VAR1 = [
[
'2.22776546277668e-07',
'0.011408329955622'
],
[
'2.22776546277668e-07',
'0.011408329955622'
],
[
'2.22776546277668e-07',
'0.011408329955622'
],
[
'2.22776546277668e-07',
'0.011408329955622'
],
[
'2.22776546277668e-07',
'0.011408329955622'
],
[
'2.22776546277668e-07',
'0.011408329955622'
],
....
....
Am I handling the output of the network in correctly? The module says that Runs the dataset through the network and returns a reference to an array of output patterns. I guess I am not handling the reference array correctly.
Thanks for all the help. |