def reset_weights(model):
for layer in model.layers:
if isinstance(layer, tf.keras.Model): #if you're using a model as a layer
reset_weights(layer) #apply function recursively
continue
#where are the initializers?
if hasattr(layer, 'cell'):
init_container = layer.cell
else:
init_container = layer
for key, initializer in init_container.__dict__.items():
if "initializer" not in key: #is this item an initializer?
continue #if no, skip it
# find the corresponding variable, like the kernel or the bias
if key == 'recurrent_initializer': #special case check
var = getattr(init_container, 'recurrent_kernel')
else:
var = getattr(init_container, key.replace("_initializer", ""))
var.assign(initializer(var.shape, var.dtype))
#use the initializer`
import tensorflow as tf
def reset_weights(model,weights=None):
for layer in model.layers:
print("---------------")
print(layer)
if isinstance(layer, tf.keras.Model): #if you're using a model as a layer
reset_weights(layer) #apply function recursively
continue
if hasattr(layer, 'cell'):
init_container = layer.cell
else:
init_container = layer
for key, initializer in init_container.__dict__.items():
if ("kernel_initializer" or "recurrent_initializer") not in key: #is this item an initializer?
continue #if no, skip it
else:
print("key:")
print(key)
print("initializer:")
print(initializer)
#replace weights with initialized values
weights = layer.get_weights()
weights = [initializer(w.shape, w.dtype) for w in weights]
layer.set_weights(weights)
loaded_model= load_model("model.h5")
reset_weights(loaded_model)
保存したKerasモデルのアーキテクチャだけ使って、重みを初期化して最初から別データで学習しようとして詰まったのでメモ。
目次
参考にしたサイト
tf.kerasの罠
従来のkerasではバックエンドとしてsessionを呼び出すことで、リセットすることが可能でした。
https://stackoverflow.com/questions/40496069/reset-weights-in-keras-layer
残念ながらtensorflowに組み込まれたkeras(tensorflow2.0以上)ではget_sessionは無いと怒られ、実行できません。tf.kerasでは従来の初期化手法が使えないようです。
Reset/Reinitialize model weights/parameters #341の下の方に行くと、tf.kerasの場合の解決策は以下の通りに示されています。※私の場合動きませんでした
key.replace
でkernelの部分を呼び出しているようですが、kernelにはassignという関数はありません!と怒られます。解決策
各layerのinitializerの取得は上記の回答を参考にし、そこから直接重みを生成→set_weightsメソッドで再割り当てする方法を考えました。無事に動きます。
出力
reset_weights(model)
で重みをリセットできます。elseの中のprint分はただのデバッグ用ですので消去して問題ありません。各kernel_initializerがweightsの形に合わせた初期重みを生成し、
.set_weights
で重みをリセットしました。biasは初期化していないため全くまっさらなモデルとは言えませんが、再訓練するとlossが増加しており、ほぼ一から訓練しなおせました。
keras側に普通にこの機能実装して欲しいですね。