简体   繁体   中英

rewrite TensorFlow 1.x to 2.x version

I need to rewrite a Tensorflow 1.x code to 2.x version. So, I rewrote the commented code as follows (the different activations and initializers were modified by myself):

def model(X, nact):
    # h = conv(tf.cast(X, tf.float32), nf=32, rf=8, stride=1, init_scale=np.sqrt(2))
    h = tf.keras.layers.Conv2D(filters=32,
                               kernel_size=8,
                               activation='relu',
                               kernel_initializer=orthogonal(np.sqrt(2)))(X)
    # h2 = conv(h, nf=64, rf=4, stride=1, init_scale=np.sqrt(2))
    h2 = tf.keras.layers.Conv2D(filters=64,
                                kernel_size=4,
                                activation='relu',
                                kernel_initializer=orthogonal(np.sqrt(2)))(h)
    . . .
    # pi = fc(h4, nact, act=lambda x: x)
    pi = tf.keras.layers.Dense(units=nact,
                               activation='linear',
                               kernel_initializer=orthogonal(np.sqrt(2)))(h4)
    # vf = fc(h4, 1, act=lambda x: tf.tanh(x))
    vf = tf.keras.layers.Dense(units=1,
                               activation='tanh',
                               kernel_initializer=orthogonal(np.sqrt(2)))(h4)

    # filter out non-valid actions from pi
    valid = tf.reduce_max(tf.cast(X, tf.float32), axis=1)
    valid_flat = tf.reshape(valid, [-1, nact])
    pi_fil = pi + (valid_flat - tf.ones(tf.shape(valid_flat))) * 1e32

    return pi_fil, vf[:, 0]

Some methods further I have the following:

def build_model(args):
    nh = args.max_clause
    nw = args.max_var
    nc = 2
    nact = nc * nw
    ob_shape = (None, nh, nw, nc * args.n_stack)
    X = tf.placeholder(tf.float32, ob_shape)
    Y = tf.placeholder(tf.float32, (None, nact))
    Z = tf.placeholder(tf.float32, None)

    p, v = model(X, nact)
    params = tf.trainable_variables()
    with tf.name_scope("loss"):
        cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=p))
        value_loss = tf.losses.mean_squared_error(labels=Z, predictions=v)
        lossL2 = tf.add_n([tf.nn.l2_loss(vv) for vv in params])
        loss = cross_entropy + value_loss + args.l2_coeff * lossL2

    return X, Y, Z, p, v, params, loss

def self_play(args, status_track):
    X, _, _, p, v, params, _ = build_model(args)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        model_dir = status_track.get_model_dir()
        sess.run(load(params, os.path.join(args.save_dir, model_dir)))
        . . .

def super_train(args, status_track):
    X, Y, Z, _, _, params, loss = build_model(args)
    with tf.name_scope("train"):
        train_step = tf.train.AdamOptimizer(1e-3).minimize(loss)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        model_dir = status_track.get_sl_starter()
        sess.run(load(params, os.path.join(args.save_dir, model_dir)))
        . . .

How can I rewrite these two functions in a TensorFlow 2.x, ie, Keras-like style?

Tensorflow 2.x compatible code snippet.

def build_model(args):
    nh = args.max_clause
    nw = args.max_var
    nc = 2
    nact = nc * nw
    ob_shape = (None, nh, nw, nc * args.n_stack)
    X = tf.compat.v1.placeholder(tf.float32, ob_shape)
    Y = tf.compat.v1.placeholder(tf.float32, (None, nact))
    Z = tf.compat.v1.placeholder(tf.float32, None)

    p, v = model(X, nact)
    params = tf.compat.v1.trainable_variables()
    with tf.name_scope("loss"):
        cross_entropy = tf.reduce_mean(tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=p))
        value_loss = tf.keras.metrics.mean_squared_error(labels=Z, predictions=v)
        lossL2 = tf.add_n([tf.compat.v1.nn.l2_loss(vv) for vv in params])
        loss = cross_entropy + value_loss + args.l2_coeff * lossL2

    return X, Y, Z, p, v, params, loss

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM