;;;; Backpropagation program for training a net having a single hidden layer
;;;; and no pass-through connections.  Updates weights in "batch mode"
;;;; (i.e., after each complete pass through the training data).
#|
How to use this program:

(load "...-data")
(init-backprop-net ...)
(train-backprop-net)
(test-backprop-net ...)
|#

;;; Globals user may wish to set

(defvar *display-interval* 100)		;measured in passes through data
(defvar *tolerance* 0.1)

;;; Other globals

(defvar *backprop-data*)
;(defvar *nbr-inputs*)			;includes bias input
;(defvar *nbr-outputs*)
(defvar *hidden-weights*)
(defvar *output-weights*)
(defvar *delta-h-wts*)
(defvar *delta-o-wts*)
(defvar *learn-coeff*)
(defvar *hidden-activity*)
(defvar *hidden-activity-plus-1*)	;includes bias line
(defvar *output-activity*)
(defvar *hidden-deltas*)
(defvar *output-deltas*)
(defvar *errors*)

;;; User routines.

(defun init-backprop-net (&optional (nbr-hidden-units 2))
  (setq nbr-hidden-units (max 2 nbr-hidden-units))
  (setq *backprop-data*
	(mapcar #'coerce-to-vectors (convert-data-to-nnet-form)))
  (format t
    "~%Building net with ~a input unit~:p (plus bias), ~
     ~a hidden unit~:p (plus bias),~%and ~a output unit~:p.~
     ~%Weights initialized to small random values."
    (1- *nbr-inputs*)
    nbr-hidden-units
    *nbr-outputs*)
  (setq *hidden-weights*
	(random-matrix nbr-hidden-units *nbr-inputs* 0.5))
  (setq *delta-h-wts*
	(make-array (list nbr-hidden-units *nbr-inputs*)))
  (setq *output-weights*
	(random-matrix *nbr-outputs* (1+ nbr-hidden-units) 0.5))
  (setq *delta-o-wts*
	(make-array (list *nbr-outputs* (1+ nbr-hidden-units))))
  (setq *hidden-activity-plus-1*
	(make-array (1+ nbr-hidden-units) :initial-element 1.0))
  (setq *hidden-activity*
	(make-array nbr-hidden-units
		    :displaced-to *hidden-activity-plus-1*))
  (setq *hidden-deltas* (make-array nbr-hidden-units))
  (setq *output-activity* (make-array *nbr-outputs*))
  (setq *output-deltas* (make-array *nbr-outputs*))
  (setq *errors* (make-array *nbr-outputs*))
  'done
  )

(defun train-backprop-net (&optional (learn-coeff 0.5) (max-passes 5000))
  (let ((divisor (* (length *backprop-data*) *nbr-outputs*)))
    (dotimes (passes max-passes (fail-message max-passes *tolerance*))
      (clear-matrix *delta-h-wts*)
      (clear-matrix *delta-o-wts*)
      (let ((sq-error 0.0)
	    (any-errors nil))
	(dolist (p *backprop-data*)
          (let ((input (first p))
		(desired-output (second p)))
	    (propagate-activity input *hidden-weights* *hidden-activity*)
	    (propagate-activity *hidden-activity-plus-1* *output-weights*
				*output-activity*)
	    (multiple-value-bind (sq-error-incr errors-found)
				 (get-errors desired-output *output-activity*
					     *tolerance* *errors*)
	      (incf sq-error sq-error-incr)
	      (setq any-errors (or any-errors errors-found)))
	    (begin-deltas *errors* *output-activity* *output-deltas*)
	    (propagate-deltas *output-deltas* *output-weights*
			      *hidden-activity* *hidden-deltas*)
	    (increment-by-outer-prod *output-deltas* *hidden-activity-plus-1*
				     *delta-o-wts*)
	    (increment-by-outer-prod *hidden-deltas* input *delta-h-wts*)
	    ))
	(when (= (mod passes *display-interval*) 0)
	      (show-progress passes (sqrt (/ sq-error divisor))))
	(unless any-errors (return (success-message passes *tolerance*)))
	(increment-wts learn-coeff *delta-h-wts* *hidden-weights*)
	(increment-wts learn-coeff *delta-o-wts* *output-weights*)
	))))

(defun test-backprop-net (input-att-vec)
  (when (check-att-vector input-att-vec)
    (let ((input (coerce (convert-input-to-euclid-vec input-att-vec) 'vector)))
      (show-layer "Input:~%" input "~5,2f")
      (propagate-activity input *hidden-weights* *hidden-activity*)
      (show-layer "Hidden:~%" *hidden-activity-plus-1* "~5,2f")
      (propagate-activity *hidden-activity-plus-1* *output-weights*
			  *output-activity*)
      (show-layer "Output:~%" *output-activity* "~5,2f")
      (prettify-output
       (convert-output-to-att-vec (concatenate 'list *output-activity*)))
      )))

(defun show-backprop-weights ()
  (format t "~%Input-to-hidden weights:~%")
  (show-matrix *hidden-weights*)
  (format t "~%Hidden-to-output weights:~%")
  (show-matrix *output-weights*)
  (values)
  )

;;; Auxiliaries

(defun propagate-activity (in wts out)
  (dotimes (i (array-dimension out 0))
    (let ((sum 0.0))
      (dotimes (j (array-dimension in 0))
	(incf sum (* (aref wts i j) (aref in j))))
      (setf (aref out i) (squash sum)))))

(defun get-errors (desired-output actual-output tolerance errors)
  (let ((sq-error 0.0)
	(any-errors nil))
    (dotimes (i (array-dimension errors 0) (values sq-error any-errors))
      (let ((diff (- (aref desired-output i) (aref actual-output i))))
	(incf sq-error (expt diff 2))
	(setf (aref errors i)
	      (if (> (abs diff) tolerance)
		  (progn (setq any-errors t) diff)
		  0.0))
	))))

(defun begin-deltas (errors output deltas)
  (dotimes (i (array-dimension deltas 0))
    (setf (aref deltas i)
	  (* (aref errors i)
	     (squash-prime-for (aref output i))))))

(defun propagate-deltas (out-deltas wts in-activity in-deltas)
  (dotimes (j (array-dimension in-deltas 0))
    (let ((sum 0.0))
      (dotimes (i (array-dimension out-deltas 0))
	(incf sum (* (aref wts i j) (aref out-deltas i))))
      (setf (aref in-deltas j)
	    (* sum (squash-prime-for (aref in-activity j)))))))

(defun squash (x)
  (/ 1.0 (+ 1.0 (exp (- x)))))

(defun squash-prime-for (y)
  (* y (- 1.0 y)))

(defun coerce-to-vectors (seq-list)
  (mapcar #'(lambda (x) (coerce x 'vector)) seq-list))

(defun increment-wts (coeff incr wts)
  (dotimes (i (array-dimension wts 0))
    (dotimes (j (array-dimension wts 1))
      (incf (aref wts i j)
	    (* coeff (aref incr i j))))))

(defun increment-by-outer-prod (vec1 vec2 matrix)
  (dotimes (i (array-dimension vec1 0))
    (dotimes (j (array-dimension vec2 0))
      (incf (aref matrix i j)
	    (* (aref vec1 i) (aref vec2 j))))))

(defun clear-matrix (matrix)
  (dotimes (i (array-dimension matrix 0))
    (dotimes (j (array-dimension matrix 1))
      (setf (aref matrix i j) 0.0))))

(defun random-matrix (rows cols val)
  (let ((matrix (make-array (list rows cols))))
    (dotimes (i rows matrix)
      (dotimes (j cols)
	(setf (aref matrix i j)
	      (uniform-random (- val) val))))))

(defun uniform-random (a b)
  (+ a (random (- b a))))

(defun fail-message (passes tolerance)
  (format t
    "~%Not learned to within tolerance ~4,2f after ~a "
    tolerance
    passes)
  (format t (if (= passes 1) "pass." "passes.")))

(defun show-progress (passes rms-error)
  (format t "~%~a passes through data: RMS error = ~6,4f" passes rms-error))

(defun success-message (passes tolerance)
  (format t
    "~%Patterns all learned to within tolerance ~4,2f after ~a "
    tolerance
    passes)
  (format t (if (= passes 1) "pass." "passes."))
  'done)

(defun show-matrix (mat &optional (format-string "~6,2f"))
  (dotimes (i (array-dimension mat 0))
    (dotimes (j (array-dimension mat 1) (terpri))
      (format t format-string (aref mat i j)))))
