Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ cascade_train_debug
xor_test_debug
xor_test_fixed_debug
xor_train_debug
robot_adam
robot_adam.net

4 changes: 2 additions & 2 deletions examples/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ GCC=gcc
CFLAGS=-I ../src/include
LDFLAGS=-L ../src/

TARGETS = xor_train xor_test xor_test_fixed simple_train steepness_train simple_test robot mushroom cascade_train scaling_test scaling_train
DEBUG_TARGETS = xor_train_debug xor_test_debug xor_test_fixed_debug cascade_train_debug
TARGETS = xor_train xor_test xor_test_fixed simple_train steepness_train simple_test robot mushroom cascade_train scaling_test scaling_train robot_adam
DEBUG_TARGETS = xor_train_debug xor_test_debug xor_test_fixed_debug cascade_train_debug robot_adam_debug

all: $(TARGETS)

Expand Down
72 changes: 72 additions & 0 deletions examples/robot_adam.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
Fast Artificial Neural Network Library (fann)
Copyright (C) 2003-2016 Steffen Nissen (steffen.fann@gmail.com)

This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.

This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/

#include <stdio.h>

#include "fann.h"

int main()
{
const unsigned int num_layers = 3;
const unsigned int num_neurons_hidden = 96;
const float desired_error = (const float) 0.00001;
struct fann *ann;
struct fann_train_data *train_data, *test_data;

unsigned int i = 0;

printf("Creating network.\n");

train_data = fann_read_train_from_file("../datasets/robot.train");

ann = fann_create_standard(num_layers,
train_data->num_input, num_neurons_hidden, train_data->num_output);

printf("Training network.\n");

fann_set_training_algorithm(ann, FANN_TRAIN_ADAM);
fann_set_adam_beta1(ann, 0.9f);
fann_set_adam_beta2(ann, 0.999f);
fann_set_adam_epsilon(ann, 1e-8f);
fann_set_learning_rate(ann, 0.05f);

fann_train_on_data(ann, train_data, 10000, 100, desired_error);

printf("Testing network.\n");

test_data = fann_read_train_from_file("../datasets/robot.test");

fann_reset_MSE(ann);
for(i = 0; i < fann_length_train_data(test_data); i++)
{
fann_test(ann, test_data->input[i], test_data->output[i]);
}
printf("MSE error on test data: %f\n", fann_get_MSE(ann));

printf("Saving network.\n");

fann_save(ann, "robot_adam.net");

printf("Cleaning up.\n");
fann_destroy_train(train_data);
fann_destroy_train(test_data);
fann_destroy(ann);

return 0;
}
44 changes: 44 additions & 0 deletions src/fann.c
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,8 @@ FANN_EXTERNAL void FANN_API fann_destroy(struct fann *ann) {
fann_safe_free(ann->prev_train_slopes);
fann_safe_free(ann->prev_steps);
fann_safe_free(ann->prev_weights_deltas);
fann_safe_free(ann->adam_m);
fann_safe_free(ann->adam_v);
fann_safe_free(ann->errstr);
fann_safe_free(ann->cascade_activation_functions);
fann_safe_free(ann->cascade_activation_steepnesses);
Expand Down Expand Up @@ -888,6 +890,14 @@ FANN_EXTERNAL struct fann *FANN_API fann_copy(struct fann *orig) {
copy->rprop_delta_max = orig->rprop_delta_max;
copy->rprop_delta_zero = orig->rprop_delta_zero;

/* Copy Adam optimizer parameters */
copy->adam_beta1 = orig->adam_beta1;
copy->adam_beta2 = orig->adam_beta2;
copy->adam_epsilon = orig->adam_epsilon;
copy->adam_timestep = orig->adam_timestep;
copy->adam_m = NULL;
copy->adam_v = NULL;
Comment thread
wu1045718093 marked this conversation as resolved.

/* user_data is not deep copied. user should use fann_copy_with_user_data() for that */
copy->user_data = orig->user_data;

Expand Down Expand Up @@ -1008,6 +1018,29 @@ FANN_EXTERNAL struct fann *FANN_API fann_copy(struct fann *orig) {
copy->total_connections_allocated * sizeof(fann_type));
}

/* Copy Adam optimizer moment vectors if they exist */
if (orig->adam_m) {
copy->adam_m = (fann_type *)malloc(copy->total_connections_allocated * sizeof(fann_type));
if (copy->adam_m == NULL) {
fann_error((struct fann_error *)orig, FANN_E_CANT_ALLOCATE_MEM);
fann_destroy(copy);
return NULL;
}
memcpy(copy->adam_m, orig->adam_m,
copy->total_connections_allocated * sizeof(fann_type));
}

if (orig->adam_v) {
copy->adam_v = (fann_type *)malloc(copy->total_connections_allocated * sizeof(fann_type));
if (copy->adam_v == NULL) {
fann_error((struct fann_error *)orig, FANN_E_CANT_ALLOCATE_MEM);
fann_destroy(copy);
return NULL;
}
memcpy(copy->adam_v, orig->adam_v,
copy->total_connections_allocated * sizeof(fann_type));
}

return copy;
}

Expand Down Expand Up @@ -1171,6 +1204,9 @@ FANN_EXTERNAL void FANN_API fann_print_parameters(struct fann *ann) {
printf("RPROP decrease factor :%8.3f\n", ann->rprop_decrease_factor);
printf("RPROP delta min :%8.3f\n", ann->rprop_delta_min);
printf("RPROP delta max :%8.3f\n", ann->rprop_delta_max);
printf("Adam beta1 :%f\n", ann->adam_beta1);
printf("Adam beta2 :%f\n", ann->adam_beta2);
printf("Adam epsilon :%.8f\n", ann->adam_epsilon);
Comment thread
wu1045718093 marked this conversation as resolved.
printf("Cascade output change fraction :%11.6f\n", ann->cascade_output_change_fraction);
printf("Cascade candidate change fraction :%11.6f\n", ann->cascade_candidate_change_fraction);
printf("Cascade output stagnation epochs :%4d\n", ann->cascade_output_stagnation_epochs);
Expand Down Expand Up @@ -1552,6 +1588,14 @@ struct fann *fann_allocate_structure(unsigned int num_layers) {
ann->sarprop_temperature = 0.015f;
ann->sarprop_epoch = 0;

/* Variables for use with Adam training (reasonable defaults) */
ann->adam_m = NULL;
ann->adam_v = NULL;
ann->adam_beta1 = 0.9f;
ann->adam_beta2 = 0.999f;
ann->adam_epsilon = 1e-8f;
ann->adam_timestep = 0;

fann_init_error_data((struct fann_error *)ann);

#ifdef FIXEDFANN
Expand Down
19 changes: 19 additions & 0 deletions src/fann_io.c
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ int fann_save_internal_fd(struct fann *ann, FILE *conf, const char *configuratio
fprintf(conf, "rprop_delta_min=%f\n", ann->rprop_delta_min);
fprintf(conf, "rprop_delta_max=%f\n", ann->rprop_delta_max);
fprintf(conf, "rprop_delta_zero=%f\n", ann->rprop_delta_zero);
fprintf(conf, "adam_beta1=%f\n", ann->adam_beta1);
fprintf(conf, "adam_beta2=%f\n", ann->adam_beta2);
fprintf(conf, "adam_epsilon=%.8f\n", ann->adam_epsilon);
Comment thread
wu1045718093 marked this conversation as resolved.
fprintf(conf, "cascade_output_stagnation_epochs=%u\n", ann->cascade_output_stagnation_epochs);
fprintf(conf, "cascade_candidate_change_fraction=%f\n", ann->cascade_candidate_change_fraction);
fprintf(conf, "cascade_candidate_stagnation_epochs=%u\n",
Expand Down Expand Up @@ -322,6 +325,18 @@ struct fann *fann_create_from_fd_1_1(FILE *conf, const char *configuration_file)
} \
}

/* Optional scanf that sets a default value if the field is not present in the file.
* This is used for new parameters to maintain backward compatibility with older saved networks.
*/
#define fann_scanf_optional(type, name, val, default_val) \
Comment thread
wu1045718093 marked this conversation as resolved.
{ \
long pos = ftell(conf); \
if (fscanf(conf, name "=" type "\n", val) != 1) { \
fseek(conf, pos, SEEK_SET); \
*(val) = (default_val); \
} \
}

#define fann_skip(name) \
{ \
if (fscanf(conf, name) != 0) { \
Expand Down Expand Up @@ -420,6 +435,10 @@ struct fann *fann_create_from_fd(FILE *conf, const char *configuration_file) {
fann_scanf("%f", "rprop_delta_min", &ann->rprop_delta_min);
fann_scanf("%f", "rprop_delta_max", &ann->rprop_delta_max);
fann_scanf("%f", "rprop_delta_zero", &ann->rprop_delta_zero);
/* Adam parameters are optional for backward compatibility with older saved networks */
fann_scanf_optional("%f", "adam_beta1", &ann->adam_beta1, 0.9f);
fann_scanf_optional("%f", "adam_beta2", &ann->adam_beta2, 0.999f);
fann_scanf_optional("%f", "adam_epsilon", &ann->adam_epsilon, 1e-8f);
Comment thread
wu1045718093 marked this conversation as resolved.
fann_scanf("%u", "cascade_output_stagnation_epochs", &ann->cascade_output_stagnation_epochs);
fann_scanf("%f", "cascade_candidate_change_fraction", &ann->cascade_candidate_change_fraction);
fann_scanf("%u", "cascade_candidate_stagnation_epochs",
Expand Down
97 changes: 96 additions & 1 deletion src/fann_train.c
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,34 @@ void fann_clear_train_arrays(struct fann *ann) {
} else {
memset(ann->prev_train_slopes, 0, (ann->total_connections_allocated) * sizeof(fann_type));
}

/* Allocate and initialize Adam optimizer arrays if using Adam */
if (ann->training_algorithm == FANN_TRAIN_ADAM) {
/* Allocate first moment vector (m) */
if (ann->adam_m == NULL) {
ann->adam_m = (fann_type *)calloc(ann->total_connections_allocated, sizeof(fann_type));
if (ann->adam_m == NULL) {
fann_error((struct fann_error *)ann, FANN_E_CANT_ALLOCATE_MEM);
return;
}
} else {
memset(ann->adam_m, 0, (ann->total_connections_allocated) * sizeof(fann_type));
}

/* Allocate second moment vector (v) */
if (ann->adam_v == NULL) {
ann->adam_v = (fann_type *)calloc(ann->total_connections_allocated, sizeof(fann_type));
if (ann->adam_v == NULL) {
fann_error((struct fann_error *)ann, FANN_E_CANT_ALLOCATE_MEM);
return;
}
} else {
memset(ann->adam_v, 0, (ann->total_connections_allocated) * sizeof(fann_type));
}

/* Reset timestep */
ann->adam_timestep = 0;
}
}

/* INTERNAL FUNCTION
Expand Down Expand Up @@ -682,9 +710,70 @@ void fann_update_weights_irpropm(struct fann *ann, unsigned int first_weight,
}
}

/* INTERNAL FUNCTION
The Adam (Adaptive Moment Estimation) algorithm

Adam combines ideas from momentum and RMSProp:
- Maintains exponential moving averages of gradients (first moment, m)
- Maintains exponential moving averages of squared gradients (second moment, v)
- Uses bias correction to account for initialization at zero

Parameters:
- beta1: exponential decay rate for first moment (default 0.9)
- beta2: exponential decay rate for second moment (default 0.999)
- epsilon: small constant for numerical stability (default 1e-8)
*/
void fann_update_weights_adam(struct fann *ann, unsigned int num_data, unsigned int first_weight,
Comment thread
wu1045718093 marked this conversation as resolved.
unsigned int past_end) {
fann_type *train_slopes = ann->train_slopes;
fann_type *weights = ann->weights;
fann_type *m = ann->adam_m;
fann_type *v = ann->adam_v;

const float learning_rate = ann->learning_rate;
const float beta1 = ann->adam_beta1;
const float beta2 = ann->adam_beta2;
const float epsilon = ann->adam_epsilon;
const float gradient_scale = 1.0f / num_data;

unsigned int i;
float gradient, m_hat, v_hat;
Comment thread
wu1045718093 marked this conversation as resolved.
float beta1_t, beta2_t;

/* Increment timestep */
ann->adam_timestep++;

/* Compute bias correction terms: 1 - beta^t */
beta1_t = 1.0f - powf(beta1, (float)ann->adam_timestep);
beta2_t = 1.0f - powf(beta2, (float)ann->adam_timestep);

for (i = first_weight; i != past_end; i++) {
/* Compute gradient (average over batch) */
gradient = train_slopes[i] * gradient_scale;

/* Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t */
m[i] = beta1 * m[i] + (1.0f - beta1) * gradient;

/* Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2 */
v[i] = beta2 * v[i] + (1.0f - beta2) * gradient * gradient;

/* Compute bias-corrected first moment: m_hat = m_t / (1 - beta1^t) */
m_hat = m[i] / beta1_t;

/* Compute bias-corrected second moment: v_hat = v_t / (1 - beta2^t) */
v_hat = v[i] / beta2_t;

/* Update weights: w_t = w_{t-1} + learning_rate * m_hat / (sqrt(v_hat) + epsilon) */
weights[i] += learning_rate * m_hat / (sqrtf(v_hat) + epsilon);
Comment thread
wu1045718093 marked this conversation as resolved.
Outdated

/* Clear slope for next iteration */
train_slopes[i] = 0.0f;
}
}

/* INTERNAL FUNCTION
The SARprop- algorithm
*/
*/
void fann_update_weights_sarprop(struct fann *ann, unsigned int epoch, unsigned int first_weight,
unsigned int past_end) {
fann_type *train_slopes = ann->train_slopes;
Expand Down Expand Up @@ -919,3 +1008,9 @@ FANN_GET_SET(float, sarprop_temperature)
FANN_GET_SET(enum fann_stopfunc_enum, train_stop_function)
FANN_GET_SET(fann_type, bit_fail_limit)
FANN_GET_SET(float, learning_momentum)

FANN_GET_SET(float, adam_beta1)

FANN_GET_SET(float, adam_beta2)

FANN_GET_SET(float, adam_epsilon)
26 changes: 26 additions & 0 deletions src/fann_train_data.c
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,30 @@ float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data) {
return fann_get_MSE(ann);
}

/*
* Internal train function
*/
float fann_train_epoch_adam(struct fann *ann, struct fann_train_data *data) {
unsigned int i;

if (ann->adam_m == NULL) {
fann_clear_train_arrays(ann);
}

fann_reset_MSE(ann);

for (i = 0; i < data->num_data; i++) {
fann_run(ann, data->input[i]);
fann_compute_MSE(ann, data->output[i]);
fann_backpropagate_MSE(ann);
fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
}

fann_update_weights_adam(ann, data->num_data, 0, ann->total_connections);

return fann_get_MSE(ann);
}

/*
* Internal train function
*/
Expand Down Expand Up @@ -211,6 +235,8 @@ FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_trai
return fann_train_epoch_irpropm(ann, data);
case FANN_TRAIN_SARPROP:
return fann_train_epoch_sarprop(ann, data);
case FANN_TRAIN_ADAM:
return fann_train_epoch_adam(ann, data);
case FANN_TRAIN_BATCH:
return fann_train_epoch_batch(ann, data);
case FANN_TRAIN_INCREMENTAL:
Expand Down
Loading