Kerasで保存済みモデルの重みを初期化する

保存したKerasモデルのアーキテクチャだけ使って、重みを初期化して最初から別データで学習しようとして詰まったのでメモ。

参考にしたサイト

tf.kerasの罠

従来のkerasではバックエンドとしてsessionを呼び出すことで、リセットすることが可能でした。
https://stackoverflow.com/questions/40496069/reset-weights-in-keras-layer

残念ながらtensorflowに組み込まれたkeras(tensorflow2.0以上)ではget_sessionは無いと怒られ、実行できません。tf.kerasでは従来の初期化手法が使えないようです。

Reset/Reinitialize model weights/parameters #341の下の方に行くと、tf.kerasの場合の解決策は以下の通りに示されています。※私の場合動きませんでした

def reset_weights(model):
    for layer in model.layers:
        if isinstance(layer, tf.keras.Model): #if you're using a model as a layer
            reset_weights(layer) #apply function recursively
            continue

        #where are the initializers?
        if hasattr(layer, 'cell'):
            init_container = layer.cell
        else:
            init_container = layer

        for key, initializer in init_container.__dict__.items():
            if "initializer" not in key: #is this item an initializer?
                  continue #if no, skip it

            # find the corresponding variable, like the kernel or the bias
            if key == 'recurrent_initializer': #special case check
                var = getattr(init_container, 'recurrent_kernel')
            else:
                var = getattr(init_container, key.replace("_initializer", ""))

            var.assign(initializer(var.shape, var.dtype))
            #use the initializer`

key.replaceでkernelの部分を呼び出しているようですが、kernelにはassignという関数はありません!と怒られます。

解決策

各layerのinitializerの取得は上記の回答を参考にし、そこから直接重みを生成→set_weightsメソッドで再割り当てする方法を考えました。無事に動きます。

import tensorflow as tf

def reset_weights(model,weights=None):
    for layer in model.layers:
        print("---------------")
        print(layer)
        if isinstance(layer, tf.keras.Model): #if you're using a model as a layer
            reset_weights(layer) #apply function recursively
            continue
        if hasattr(layer, 'cell'):
            init_container = layer.cell
        else:
            init_container = layer

        for key, initializer in init_container.__dict__.items():
            if ("kernel_initializer" or "recurrent_initializer") not in key: #is this item an initializer?
                  continue #if no, skip it
            else:
                print("key:")
                print(key)
                print("initializer:")
                print(initializer)
                
                #replace weights with initialized values
                weights = layer.get_weights()
                weights = [initializer(w.shape, w.dtype) for w in weights] 
                layer.set_weights(weights)
                
loaded_model= load_model("model.h5")       
reset_weights(loaded_model)

出力

reset_weights(model)で重みをリセットできます。elseの中のprint分はただのデバッグ用ですので消去して問題ありません。

各kernel_initializerがweightsの形に合わせた初期重みを生成し、.set_weightsで重みをリセットしました。

biasは初期化していないため全くまっさらなモデルとは言えませんが、再訓練するとlossが増加しており、ほぼ一から訓練しなおせました。

keras側に普通にこの機能実装して欲しいですね。

コメントする

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

Exit mobile version