簡體   English   中英

tf.where使張量流中的優化器失敗

[英]tf.where causes optimiser to fail in tensorflow

我想檢查我是否可以使用tensorflow而不是pymc3解決問題。 實驗性的想法是,我將定義一個包含切換點的概率系統。 我可以將采樣用作推論方法,但我開始懷疑為什么我不能只使用梯度下降來做到這一點。

我決定做tensorflow梯度搜索,但它似乎像tensorflow是不好受進行梯度搜索時tf.where參與。

您可以在下面找到代碼。

import tensorflow as tf
import numpy as np

x1 = np.random.randn(50)+1
x2 = np.random.randn(50)*2 + 5
x_all = np.hstack([x1, x2])
len_x = len(x_all)
time_all = np.arange(1, len_x + 1)

mu1 = tf.Variable(0, name="mu1", dtype=tf.float32)
mu2 = tf.Variable(5, name = "mu2", dtype=tf.float32)
sigma1 = tf.Variable(2, name = "sigma1", dtype=tf.float32)
sigma2 = tf.Variable(2, name = "sigma2", dtype=tf.float32)
tau = tf.Variable(10, name = "tau", dtype=tf.float32)

mu = tf.where(time_all < tau,
              tf.ones(shape=(len_x,), dtype=tf.float32) * mu1,
              tf.ones(shape=(len_x,), dtype=tf.float32) * mu2)
sigma = tf.where(time_all < tau,
              tf.ones(shape=(len_x,), dtype=tf.float32) * sigma1,
              tf.ones(shape=(len_x,), dtype=tf.float32) * sigma2)

likelihood_arr = tf.log(tf.sqrt(1/(2*np.pi*tf.pow(sigma, 2)))) -tf.pow(x_all - mu, 2)/(2*tf.pow(sigma, 2))
total_likelihood = tf.reduce_sum(likelihood_arr, name="total_likelihood")

optimizer = tf.train.RMSPropOptimizer(0.01)
opt_task = optimizer.minimize(-total_likelihood)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print("these variables should be trainable: {}".format([_.name for _ in tf.trainable_variables()]))
    for step in range(10000):
        _lik, _ = sess.run([total_likelihood, opt_task])
        if step % 1000 == 0:
            variables = {_.name:_.eval() for _ in [mu1, mu2, sigma1, sigma2, tau]}
            print("step: {}, values: {}".format(str(step).zfill(4), variables))

您會注意到,即使tensorflow似乎知道該變量及其漸變,tau參數也不會更改。 關於出什么問題的任何線索? 這是可以在tensorflow中計算的東西嗎,或者我需要其他模式嗎?

tau ,僅在使用condition參數where :( tf.where(time_all < tau, ... 。),這是一個布爾張量由於計算梯度才有意義為連續值時,輸出的斜率相對於tau將為零。

即使忽略tf.where中,使用tau在表達式time_all < tau ,其是恆定的幾乎無處不在,所以具有零梯度。

由於梯度為零,因此無法使用梯度下降法學習tau

根據您的問題,也許可以使用加權和代替p*val1 + (1-p)*val2來代替在兩個值之間進行硬切換,其中p連續依賴於tau

分配的解決方案是正確的答案,但不包含解決我的問題的代碼。 以下代碼片段可以;

import tensorflow as tf
import numpy as np
import os
import uuid

TENSORBOARD_PATH = "/tmp/tensorboard-switchpoint"
# tensorboard --logdir=/tmp/tensorboard-switchpoint

x1 = np.random.randn(35)-1
x2 = np.random.randn(35)*2 + 5
x_all = np.hstack([x1, x2])
len_x = len(x_all)
time_all = np.arange(1, len_x + 1)

mu1 = tf.Variable(0, name="mu1", dtype=tf.float32)
mu2 = tf.Variable(0, name = "mu2", dtype=tf.float32)
sigma1 = tf.Variable(2, name = "sigma1", dtype=tf.float32)
sigma2 = tf.Variable(2, name = "sigma2", dtype=tf.float32)
tau = tf.Variable(15, name = "tau", dtype=tf.float32)
switch = 1./(1+tf.exp(tf.pow(time_all - tau, 1)))

mu = switch*mu1 + (1-switch)*mu2
sigma = switch*sigma1 + (1-switch)*sigma2

likelihood_arr = tf.log(tf.sqrt(1/(2*np.pi*tf.pow(sigma, 2)))) - tf.pow(x_all - mu, 2)/(2*tf.pow(sigma, 2))
total_likelihood = tf.reduce_sum(likelihood_arr, name="total_likelihood")

optimizer = tf.train.AdamOptimizer()
opt_task = optimizer.minimize(-total_likelihood)
init = tf.global_variables_initializer()

tf.summary.scalar("mu1", mu1)
tf.summary.scalar("mu2", mu2)
tf.summary.scalar("sigma1", sigma1)
tf.summary.scalar("sigma2", sigma2)
tf.summary.scalar("tau", tau)
tf.summary.scalar("likelihood", total_likelihood)
merged_summary_op = tf.summary.merge_all()

with tf.Session() as sess:
    sess.run(init)
    print("these variables should be trainable: {}".format([_.name for _ in tf.trainable_variables()]))
    uniq_id = os.path.join(TENSORBOARD_PATH, "switchpoint-" + uuid.uuid1().__str__()[:4])
    summary_writer = tf.summary.FileWriter(uniq_id, graph=tf.get_default_graph())
    for step in range(40000):
        lik, opt, summary = sess.run([total_likelihood, opt_task, merged_summary_op])
        if step % 100 == 0:
            variables = {_.name:_.eval() for _ in [total_likelihood]}
            summary_writer.add_summary(summary, step)
            print("i{}: {}".format(str(step).zfill(5), variables))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM