Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ GCC=gcc
CFLAGS=-I ../src/include
LDFLAGS=-L ../src/

TARGETS = xor_train xor_test xor_test_fixed simple_train steepness_train simple_test robot mushroom cascade_train scaling_test scaling_train
DEBUG_TARGETS = xor_train_debug xor_test_debug xor_test_fixed_debug cascade_train_debug
TARGETS = xor_train xor_test xor_test_fixed simple_train steepness_train simple_test robot mushroom cascade_train scaling_test scaling_train robot_adam
DEBUG_TARGETS = xor_train_debug xor_test_debug xor_test_fixed_debug cascade_train_debug robot_adam_debug

all: $(TARGETS)

Expand Down
72 changes: 72 additions & 0 deletions examples/robot_adam.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
Fast Artificial Neural Network Library (fann)
Copyright (C) 2003-2016 Steffen Nissen (steffen.fann@gmail.com)

This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.

This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/

#include <stdio.h>

#include "fann.h"

int main()
{
const unsigned int num_layers = 3;
const unsigned int num_neurons_hidden = 96;
const float desired_error = (const float) 0.00001;
struct fann *ann;
struct fann_train_data *train_data, *test_data;

unsigned int i = 0;

printf("Creating network.\n");

train_data = fann_read_train_from_file("../datasets/robot.train");

ann = fann_create_standard(num_layers,
train_data->num_input, num_neurons_hidden, train_data->num_output);

printf("Training network.\n");

fann_set_training_algorithm(ann, FANN_TRAIN_ADAM);
fann_set_adam_beta1(ann, 0.9f);
fann_set_adam_beta2(ann, 0.999f);
fann_set_adam_epsilon(ann, 1e-8f);
fann_set_learning_rate(ann, 0.05f);

fann_train_on_data(ann, train_data, 10000, 100, desired_error);

printf("Testing network.\n");

test_data = fann_read_train_from_file("../datasets/robot.test");

fann_reset_MSE(ann);
for(i = 0; i < fann_length_train_data(test_data); i++)
{
fann_test(ann, test_data->input[i], test_data->output[i]);
}
printf("MSE error on test data: %f\n", fann_get_MSE(ann));

printf("Saving network.\n");

fann_save(ann, "robot_adam.net");

printf("Cleaning up.\n");
fann_destroy_train(train_data);
fann_destroy_train(test_data);
fann_destroy(ann);

return 0;
}
18 changes: 18 additions & 0 deletions src/fann.c
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,8 @@ FANN_EXTERNAL void FANN_API fann_destroy(struct fann *ann) {
fann_safe_free(ann->prev_train_slopes);
fann_safe_free(ann->prev_steps);
fann_safe_free(ann->prev_weights_deltas);
fann_safe_free(ann->adam_m);
fann_safe_free(ann->adam_v);
fann_safe_free(ann->errstr);
fann_safe_free(ann->cascade_activation_functions);
fann_safe_free(ann->cascade_activation_steepnesses);
Expand Down Expand Up @@ -888,6 +890,14 @@ FANN_EXTERNAL struct fann *FANN_API fann_copy(struct fann *orig) {
copy->rprop_delta_max = orig->rprop_delta_max;
copy->rprop_delta_zero = orig->rprop_delta_zero;

/* Copy Adam optimizer parameters */
copy->adam_beta1 = orig->adam_beta1;
copy->adam_beta2 = orig->adam_beta2;
copy->adam_epsilon = orig->adam_epsilon;
copy->adam_timestep = orig->adam_timestep;
copy->adam_m = NULL; /* These will be reallocated during training if needed */
copy->adam_v = NULL;
Comment thread
wu1045718093 marked this conversation as resolved.

/* user_data is not deep copied. user should use fann_copy_with_user_data() for that */
copy->user_data = orig->user_data;

Expand Down Expand Up @@ -1552,6 +1562,14 @@ struct fann *fann_allocate_structure(unsigned int num_layers) {
ann->sarprop_temperature = 0.015f;
ann->sarprop_epoch = 0;

/* Variables for use with Adam training (reasonable defaults) */
ann->adam_m = NULL;
ann->adam_v = NULL;
ann->adam_beta1 = 0.9f;
ann->adam_beta2 = 0.999f;
ann->adam_epsilon = 1e-8f;
ann->adam_timestep = 0;

fann_init_error_data((struct fann_error *)ann);

#ifdef FIXEDFANN
Expand Down
6 changes: 6 additions & 0 deletions src/fann_io.c
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ int fann_save_internal_fd(struct fann *ann, FILE *conf, const char *configuratio
fprintf(conf, "rprop_delta_min=%f\n", ann->rprop_delta_min);
fprintf(conf, "rprop_delta_max=%f\n", ann->rprop_delta_max);
fprintf(conf, "rprop_delta_zero=%f\n", ann->rprop_delta_zero);
fprintf(conf, "adam_beta1=%f\n", ann->adam_beta1);
fprintf(conf, "adam_beta2=%f\n", ann->adam_beta2);
fprintf(conf, "adam_epsilon=%f\n", ann->adam_epsilon);
fprintf(conf, "cascade_output_stagnation_epochs=%u\n", ann->cascade_output_stagnation_epochs);
fprintf(conf, "cascade_candidate_change_fraction=%f\n", ann->cascade_candidate_change_fraction);
fprintf(conf, "cascade_candidate_stagnation_epochs=%u\n",
Expand Down Expand Up @@ -420,6 +423,9 @@ struct fann *fann_create_from_fd(FILE *conf, const char *configuration_file) {
fann_scanf("%f", "rprop_delta_min", &ann->rprop_delta_min);
fann_scanf("%f", "rprop_delta_max", &ann->rprop_delta_max);
fann_scanf("%f", "rprop_delta_zero", &ann->rprop_delta_zero);
fann_scanf("%f", "adam_beta1", &ann->adam_beta1);
fann_scanf("%f", "adam_beta2", &ann->adam_beta2);
fann_scanf("%f", "adam_epsilon", &ann->adam_epsilon);
Comment thread
wu1045718093 marked this conversation as resolved.
Outdated
fann_scanf("%u", "cascade_output_stagnation_epochs", &ann->cascade_output_stagnation_epochs);
fann_scanf("%f", "cascade_candidate_change_fraction", &ann->cascade_candidate_change_fraction);
fann_scanf("%u", "cascade_candidate_stagnation_epochs",
Expand Down
97 changes: 96 additions & 1 deletion src/fann_train.c
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,34 @@ void fann_clear_train_arrays(struct fann *ann) {
} else {
memset(ann->prev_train_slopes, 0, (ann->total_connections_allocated) * sizeof(fann_type));
}

/* Allocate and initialize Adam optimizer arrays if using Adam */
if (ann->training_algorithm == FANN_TRAIN_ADAM) {
/* Allocate first moment vector (m) */
if (ann->adam_m == NULL) {
ann->adam_m = (fann_type *)calloc(ann->total_connections_allocated, sizeof(fann_type));
if (ann->adam_m == NULL) {
fann_error((struct fann_error *)ann, FANN_E_CANT_ALLOCATE_MEM);
return;
}
} else {
memset(ann->adam_m, 0, (ann->total_connections_allocated) * sizeof(fann_type));
}

/* Allocate second moment vector (v) */
if (ann->adam_v == NULL) {
ann->adam_v = (fann_type *)calloc(ann->total_connections_allocated, sizeof(fann_type));
if (ann->adam_v == NULL) {
fann_error((struct fann_error *)ann, FANN_E_CANT_ALLOCATE_MEM);
return;
}
} else {
memset(ann->adam_v, 0, (ann->total_connections_allocated) * sizeof(fann_type));
}

/* Reset timestep */
ann->adam_timestep = 0;
}
}

/* INTERNAL FUNCTION
Expand Down Expand Up @@ -682,9 +710,70 @@ void fann_update_weights_irpropm(struct fann *ann, unsigned int first_weight,
}
}

/* INTERNAL FUNCTION
The Adam (Adaptive Moment Estimation) algorithm

Adam combines ideas from momentum and RMSProp:
- Maintains exponential moving averages of gradients (first moment, m)
- Maintains exponential moving averages of squared gradients (second moment, v)
- Uses bias correction to account for initialization at zero

Parameters:
- beta1: exponential decay rate for first moment (default 0.9)
- beta2: exponential decay rate for second moment (default 0.999)
- epsilon: small constant for numerical stability (default 1e-8)
*/
void fann_update_weights_adam(struct fann *ann, unsigned int num_data, unsigned int first_weight,
Comment thread
wu1045718093 marked this conversation as resolved.
unsigned int past_end) {
fann_type *train_slopes = ann->train_slopes;
fann_type *weights = ann->weights;
fann_type *m = ann->adam_m;
fann_type *v = ann->adam_v;

const float learning_rate = ann->learning_rate;
const float beta1 = ann->adam_beta1;
const float beta2 = ann->adam_beta2;
const float epsilon = ann->adam_epsilon;
const float gradient_scale = 1.0f / num_data;

unsigned int i;
float gradient, m_hat, v_hat;
Comment thread
wu1045718093 marked this conversation as resolved.
float beta1_t, beta2_t;

/* Increment timestep */
ann->adam_timestep++;

/* Compute bias correction terms: 1 - beta^t */
beta1_t = 1.0f - powf(beta1, (float)ann->adam_timestep);
beta2_t = 1.0f - powf(beta2, (float)ann->adam_timestep);

for (i = first_weight; i != past_end; i++) {
/* Compute gradient (average over batch) */
gradient = train_slopes[i] * gradient_scale;

/* Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t */
m[i] = beta1 * m[i] + (1.0f - beta1) * gradient;

/* Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2 */
v[i] = beta2 * v[i] + (1.0f - beta2) * gradient * gradient;

/* Compute bias-corrected first moment: m_hat = m_t / (1 - beta1^t) */
m_hat = m[i] / beta1_t;

/* Compute bias-corrected second moment: v_hat = v_t / (1 - beta2^t) */
v_hat = v[i] / beta2_t;

/* Update weights: w_t = w_{t-1} + learning_rate * m_hat / (sqrt(v_hat) + epsilon) */
weights[i] += learning_rate * m_hat / (sqrtf(v_hat) + epsilon);
Comment thread
wu1045718093 marked this conversation as resolved.
Outdated

/* Clear slope for next iteration */
train_slopes[i] = 0.0f;
}
}

/* INTERNAL FUNCTION
The SARprop- algorithm
*/
*/
void fann_update_weights_sarprop(struct fann *ann, unsigned int epoch, unsigned int first_weight,
unsigned int past_end) {
fann_type *train_slopes = ann->train_slopes;
Expand Down Expand Up @@ -919,3 +1008,9 @@ FANN_GET_SET(float, sarprop_temperature)
FANN_GET_SET(enum fann_stopfunc_enum, train_stop_function)
FANN_GET_SET(fann_type, bit_fail_limit)
FANN_GET_SET(float, learning_momentum)

FANN_GET_SET(float, adam_beta1)

FANN_GET_SET(float, adam_beta2)

FANN_GET_SET(float, adam_epsilon)
26 changes: 26 additions & 0 deletions src/fann_train_data.c
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,30 @@ float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data) {
return fann_get_MSE(ann);
}

/*
* Internal train function
*/
float fann_train_epoch_adam(struct fann *ann, struct fann_train_data *data) {
unsigned int i;

if (ann->adam_m == NULL) {
fann_clear_train_arrays(ann);
}

fann_reset_MSE(ann);

for (i = 0; i < data->num_data; i++) {
fann_run(ann, data->input[i]);
fann_compute_MSE(ann, data->output[i]);
fann_backpropagate_MSE(ann);
fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
}

fann_update_weights_adam(ann, data->num_data, 0, ann->total_connections);

return fann_get_MSE(ann);
}

/*
* Internal train function
*/
Expand Down Expand Up @@ -211,6 +235,8 @@ FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_trai
return fann_train_epoch_irpropm(ann, data);
case FANN_TRAIN_SARPROP:
return fann_train_epoch_sarprop(ann, data);
case FANN_TRAIN_ADAM:
return fann_train_epoch_adam(ann, data);
case FANN_TRAIN_BATCH:
return fann_train_epoch_batch(ann, data);
case FANN_TRAIN_INCREMENTAL:
Expand Down
24 changes: 22 additions & 2 deletions src/include/fann_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ enum fann_train_enum {
FANN_TRAIN_BATCH,
FANN_TRAIN_RPROP,
FANN_TRAIN_QUICKPROP,
FANN_TRAIN_SARPROP
FANN_TRAIN_SARPROP,
FANN_TRAIN_ADAM
};

/* Constant: FANN_TRAIN_NAMES
Expand All @@ -95,7 +96,7 @@ enum fann_train_enum {
*/
static char const *const FANN_TRAIN_NAMES[] = {"FANN_TRAIN_INCREMENTAL", "FANN_TRAIN_BATCH",
"FANN_TRAIN_RPROP", "FANN_TRAIN_QUICKPROP",
"FANN_TRAIN_SARPROP"};
"FANN_TRAIN_SARPROP", "FANN_TRAIN_ADAM"};

/* Enums: fann_activationfunc_enum

Expand Down Expand Up @@ -754,6 +755,25 @@ struct fann {
*/
fann_type *prev_weights_deltas;

/* Adam optimizer parameters */
/* First moment vector (mean of gradients) for Adam optimizer */
fann_type *adam_m;

/* Second moment vector (variance of gradients) for Adam optimizer */
fann_type *adam_v;

/* Exponential decay rate for the first moment estimates (default 0.9) */
float adam_beta1;

/* Exponential decay rate for the second moment estimates (default 0.999) */
float adam_beta2;

/* Small constant for numerical stability (default 1e-8) */
float adam_epsilon;

/* Current timestep for Adam optimizer */
unsigned int adam_timestep;

#ifndef FIXEDFANN
/* Arithmetic mean used to remove steady component in input data. */
float *scale_mean_in;
Expand Down
Loading