Is it possible to extract the weight matrix from a network?
Sure. The method model.save_weights() will do it for you and store the weights to hdf5.
If you want to do it manually, you'd do something like:
for layer in model.layers:
weights = layer.get_weights() # list of numpy arrays
Thanks! Are the bias included here?
I cant quite figure out the format of the weights returned....even with
a simple model with 0 hidden layers, I get back a large number of
weights. I also see that activation parameters have weights associated
with them....
Wondering if this format is documented or is going through the code the
only way? Biases are here to?
I am running this simple model on the otto dataset:
print("Building model...")
model = Sequential()
model.add(Dense(dims, 10, init='glorot_uniform'))
model.add(PReLU((10,)))
model.add(BatchNormalization((10,)))
model.add(Dropout(0.5))
model.add(Dense(10, nb_classes, init='glorot_uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer="adam")
print("Training model...")
model.fit(train_X, train_y, nb_epoch=1, batch_size=250, validation_sp
And then this
for layer in model.layers:
g=layer.get_config()
h=layer.get_weights()
print (g)
print (h)
And the output is....
​
​{'output_dim': 10, 'init': 'glorot_uniform', 'activation': 'linear',
'name': 'Dense', 'input_dim': 94}
[array([[ -7.99849009e-02, -2.45893497e-02, 4.69488823e-02,
6.59520472e-04, 5.47647455e-02, -5.37778421e-02,
-6.25924897e-02, -8.44473401e-02, -8.43347936e-02,
-1.41206931e-02],
[ -7.28297196e-03, -6.94909196e-02, 5.09896591e-02,
1.07518776e-01, -1.40819659e-01, 1.28122287e-03,
-7.18242022e-02, -7.47769234e-02, -5.61066325e-04,
5.22953511e-02],
[ 4.48867873e-02, -5.98358224e-02, -3.34403798e-02,
-6.98303095e-02, 1.42869021e-02, 9.69443043e-03,
-5.72198990e-02, 1.26906643e-01, 1.44663706e-01,
-1.05488360e-01],
[ -4.94764224e-03, -6.72982314e-02, 1.98947019e-02,
-2.34841952e-04, 8.15450053e-02, 9.93956783e-02,
5.33930437e-02, 9.88775421e-02, 2.62561318e-03,
-1.37490951e-01],
[ -9.69757727e-02, -1.93766233e-02, -2.67758527e-02,
1.24528297e-03, -4.38467242e-02, -3.73318840e-02,
-3.62683822e-02, 3.58399029e-02, -2.16549907e-03,
6.28375524e-02],
[ 2.30916925e-02, -5.03360711e-02, -3.48193628e-03,
-1.11732154e-02, -2.15673704e-02, -5.69503631e-02,
9.87205698e-03, -2.12394581e-02, 3.23201450e-02,
8.15617402e-02],
[ -1.22399136e-01, -1.53275598e-03, -2.49756349e-02,
2.55169260e-02, 6.42912644e-02, -8.35097613e-02,
-1.16564598e-01, 5.86628949e-02, -3.66701814e-02,
7.93936586e-02],
[ -8.18904005e-03, 9.61438433e-02, 1.05611869e-01,
-5.29767001e-02, 1.30970726e-01, -1.05098848e-01,
-5.67985073e-02, 8.96001608e-02, -4.06112779e-03,
7.38853071e-02],
[ -1.10601817e-01, 4.01777134e-02, -1.04958505e-01,
-1.86499188e-01, -2.37804665e-02, -3.76688537e-02,
-1.45232209e-01, -6.36418185e-02, -7.63368207e-02,
-1.17734138e-01],
[ -2.40944892e-02, -1.99559485e-02, -3.10534144e-02,
2.51051600e-02, -3.66193065e-02, -6.94967520e-02,
4.28971762e-02, -4.68141541e-02, 2.08197228e-02,
4.14676310e-02],
[ 1.85548637e-01, -2.00046197e-01, -1.74210642e-01,
9.19546529e-02, 2.09622917e-01, 1.66180606e-01,
2.18369084e-01, 1.06692421e-01, 2.23854503e-02,
-1.63984764e-01],
[ -2.25432936e-02, 1.21225411e-02, -1.77713444e-02,
2.49248395e-02, -9.09446976e-02, -2.12943360e-02,
-5.79632667e-02, -7.98705389e-02, 1.45603023e-02,
-3.29101095e-02],
[ 1.06982420e-02, 8.10667692e-03, -7.40422225e-03,
8.73337501e-02, -4.70680094e-02, -6.96424801e-04,
-3.52869543e-02, -5.75422152e-02, -2.50050629e-03,
5.30206962e-02],
[ -1.05554822e-01, 1.05913849e-01, -1.99050840e-01,
-1.09049315e-01, -2.13796040e-01, -2.46177639e-01,
3.93859830e-02, -1.60071728e-01, -2.43849880e-01,
-9.01017306e-02],
[ -2.66245883e-02, 1.02689584e-01, -1.13140709e-01,
-2.50536214e-01, -1.36009340e-01, -9.52662533e-02,
1.40512889e-02, -2.11842101e-01, -2.26055642e-01,
-1.35109722e-01],
[ -7.53730907e-03, -1.64213430e-03, -6.63893471e-02,
-7.66609176e-02, -3.59994294e-02, -9.46264700e-02,
2.58486499e-02, 2.08334686e-03, -8.75103806e-02,
-1.10772615e-02],
[ 1.52794017e-02, 7.91648709e-03, 1.07050460e-01,
-3.69858265e-02, 5.07007745e-02, -5.97656110e-02,
-3.49382162e-02, 1.55669800e-01, 1.03023857e-02,
-4.40477760e-02],
[ 4.07815059e-03, 2.83262161e-02, -1.48807663e-03,
9.73448516e-02, 3.56439926e-03, -1.23538034e-02,
9.76411415e-02, 2.06518137e-02, 6.17017622e-02,
-8.53835248e-02],
[ -8.09946660e-02, 2.61380996e-02, 2.63300387e-02,
9.93539858e-02, -1.10327820e-01, 4.95747644e-02,
-3.89292545e-02, -1.14508714e-02, 3.19129217e-02,
1.23465986e-01],
[ 3.59885269e-03, 2.76080271e-03, -2.78764395e-02,
4.24239723e-02, 1.53515202e-03, -2.69810376e-02,
1.63534000e-02, 5.40332907e-02, 1.14591409e-01,
3.83767457e-02],
[ 2.84191996e-03, -6.19442984e-02, -8.03084884e-03,
-1.51138132e-02, 4.29634390e-02, 3.55585086e-02,
-3.89301471e-02, -3.44196220e-02, -1.92320194e-02,
-6.95947810e-02],
[ 8.89051092e-03, -3.63276576e-02, 3.77512802e-02,
-8.61697928e-02, 4.37051203e-02, 1.25958620e-01,
1.27073221e-02, 1.74475895e-02, 3.51703513e-02,
1.76466516e-02],
[ 7.71217281e-02, 1.81580258e-02, 2.90213940e-02,
4.13172177e-02, 8.94682418e-02, -5.88464358e-03,
-8.54934082e-02, -6.62388395e-03, 6.46094670e-03,
-1.88742827e-02],
[ -2.91908240e-02, -2.65753849e-02, 1.52980919e-02,
6.58449558e-02, -7.86922291e-02, 4.58466034e-02,
-3.39622131e-02, 8.56353043e-02, -3.02583216e-02,
1.06242259e-01],
[ -3.46668468e-02, 5.48445800e-02, -9.99637775e-02,
-1.40776997e-01, -2.33112086e-01, -1.12747490e-01,
-3.38637078e-03, -2.72000452e-01, -2.43135959e-01,
1.40148791e-02],
[ 4.55380939e-02, -2.05104710e-01, -1.23208875e-01,
7.37236719e-02, 3.40874932e-02, 1.38913292e-01,
2.39351600e-01, -1.87328804e-02, 7.01187834e-02,
-5.02784378e-02],
[ -1.56907260e-02, -6.93088723e-02, -1.32831942e-01,
1.40660959e-02, 7.01844406e-02, 6.88994173e-02,
1.09614443e-01, 5.21141313e-03, 2.66728291e-02,
-2.12535714e-01],
[ 8.63241877e-03, -2.00266780e-01, 4.21728732e-03,
-2.71931798e-04, -3.74458866e-02, -1.02733646e-02,
1.26404038e-02, 4.53453957e-02, -5.47819209e-03,
-1.78313855e-01],
[ -3.40579999e-02, 3.68030410e-02, 2.78233896e-02,
-6.87630905e-02, 5.79211738e-02, 3.53004862e-02,
2.97676136e-02, -2.83821290e-03, -6.19672378e-02,
4.38129833e-02],
[ -1.29424714e-01, -5.27272972e-02, 6.69243394e-02,
1.19757129e-01, 3.84862554e-02, 9.18853869e-02,
-4.82323177e-02, 1.87875149e-02, 4.63434479e-02,
5.18075847e-02],
[ -4.84346313e-02, -7.05440572e-03, -1.17486716e-01,
-3.78191092e-02, -1.63198220e-02, 6.35379808e-02,
-2.28866377e-02, -4.73959864e-02, 6.47443882e-02,
1.71767526e-02],
[ 7.92322710e-02, -3.99799449e-03, -2.68663861e-02,
2.16343925e-02, 7.08523118e-02, 9.40224531e-03,
-2.73172165e-02, 3.90645337e-02, 1.43386517e-02,
-3.28124923e-02],
[ -7.49563779e-02, 8.96276663e-03, -3.00325036e-02,
-9.26680367e-02, -1.68292320e-01, -1.46136493e-01,
-2.32867781e-02, -1.37576449e-01, -1.08664407e-01,
-3.03448062e-02],
[ -1.58937859e-01, 1.24359548e-01, 1.72723048e-01,
-7.16377564e-02, 8.15923267e-02, -9.26028128e-02,
-3.52089676e-02, -1.71921631e-01, 1.19313224e-01,
5.04153287e-02],
[ 5.87777373e-03, -9.82951252e-02, -1.44826277e-01,
2.87068172e-03, 5.77770075e-02, 3.83663771e-02,
-7.79279419e-02, 5.30456023e-02, -4.32744374e-02,
-8.80488631e-02],
[ -9.51507092e-02, 2.48270494e-02, 1.85502184e-01,
-6.15492334e-02, 1.18012058e-01, -3.35554541e-02,
1.92680772e-03, 1.25766050e-03, -1.33313640e-01,
3.44885089e-02],
[ -8.23068727e-02, 6.00759323e-03, 3.07414607e-02,
3.78444659e-02, 1.59902793e-02, 1.32360708e-02,
-5.58786321e-02, -8.98881639e-03, 9.11022060e-02,
1.27618163e-01],
[ -8.39977134e-02, -2.19334460e-02, 2.93163173e-02,
-9.67890487e-03, -1.45040381e-01, -1.11580661e-02,
-5.45275977e-02, 3.19171362e-02, 6.28167381e-02,
5.43802318e-02],
[ -3.19924245e-02, 9.04734505e-02, -8.88122521e-03,
-6.81262022e-02, 3.81438529e-02, -1.10441657e-01,
-2.75602345e-02, 4.45400922e-02, 1.65449719e-01,
-4.30487717e-03],
[ -8.83774422e-02, 7.58000898e-02, -1.84174130e-01,
-8.53399835e-02, -1.66059125e-01, -1.49298746e-01,
2.59884981e-02, -1.58198291e-01, -1.70953601e-01,
-2.42574980e-02],
[ 3.76959883e-02, 3.19914391e-02, 8.79207932e-02,
3.22329258e-02, -4.57678893e-03, 6.73169597e-02,
3.67542760e-03, -1.08364242e-02, -7.73627399e-03,
1.05920958e-01],
[ 2.22782407e-01, -1.80813558e-01, -5.38938983e-02,
-1.37104372e-01, 1.53482621e-01, 1.41474566e-01,
6.76793766e-02, 9.57258689e-02, 1.20902074e-01,
-8.28063868e-02],
[ 1.45137965e-02, 4.45537228e-02, -2.40643314e-01,
-5.28557398e-02, -9.96888936e-02, 3.65267589e-02,
3.64279623e-02, -8.18503052e-02, -2.02683058e-01,
-1.82493162e-01],
[ 8.98117186e-02, -3.25222791e-02, 4.78793257e-03,
-3.44254824e-02, -9.32265642e-02, -3.37535001e-02,
-2.59333102e-02, -3.14748988e-02, 1.50924621e-02,
-3.88573589e-02],
[ 9.16648589e-02, -9.07544921e-03, 1.15557421e-02,
-6.69898791e-02, 6.43994728e-02, -1.03488028e-01,
-8.25832880e-02, 1.04004400e-01, -1.32508566e-02,
-2.85507593e-02],
[ 1.18606710e-01, -1.62804455e-01, -8.94677872e-02,
9.15753083e-02, -2.46587921e-02, 7.59351306e-02,
6.87165226e-02, -4.79627016e-02, -3.69506298e-02,
-1.07193629e-01],
[ 1.17239054e-02, -2.67437157e-02, -7.68840391e-04,
8.10397970e-02, -1.99904474e-02, 5.51565844e-02,
5.58673400e-04, 3.17302125e-02, 1.11154388e-01,
1.07863925e-01],
[ -1.04641810e-01, 2.78747603e-02, -1.45182739e-03,
-1.53233877e-01, -6.29874225e-02, -1.29592620e-01,
-1.51859556e-01, -4.14495814e-02, -8.45453923e-02,
-1.15042618e-01],
[ -2.20509269e-02, -1.47940439e-01, -4.76452491e-02,
-2.20822576e-02, -2.04331960e-03, 4.55278328e-02,
7.08764808e-02, -1.73128630e-02, 1.93976654e-02,
-2.61378301e-02],
[ 7.04435092e-02, 2.91919488e-03, 4.20234866e-02,
-7.55152665e-02, -4.63803985e-02, -1.07999505e-01,
-8.16274725e-02, 4.19000215e-02, 7.89168322e-02,
-3.86468662e-04],
[ 5.83150997e-03, -1.60543542e-01, -1.42290587e-02,
-6.85734380e-03, 8.61183085e-03, -4.91538703e-02,
3.61739483e-03, 3.47957593e-02, -9.38960015e-03,
1.87941107e-02],
[ 3.86580568e-02, -6.17538775e-02, -3.52311159e-02,
-6.14824396e-03, 3.28743091e-02, 2.30341515e-02,
4.17623711e-02, 3.78719485e-02, 6.40326123e-02,
-1.25724117e-01],
[ -2.69466654e-02, 8.76149869e-02, 8.43603169e-02,
-2.43710820e-02, -7.43506410e-03, -5.92964338e-03,
1.19778728e-02, 1.49802507e-02, -1.24118188e-01,
1.19810388e-01],
[ 1.43340287e-02, -8.91459774e-02, 3.11272246e-02,
1.25600087e-01, 7.31098870e-02, -9.06883544e-03,
3.69710535e-03, -5.08471161e-02, 7.35172675e-02,
-1.54848409e-01],
[ 4.85676467e-02, -3.05382790e-02, 3.27869005e-02,
6.41000141e-02, 7.11917729e-02, 1.18288425e-01,
6.83290253e-03, 9.93292852e-02, -3.07955208e-03,
-1.28590112e-02],
[ -3.03334700e-02, 1.08860071e-02, 8.97625740e-02,
-4.83023573e-02, -3.12615198e-02, -8.61628871e-02,
-1.04541486e-02, 1.33751888e-01, 8.89201773e-02,
-1.06815586e-01],
[ 1.64302198e-01, -1.73583248e-01, -1.29705113e-01,
1.15130432e-01, 5.06143426e-02, 1.38332743e-01,
8.68229188e-02, 2.03039990e-02, 9.97073441e-03,
-1.84693280e-01],
[ 1.39708295e-02, -1.20358859e-01, 2.09749290e-02,
1.59206874e-01, -1.31967745e-01, 9.07331382e-04,
-2.98483952e-02, -2.11940084e-02, 5.57213581e-02,
1.17593488e-01],
[ 1.95239811e-02, 2.48040812e-02, 1.13217120e-01,
-1.32227109e-01, 4.16574397e-02, 4.98457202e-02,
-3.24562114e-02, -6.51531123e-02, -7.33154552e-02,
-3.98947626e-02],
[ 2.66432299e-01, -2.97074213e-01, -5.99969370e-02,
-5.87481183e-02, 1.60074011e-01, 1.30176286e-01,
2.77686796e-01, 1.31290317e-01, 1.22633535e-01,
-1.09230713e-01],
[ 2.00969073e-02, -1.84964992e-01, -5.83985571e-03,
-7.41362658e-02, 1.38365878e-02, 1.06839467e-01,
-2.74847583e-02, -1.92903743e-02, -5.20260714e-03,
2.50301918e-04],
[ -6.17150407e-02, 8.06537180e-02, 7.02343393e-02,
-5.42272016e-02, 9.97773977e-03, -1.98928364e-01,
-2.17237932e-01, -6.73968443e-03, -1.97736271e-01,
5.00901130e-02],
[ -1.66149509e-02, -2.62200709e-03, -2.08381446e-02,
2.57409840e-02, -2.42515149e-02, -3.21673100e-02,
1.49380691e-02, -2.72776353e-02, -8.78237320e-02,
1.68875553e-02],
[ -5.57623273e-02, 3.65157464e-02, -1.13528556e-01,
9.79066670e-03, -5.63900992e-02, -3.71015632e-02,
-1.53407485e-04, -1.72846924e-01, -1.09153670e-01,
-3.32763024e-02],
[ 5.97840241e-03, 4.14377038e-02, 4.48701934e-03,
-6.99469980e-02, 4.48644760e-02, 1.05679300e-01,
-6.51186109e-02, 9.74034880e-03, 2.48868577e-02,
-3.36158742e-02],
[ 1.20077827e-02, -2.03385970e-02, -2.75987445e-03,
-1.16149630e-02, 1.06375004e-01, 5.39069232e-02,
-8.91773269e-02, 2.75564667e-03, 4.46157051e-02,
5.58006479e-02],
[ -1.16916048e-02, 4.60958859e-02, -2.17421025e-02,
-1.86881864e-01, 4.66562070e-02, 4.02356717e-02,
-1.15677821e-01, 4.66245330e-02, -3.02338327e-02,
1.65767924e-03],
[ -6.92209031e-02, -5.00074868e-02, 3.97814441e-02,
2.01908782e-01, -1.24789202e-01, 1.66400945e-01,
-1.04113481e-01, 2.69553450e-02, 4.85165628e-02,
1.82033767e-01],
[ -1.28536858e-01, -3.02310722e-03, 8.69889310e-02,
8.70096528e-02, -4.53932410e-02, 1.66790953e-01,
-1.28925981e-01, -1.83845627e-02, 2.31594216e-02,
1.55204566e-01],
[ -6.96121049e-02, 5.98652894e-02, -4.84150225e-03,
1.57034354e-02, 3.75668782e-02, -1.64650619e-02,
-1.70003480e-02, 1.68722860e-03, -1.58515948e-02,
-4.94031462e-02],
[ -6.33468947e-02, -8.75024885e-03, 5.29224644e-03,
1.02262168e-01, -1.58498505e-02, 5.67770925e-02,
-5.12286565e-02, 7.78941169e-02, 9.62220805e-02,
1.12947449e-01],
[ -1.71659989e-01, 1.16973947e-01, -1.05326386e-01,
-1.66891420e-01, -1.01974213e-01, -6.23915854e-02,
2.04867068e-02, -2.09272962e-01, -4.98385084e-02,
6.39348977e-02],
[ -2.33088886e-02, 4.77167627e-04, 5.12577177e-03,
8.72510224e-02, -6.98289790e-02, -3.83786347e-02,
-1.61912813e-01, 9.99245590e-03, 6.40901957e-03,
1.06622519e-03],
[ -3.22696147e-03, 7.50338364e-04, -4.13295523e-02,
1.11532490e-02, 2.13605470e-03, 7.01725919e-02,
6.56769972e-02, -1.22168268e-02, -3.98609596e-02,
4.02942635e-02],
[ -6.30274025e-02, -8.49640888e-02, 1.38809509e-02,
2.04496514e-01, -1.47865061e-01, 1.14922183e-01,
-1.48793624e-01, 5.45763123e-02, 4.44691877e-02,
1.79254325e-01],
[ -3.70425750e-02, -2.35610560e-02, -4.93268816e-03,
1.31040411e-02, 2.73776025e-02, 1.15118941e-01,
-1.38940692e-01, 3.52187024e-02, 1.86298744e-02,
1.01562678e-01],
[ 1.76901211e-03, 7.25849291e-02, 4.30484675e-02,
-5.29386234e-02, 8.85783359e-02, -2.10795505e-03,
-1.75977381e-02, -6.26338776e-02, 1.10745137e-02,
-3.01402319e-02],
[ -6.51010148e-03, 4.46026348e-03, 3.29594028e-02,
7.42153750e-02, 4.51823405e-02, 2.38861625e-02,
-1.51084355e-01, 2.58145276e-02, 9.78499330e-02,
1.13541447e-02],
[ -1.70992478e-02, -4.46845213e-02, 6.57091110e-02,
-7.22719951e-02, 1.38301387e-01, 3.39854928e-02,
-2.71362382e-03, 5.89179971e-02, 3.98174930e-02,
-2.14658967e-02],
[ 1.47533650e-01, -1.04364285e-01, -8.94515699e-02,
6.64953902e-02, 2.39417285e-02, -2.14146403e-02,
6.62148791e-02, -1.08153465e-02, 5.16891148e-02,
-7.39043785e-02],
[ 2.40624194e-03, 5.38938896e-02, 3.42756790e-02,
4.45555846e-02, 1.21189843e-01, -4.61384428e-02,
4.66344444e-02, 2.32483197e-02, 1.73072110e-02,
-5.31442668e-02],
[ 4.43278059e-02, -1.25351815e-01, 3.30721381e-03,
-4.40161903e-02, 4.02718580e-02, -2.78869336e-02,
5.86890902e-02, 7.31859328e-02, 2.70040392e-02,
1.43055350e-02],
[ -7.64451720e-02, 2.79163924e-02, -4.95644115e-02,
1.40135399e-04, -3.44646526e-02, -5.21001244e-02,
-1.36563861e-02, -3.01335168e-02, -2.14250554e-02,
3.37404738e-02],
[ 5.63956295e-03, 1.31333839e-02, 1.86413210e-02,
3.75087111e-02, 1.14190531e-01, -1.23009164e-02,
5.04972589e-02, -1.08671571e-02, -1.72429908e-02,
6.35628664e-03],
[ -7.81176131e-02, 1.22820378e-02, -4.88642529e-02,
-1.68663757e-02, -4.93644076e-02, -8.76378246e-03,
6.76276358e-02, -2.72725790e-02, -1.12157596e-01,
6.67786778e-02],
[ -1.35052643e-01, 6.17655018e-02, -9.91510117e-02,
-1.36961912e-01, -9.93016461e-02, -2.23026506e-01,
5.43391763e-02, -7.92043386e-02, -4.64242573e-02,
-1.31950335e-01],
[ 9.02050417e-02, 4.89889014e-02, -3.76353176e-02,
-8.21240609e-02, 5.91559984e-02, -2.88601734e-02,
-2.75170553e-02, 4.00444237e-02, 1.02223471e-01,
-8.85973866e-03],
[ -2.79819946e-02, 1.57770779e-02, -1.74634653e-01,
-1.50595889e-01, -1.98462406e-01, -1.58971182e-01,
5.48719855e-02, -1.66760843e-01, -2.75554553e-01,
2.64535338e-02],
[ -8.80281079e-02, -1.79294346e-02, 7.08099011e-02,
1.95550220e-02, -9.05946128e-02, 5.07943950e-02,
-2.29652551e-02, -4.74400823e-02, -2.48743014e-02,
-3.02452203e-02],
[ -1.41639591e-01, -1.14204327e-02, 6.37410597e-03,
3.71080625e-02, -2.52950025e-01, 9.53980112e-02,
-1.04131975e-01, 1.33543914e-02, 5.09378104e-02,
1.09227010e-01],
[ -1.41634459e-02, 7.14195268e-02, 6.84697215e-02,
-9.43577039e-02, 5.96288760e-03, -1.19336735e-01,
-3.07758752e-02, -5.98731421e-02, -9.55384648e-03,
3.84556520e-04],
[ -2.50815619e-02, 4.53682164e-02, 5.46400441e-02,
-2.23102338e-02, 1.17296376e-01, 5.90221088e-02,
-1.21542180e-02, 8.26175342e-03, 6.07650355e-02,
-3.03832471e-02],
[ 9.42653504e-02, -7.87588973e-02, -6.33857297e-02,
5.43245905e-02, 2.31799575e-02, 1.10878154e-01,
4.02822240e-02, 1.89821951e-02, 6.25729673e-02,
-4.26672129e-02],
[ -2.05037699e-02, 1.29917332e-01, -1.01022859e-01,
-9.12096671e-02, 9.67299134e-02, 4.83183980e-02,
-2.53932230e-02, -5.23461370e-02, 9.16840297e-02,
5.16286812e-02]]), array([ 0.13384319, 0.21280285,
0.12673509, 0.0481298 , 0.15588539,
0.16365858, 0.16476655, 0.21108415, 0.17096487, 0.1085302 ])]
{'name': 'PReLU', 'input_shape': (10,)}
[array([ 0.0933486 , 0.21287123, 0.00783342, 0.01653181, 0.17309322,
0.16339809, 0.17657923, 0.16427643, 0.15932031, 0.1788937 ])]
{'epsilon': 1e-06, 'mode': 0, 'name': 'BatchNormalization',
'input_shape': (10,)}
[array([-0.24547945, -0.21909049, 0.24213387, 0.2301027 , -0.2308098 ,
-0.26039818, 0.25852819, -0.2666662 , -0.23721958,
-0.21551781]), array([-0.18698727, 0.18234116, -0.19378199,
0.14785786, 0.20994197,
-0.17636872, -0.0286329 , 0.19325841, 0.15674127, 0.18424748])]
{'p': 0.5, 'name': 'Dropout'}
[]
{'output_dim': 9, 'init': 'glorot_uniform', 'activation': 'linear',
'name': 'Dense', 'input_dim': 10}
[array([[ 0.09531779, 0.08962424, 0.2058567 , -0.01249131, 0.30507946,
-0.53983095, -0.06707083, 0.19728828, 0.20076663],
[-0.32254363, -0.42794729, -0.25050942, -0.33067779, -0.41560606,
0.44383431, -0.41416482, 0.06048449, -0.41940528],
[ 0.44522338, -0.32401378, -0.47518886, 0.18116826, 0.37900766,
-0.25756208, 0.19068314, 0.20710987, 0.54537 ],
[-0.31705924, -0.24385525, -0.42045825, -0.06039081, -0.23035229,
0.13377106, -0.34439084, 0.50066419, -0.29471258],
[-0.46195437, 0.32470544, 0.23317512, -0.03074575, -0.03195262,
-0.12666796, -0.0689242 , 0.29082389, -0.03790227],
[ 0.08793353, 0.23260939, -0.02797571, 0.08840966, 0.26917765,
-0.42360304, 0.53772368, -0.30116916, 0.14197116],
[-0.02250078, -0.25081609, -0.09740926, 0.40152244, -0.22063047,
0.26022703, -0.34145828, -0.53870416, -0.48703162],
[-0.45320813, 0.50861172, 0.19536802, 0.07115475, 0.16693423,
-0.20375885, -0.26160637, -0.09466415, 0.0106583 ],
[-0.03785352, 0.45134675, 0.46924738, 0.003416 , -0.3582257 ,
-0.38137255, -0.47833458, -0.37664575, 0.31648622],
[-0.20752858, 0.18646851, 0.23139626, -0.32362606, -0.12266297,
0.43289664, -0.07509048, -0.52767154, -0.15752394]]),
array([-0.16717777, 0.16997226, 0.05989215, -0.16888715, -0.16850433,
0.16732873, -0.16105396, 0.09626825, -0.12726908])]
{'activation': 'softmax', 'name': 'Activation'}
[]
PReLU and BatchNormalization do, of course, have weights. They are learnable layers.
Also the model above has one "hidden layer" of size 10 (but reasoning in terms of hidden layers can be confusing, better think in terms of operations). You are doing one projection from the initial space to a space of size 10, then a second projection from that space to a space of size nb_classes.
I use the following function to print out dumped weight file from keras.
from __future__ import print_function
import h5py
def print_structure(weight_file_path):
"""
Prints out the structure of HDF5 file.
Args:
weight_file_path (str) : Path to the file to analyze
"""
f = h5py.File(weight_file_path)
try:
if len(f.attrs.items()):
print("{} contains: ".format(weight_file_path))
print("Root attributes:")
for key, value in f.attrs.items():
print(" {}: {}".format(key, value))
if len(f.items())==0:
return
for layer, g in f.items():
print(" {}".format(layer))
print(" Attributes:")
for key, value in g.attrs.items():
print(" {}: {}".format(key, value))
print(" Dataset:")
for p_name in g.keys():
param = g[p_name]
print(" {}: {}".format(p_name, param.shape))
finally:
f.close()
and output is something like this
(This is from my model and does not represent from your model.)
Attributes:
nb_params: 2
subsample: [1 1]
init: glorot_uniform
nb_filter: 32
name: Convolution2D
activation: linear
border_mode: full
nb_col: 3
stack_size: 3
nb_row: 3
Dataset:
param_0: (32, 3, 3, 3)
param_1: (32,)
layer_1
Attributes:
nb_params: 0
activation: relu
name: Activation
Dataset:
layer_2
Attributes:
nb_params: 2
subsample: [1 1]
init: glorot_uniform
nb_filter: 32
name: Convolution2D
activation: linear
border_mode: valid
nb_col: 3
stack_size: 32
nb_row: 3
Dataset:
param_0: (32, 32, 3, 3)
param_1: (32,)
layer_3
Attributes:
nb_params: 0
activation: relu
name: Activation
Dataset:
layer_4
Attributes:
nb_params: 0
name: MaxPooling2D
ignore_border: True
poolsize: [2 2]
Dataset:
So I can tell that layer_0 has Convolution2D and weight is stored at ‘param_0’ attribute and its shape is (32, 3, 3, 3), which means there are 32 filters, with 3 channels, 3pixel height 3pixel width, and bias is stored at ‘param_1’ and its shape (32,), one for each filter.
To access them, use model.layers[0].params[0] for weight and model.layers[0].params[1] for bias
@mthrok thank you - this is really helpful!
@mthrok It's extremely helpful. Thank you very much!
Instead of storing weights i want to store features. How to do this?
@mthrok I tried your function, it report
print(" {}: {}".format(p_name, param.shape))
AttributeError: 'Group' object has no attribute 'shape'
However, when I tried print(" {}: {}".format(p_name, param.shape)) independently, it's able to work. Do you have any idea about that?
Has anything changed with Keras 2.x? it isn't working for me
Changing from param.shape to param on the line which failed seems to work in eliminating the AttributeError ronzilllia mentions.
Some modification to mthrok's answer to slove the issue "AttributeError: 'Group' object has no attribute 'shape'"
from __future__ import print_function
import h5py
def print_structure(weight_file_path):
"""
Prints out the structure of HDF5 file.
Args:
weight_file_path (str) : Path to the file to analyze
"""
f = h5py.File(weight_file_path)
try:
if len(f.attrs.items()):
print("{} contains: ".format(weight_file_path))
print("Root attributes:")
for key, value in f.attrs.items():
print(" {}: {}".format(key, value))
if len(f.items())==0:
return
for layer, g in f.items():
print(" {}".format(layer))
print(" Attributes:")
for key, value in g.attrs.items():
print(" {}: {}".format(key, value))
print(" Dataset:")
for p_name in g.keys():
param = g[p_name]
subkeys = param.keys()
for k_name in param.keys():
print(" {}/{}: {}".format(p_name, k_name, param.get(k_name)[:]))
finally:
f.close()
Then it will prints something like below:
Root attributes:
layer_names: ['dense_2']
backend: tensorflow
keras_version: 2.0.8
dense_2
Attributes:
weight_names: ['dense_2/kernel:0' 'dense_2/bias:0']
Dataset:
dense_2/bias:0: [ 2.00016475]
dense_2/kernel:0: [[ 2.99988198]]
I'm currently working with a tied-weight autoencoder which requires the extraction of weight matrix from a previous convolutional layer. However, some of the code I tried was written in an older version of layer.W, which seems to return a matrix with different dimension as the current method layer.get_weights()[0]. Does anyone have any idea of what I should do to fix it?
I'm trying to drop weights(zero out some weights) from individual layers during testing my ResNet50 model(trained for Aerial scene classification), loaded from my model.h5 weight file in Keras. The get_weights solves half the problem but not sure how I would place the weights back, after making changes to it.
Does anyone have an idea how to edit weights of individual layers and then test the model in Keras?
anyone have something written that can throw each layer in as columns of a spreadsheet? I can only see a small number of weights with this... my network is pretty big. I don't see any zeroes in the printout, but I want to make sure my leakyRelu layers actually aren't getting any (or at least many) weights zeroed out.
i have the weightof model which is trained with matlab the file called weights.mat i want to load this weights in keras how i can do that ??
i think in keras there are only h5 file who can load it i don't know how i can load
see my initial attempt to extract layer parameters into a csv file. I needed to see the magnitude of the layer weights and this was a way to view them.
https://github.com/kristpapadopoulos/keras_tools/blob/master/extract_parameters.py
@kristpapadopoulos I tried using ur code and I am getting this error:
Parameter File: checkpoints/weights_multi_fb_20180513_2.best.hdf5
Extracting Model Parameters to CSV File...
Traceback (most recent call last):
File "print_weights.py", line 54, in
weights[layer].extend(param[k_name].value[:].flatten().tolist())
AttributeError: 'Group' object has no attribute 'value'
Do you have any idea?
I made an update if the group object has no parameters (ie layer has no parameters) then None is assigned to avoid issue.
https://github.com/kristpapadopoulos/keras_tools/blob/master/extract_parameters.py
If you're just looking to print the weights, I would suggest using the h5dump utility:
https://support.hdfgroup.org/HDF5/doc/RM/Tools.html#Tools-Dump
Some modification to mthrok's answer to slove the issue "AttributeError: 'Group' object has no attribute 'shape'"
from __future__ import print_function import h5py def print_structure(weight_file_path): """ Prints out the structure of HDF5 file. Args: weight_file_path (str) : Path to the file to analyze """ f = h5py.File(weight_file_path) try: if len(f.attrs.items()): print("{} contains: ".format(weight_file_path)) print("Root attributes:") for key, value in f.attrs.items(): print(" {}: {}".format(key, value)) if len(f.items())==0: return for layer, g in f.items(): print(" {}".format(layer)) print(" Attributes:") for key, value in g.attrs.items(): print(" {}: {}".format(key, value)) print(" Dataset:") for p_name in g.keys(): param = g[p_name] subkeys = param.keys() for k_name in param.keys(): print(" {}/{}: {}".format(p_name, k_name, param.get(k_name)[:])) finally: f.close()
Then it will prints something like below:
Root attributes: layer_names: ['dense_2'] backend: tensorflow keras_version: 2.0.8 dense_2 Attributes: weight_names: ['dense_2/kernel:0' 'dense_2/bias:0'] Dataset: dense_2/bias:0: [ 2.00016475] dense_2/kernel:0: [[ 2.99988198]]
AttributeError: 'slice' object has no attribute 'encode'
I just encountered the same problem. I solved it by using model.save_weights("path")
instead of model.save("path")
.
@S601327412 It might be caused by the Layer wrapper. If you create a model with a Layer Wrapper, there will be a nested group in your h5 structure which isn't consistent to mthrok's code.
Hello,
Sorry for my very bad english
I work to create a little package of deep learning for my calculator.
I have just implemented the predict function.
I would like to train my model with keras, and after, give the coefficients to my function.
My function take in input, a list of coefficient (w1, w2, w3, ..., wn)
How can I have just a list of all coefficient, in a model without convolution?
Thank you very much for your reply
Hi @mactul. Perhaps using something like the "get_weights()" function at the end of this code block might be what you are looking for:
from keras.models import Sequential
from keras.layers import Dense, Activation
# For a single-input model with 2 classes (binary classification):
model = Sequential()
model.add(Dense(8, activation='relu', input_dim=10))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['accuracy'])
# Generate dummy data
import numpy as np
data = np.random.random((1000, 10))
labels = np.random.randint(2, size=(1000, 1))
# Train the model, iterating on the data in batches of 32 samples
model.fit(data, labels, epochs=10, batch_size=32)
# print out the weights (coefficients)
print(model.get_weights())
I hope this helps.
Thanks.
I have tryed your method, it's work good, but how can I translate this array in a list ?
I want to have coefficient of neuronal's input, and there place in the network
A sample example:
w1-
-w4
w2- -w6
-w5
w3-
[w1, w2, w3, w4, w5, w6]
sorry, I'm not very explicit.
Hi @mactul. Below I created a simplified version with explicit details of the weights. I did a prediction using the model.predict in Keras and also a direct calculation using the weights and biases.
This simplified model has 4 inputs into 2 nodes, then into one output node:
x1 x2 x3 x4 (input)
w1_11 w1_12 w1_13 w1_14 (+bias) w1_21 w1_22 w1_23 w1_24 (+bias) (layer 1)
w2_1 w2_2 (+bias) (layer 2-output)
I hope this helps.
Thanks.
# Simplified version
#
from keras.models import Sequential
from keras.layers import Dense, Activation
# For a single-input model with 2 classes (binary classification):
model = Sequential()
model.add(Dense(2, activation='linear', input_dim=4))
model.add(Dense(1, activation='linear'))
model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['accuracy'])
# Generate dummy data
import numpy as np
data = np.random.random((1000, 4)) # input is 1000 examples x 4 features
labels = np.random.randint(2, size=(1000, 1)) # output(label) is 1000 examples x 1
# Train the model, iterating on the data in batches of 32 samples
model.fit(data, labels, epochs=10, batch_size=32)
model.summary()
print()
# print out the weight coefficients of first layer
kernel_1, bias_1 = model.layers[0].get_weights()
print('kernel_1:')
print(kernel_1)
print('bias_1:')
print(bias_1)
print()
# print out the weight coefficients of second output layer
kernel_2, bias_2 = model.layers[1].get_weights()
print('kernel_2:')
print(kernel_2)
print('bias:_2')
print(bias_2)
print()
#predict with keras.model.predict
x_test = np.array([1, 1.5, -4 ,3])
x_test=np.expand_dims(x_test, axis = 0)
print('test input:')
print(x_test)
print('model.predict:')
print(model.predict(x_test))
#predict with direct calculation using manual summations
print()
node_1_sum = \
x_test[0,0]*kernel_1[0,0]+ \
x_test[0,1]*kernel_1[1,0]+ \
x_test[0,2]*kernel_1[2,0]+ \
x_test[0,3]*kernel_1[3,0]+ bias_1[0]
print('node_1_sum:')
print(node_1_sum)
node_2_sum = \
x_test[0,0]*kernel_1[0,1]+ \
x_test[0,1]*kernel_1[1,1]+ \
x_test[0,2]*kernel_1[2,1]+ \
x_test[0,3]*kernel_1[3,1]+ bias_1[1]
print('node_2_sum:')
print(node_2_sum)
#output layer
output_layer = node_1_sum*kernel_2[0] + node_2_sum*kernel_2[1] + bias_2[0]
print('final result of network using manual calculations = ', output_layer)
@td2014 hi ,I want to get weights from two model and average them.Then put the averaged weights in a new model.Three models hava same structure. How can I implement this? Wish you can help me ,thank you.
Hi @AmberrrLiu . I don't know for sure, but you might want to take a look at the first section of this page: https://keras.io/models/about-keras-models/ . It mentions get_weights, set_weights, save_weights, and load_weights functions. It might be possible to get the weights from each of your models, do the averaging using python/numpy, then set the weights in the new model. You can also save and reload if that works for you.
I hope this helps.
Thank you.set_weights function works well.
Most helpful comment
Sure. The method model.save_weights() will do it for you and store the weights to hdf5.
If you want to do it manually, you'd do something like: