-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathRestrictedBoltzmannMachine.hs
228 lines (183 loc) · 9.79 KB
/
RestrictedBoltzmannMachine.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
{-# LANGUAGE PatternGuards, ScopedTypeVariables #-}
-- | Base Restricted Boltzmann machine.
-- http://en.wikipedia.org/wiki/Restricted_Boltzmann_machine
module Hopfield.Boltzmann.RestrictedBoltzmannMachine where
import Data.Maybe
import Control.Monad
import Control.Monad.Random
import Data.List
import Data.Vector ((!))
import qualified Data.Vector as V
import qualified Numeric.Container as NC
import Hopfield.Common
import Hopfield.Util
-- In the case of the Boltzmann Machine the weight matrix establishes the
-- weights between visible and hidden neurons
-- w i j - connection between visible neuron i and hidden neuron j
-- | determines the rate in which the weights are changed in the training phase.
-- http://en.wikipedia.org/wiki/Restricted_Boltzmann_machine#Training_algorithm
learningRate :: Double
learningRate = 0.1
data Mode = Hidden | Visible
deriving(Eq, Show)
data Phase = Training | Matching
deriving(Eq, Show)
data BoltzmannData = BoltzmannData {
weightsB :: Weights -- ^ the weights of the network
, patternsB :: [Pattern] -- ^ the patterns which were used to train it
, nr_hiddenB :: Int -- ^ number of neurons in the hidden layer
, pattern_to_binaryB :: [(Pattern, [Int])] -- ^ the binary representation of the pattern index
-- the pattern_to_binary field will not replace the patternsB field as it does
-- not contain duplicated patterns, which might be required for statistical
-- analysis in clustering and super attractors
}
deriving(Show)
-- | Retrieves the dimension of the weights matrix corresponding to the given mode.
-- For hidden, it is the width of the matrix, and for visible it is the height.
getDimension :: Mode -> Weights -> Int
getDimension Hidden ws = V.length $ ws ! 0
getDimension Visible ws = V.length $ ws
notMode :: Mode -> Mode
notMode Visible = Hidden
notMode Hidden = Visible
-- | @buildBoltzmannData patterns@ trains a boltzmann network with @patterns@.
-- The number of hidden neurons is set to the number of visible neurons.
buildBoltzmannData :: MonadRandom m => [Pattern] -> m BoltzmannData
buildBoltzmannData [] = error "Train patterns are empty"
buildBoltzmannData pats =
buildBoltzmannData' pats nr_visible
where nr_visible = fromIntegral $ V.length (head pats)
-- | @buildBoltzmannData' patterns nr_hidden@: Takes a list of patterns and
-- builds a Boltzmann network (by training) in which these patterns are
-- stable states. The result of this function can be used to run a pattern
-- against the network, by using 'matchPatternBoltzmann'.
buildBoltzmannData' :: MonadRandom m => [Pattern] -> Int -> m BoltzmannData
buildBoltzmannData' [] _ = error "Train patterns are empty"
buildBoltzmannData' pats nr_hidden
| first_len == 0
= error "Cannot have empty patterns"
| any (\x -> V.length x /= first_len) pats
= error "All training patterns must have the same length"
| otherwise = do
(ws, pats_with_binary) :: (Weights, [(Pattern, [Int])]) <- trainBoltzmann pats nr_hidden
return $ BoltzmannData ws pats nr_hidden pats_with_binary
where
first_len = V.length (head pats)
-- Pure version of updateNeuron for testing
updateNeuron' :: Double -> Phase -> Mode -> Weights -> Pattern -> Int -> Int
updateNeuron' r phase mode ws pat index = if (r < a) then 1 else 0
where a = getActivationProbability phase mode ws pat index
--
getActivationProbability :: Phase -> Mode -> Weights -> Pattern -> Int -> Double
getActivationProbability phase mode ws pat index = if a <=1 && a >=0 then a else error (show a)
where
a = activation . sum $ case mode of
Hidden -> [ (ws ! index ! i) *. (pat' ! i) | i <- [0 .. p-1] ]
Visible -> [ (ws ! i ! index) *. (pat' ! i) | i <- [0 .. p-1] ]
pat' = case phase of
Matching -> V.cons 1 pat
Training -> pat
p = V.length pat'
-- | @updateNeuron mode ws pat index@ , given a vector @pat@ of type @mode@
-- updates the neuron with number @index@ in the layer with opposite type.
updateNeuron :: MonadRandom m => Phase -> Mode -> Weights -> Pattern -> Int -> m Int
updateNeuron phase mode ws pat index = do
r <- getRandomR (0.0, 1.0)
return $ updateNeuron' r phase mode ws pat index
-- | @getCounterPattern mode ws pat@, given a vector @pat@ of type @mode@
-- computes the values of all the neurons in the layer of the opposite type.
getCounterPattern :: MonadRandom m => Phase -> Mode -> Weights -> Pattern -> m Pattern
getCounterPattern phase mode ws pat
| Just e <- validPattern phase mode ws pat = error e
| otherwise = V.fromList `liftM` mapM (updateNeuron phase mode ws pat) updatedIndices
where
updatedIndices = [0 .. getDimension (notMode mode) ws - diff]
diff = case phase of
Training -> 1
Matching -> 2
-- | One step which updates the weights in the CD-n training process.
-- The weights are changed according to one of the training patterns.
-- http://en.wikipedia.org/wiki/Restricted_Boltzmann_machine#Training_algorithm
updateWeights :: MonadRandom m => Weights -> Pattern -> m Weights
updateWeights ws v = do
let biased_v = V.cons 1 v
h <- getCounterPattern Training Visible ws biased_v
v' <- getCounterPattern Training Hidden ws h
let f = fromDataVector . fmap fromIntegral
pos = NC.toLists $ (f biased_v) `NC.outer` (fromDataVector $ getSigmaH v) -- "positive gradient"
neg = NC.toLists $ (f v') `NC.outer` (fromDataVector $ getSigmaH v') -- "negative gradient"
d_ws = map (map (* learningRate)) $ combine (-) pos neg -- weights delta
new_weights = combine (+) (list2D ws) d_ws
nr_hidden = V.length $ ws ! 0
getSigmaH y = V.fromList [getActivationProbability Training Visible ws y x | x <- [0.. nr_hidden - 1] ]
return $ vector2D new_weights
-- | The training function for the Boltzmann Machine.
-- We are using the contrastive divergence algorithm CD-1
-- TODO see if making the vis
-- (we could extend to CD-n, but "In practice, CD-1 has been shown to work surprisingly well."
-- @trainBoltzmann pats nr_hidden@ where @pats@ are the training patterns
-- and @nr_hidden@ is the number of neurons to be created in the hidden layer.
-- http://en.wikipedia.org/wiki/Restricted_Boltzmann_machine#Training_algorithm
trainBoltzmann :: MonadRandom m => [Pattern] -> Int -> m (Weights, [(Pattern, [Int])])
trainBoltzmann pats nr_hidden = do
weights_without_bias <- genWeights
-- add biases as a dimension of the matrix, in order to include them in the
-- contrastive divergence algorithm
let ws = [0: x | x <- weights_without_bias]
ws_start = (replicate (nr_hidden + 1) 0) : ws
updated_ws <- foldM updateWeights (vector2D ws_start) pats'
return (updated_ws, paths_with_binary_indices)
where
genWeights = replicateM nr_visible . replicateM nr_hidden $ normal 0.0 0.01
paths_with_binary_indices = getBinaryIndices pats
pats' = [(V.++) x $ encoding x | x <- pats]
encoding x = V.fromList . fromJust $ lookup x paths_with_binary_indices
nr_visible = V.length $ pats' !! 0
-- | The activation function for the network (the logistic sigmoid).
-- http://en.wikipedia.org/wiki/Sigmoid_function
activation :: Double -> Double
activation x = 1.0 / (1.0 + exp (-x))
-- | @validPattern mode weights pattern@
-- Returns an error string in a Just if the @pattern@ is not compatible
-- with @weights@ and Nothing otherwise. @mode@ gives the type of the pattern,
-- which is checked (Visible or Hidden).
validPattern :: Phase -> Mode -> Weights -> Pattern -> Maybe String
validPattern phase mode ws pat
| checked_dim /= V.length pat = Just $ "Size of pattern must match network size in " ++ show phase ++ " " ++ show mode
| V.any (\x -> notElem x [0, 1]) pat = Just "Non binary element in Boltzmann pattern"
| otherwise = Nothing
where checked_dim = if phase == Training then actual_dim else actual_dim - 1
actual_dim = getDimension mode ws
validWeights :: Weights -> Maybe String
validWeights ws
| V.null ws = Just "The matrix of weights is empty"
| V.any (\x -> V.length x /= V.length (ws ! 0)) ws = Just "weights matrix ill formed"
| otherwise = Nothing
-- | Updates a pattern using the Boltzmann machine
updateBoltzmann :: MonadRandom m => Weights -> Pattern -> m Pattern
updateBoltzmann ws pat = do
h <- getCounterPattern Matching Visible ws pat
getCounterPattern Matching Hidden ws h
-- see http://www.cs.toronto.edu/~hinton/absps/guideTR.pdf section 16.1
-- And stack overflow discussion
-- http://stackoverflow.com/questions/9944568/the-free-energy-approximation-equation-in-restriction-boltzmann-machines
-- http://www.dmi.usherb.ca/~larocheh/publications/class_set_rbms_uai.pdf
getFreeEnergy :: Weights -> Pattern -> Double
getFreeEnergy ws pat
| Just e <- validWeights ws = error e
| Just e <- validPattern Matching Visible ws pat = error e
| otherwise = - biases - sum (map f xs)
where w i j = ((ws :: Weights) ! i ! j) :: Double
biases = sum [ w (i + 1) 0 *. (pat ! i) | i <- [0 .. p - 1] ]
xs = [ w 0 j + sum [ w (i + 1) j *. (pat ! i) | i <- [0 .. p - 1] ] | j <- [1 .. (V.length $ ws ! 0) - 1]]
f x = log (1 + exp x)
p = V.length pat
-- | Matches a pattern against the a given network
matchPatternBoltzmann :: MonadRandom m => BoltzmannData -> Pattern -> m Int
matchPatternBoltzmann (BoltzmannData ws pats _ pats_with_binary) pat = do
hot_pat <- updateBoltzmann ws ((V.++) pat (V.fromList $ snd $ head pats_with_binary))
let h = V.take (V.length $ head pats) hot_pat
extendWithClass p = ((V.++) h (V.fromList . fromJust $ lookup p pats_with_binary) )
getPatternProbability x = exp $ (- getFreeEnergy ws x)
fromPatToIndex p = fromJust $ p `elemIndex` pats
return $ fst $ maximumBy (compareBy snd) [(fromPatToIndex p, (getPatternProbability . extendWithClass) p) | p <- pats]