繁体   English   中英

深度学习点网格的bezier控制点

[英]deep learning bezier control points from point mesh

我想构建一个神经网络,当给定一组较大的点作为曲线的一部分时,它将生成贝塞尔曲线(其控制点)的参数。

换句话说,通过网格的贝塞尔曲线是什么。 我用它来教自己TensorFlow。 使用我到目前为止构建的代码,我使用输入dim 3得到“索引超出范围;输入只有3个dims [Op:StridedSlice]名称:strided_slice”但欢迎对该方法的整体建议!

geomdl是一个从控制点生成网格的库,我用它来生成训练。

    # this program has the following structure:
    # input = 100 pairs of x-y points
    # output = 8 pairs of x-y that are the control points of a bezier curve
    from geomdl import BSpline
    from geomdl import utilities

    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers

    import random
    import numpy as np
    tf.enable_eager_execution()

我从8点控制点生成100点的网格

    # those are utilities to generate the training data
    n_evalpts = 100
    def evalpts(ctrlpts):
        c = BSpline.Curve()
        c.degree = 2
        c.ctrlpts = ctrlpts
        c.delta = 1/(n_evalpts-0.5)
        c.knotvector = utilities.generate_knot_vector(c.degree, len(c.ctrlpts))
        c.render
        return c.evalpts

    def mevalpts(ctrlptslist):
        evalptslist=[]
        for i in range(len(ctrlptslist)):
            evalptslist.append(evalpts(ctrlptslist[i]))
        return evalptslist

    #curve = evalpts([[5.0, 10.0], [15.0, 25.0], [30.0, 30.0], [45.0, 5.0], [55.0, 5.0],
    #                 [70.0, 40.0], [60.0, 60.0], [35.0, 60.0]])

现在生成训练数据

    #generate X samples of inputs (100 pairs of points) and outputs (8 pairs)
    NUM_SAMPLES = 25
    outputs = tf.random_uniform(shape=[NUM_SAMPLES, 8,2], maxval=5)
    inputs = np.array(mevalpts(outputs.numpy().tolist())) #shape [n_evalpts,2]
    #inputs = tf.Variable(mevalpts(outputs.numpy().tolist()))
    print(inputs.shape)
    print(outputs.shape)

定义模型(尝试具有自定义丢失功能,但放弃了)

    def model():
        model = keras.Sequential([
            layers.Flatten(input_shape=(n_evalpts,2)),
            layers.Dense(64, activation=tf.nn.relu),
            layers.Dense(64, activation=tf.nn.relu),
            layers.Dense(32, activation=tf.nn.relu),
            layers.Dense(16)
        ])
        optimizer = tf.keras.optimizers.RMSprop(0.001)
        model.compile(loss=loss,  #'mean_squared_error',
                      optimizer=optimizer,
                      metrics=['mean_absolute_error', 'mean_squared_error'])
        return model

    model = model()

适合数据,但这不起作用

    EPOCHS = 100

    history = model.fit(
      inputs,
      outputs,
      epochs=EPOCHS,
      #validation_split = 0.2,
      #verbose=0
      )

更新:

我得到的确切错误,最后一个单元格如下:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-9-9994db909e15> in <module>
      4   inputs,
      5   outputs,
----> 6   epochs=EPOCHS,
      7   #validation_split = 0.2,
      8   #verbose=0

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, max_queue_size, workers, use_multiprocessing, **kwargs)
    878           initial_epoch=initial_epoch,
    879           steps_per_epoch=steps_per_epoch,
--> 880           validation_steps=validation_steps)
    881 
    882   def evaluate(self,

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, mode, validation_in_fit, **kwargs)
    308           if ins and isinstance(ins[-1], int):
    309             # Do not slice the training phase flag.
--> 310             ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
    311           else:
    312             ins_batch = slice_arrays(ins, batch_ids)

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in slice_arrays(arrays, start, stop)
    524       if hasattr(start, 'shape'):
    525         start = start.tolist()
--> 526       return [None if x is None else x[start] for x in arrays]
    527     else:
    528       return [None if x is None else x[start:stop] for x in arrays]

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in <listcomp>(.0)
    524       if hasattr(start, 'shape'):
    525         start = start.tolist()
--> 526       return [None if x is None else x[start] for x in arrays]
    527     else:
    528       return [None if x is None else x[start:stop] for x in arrays]

/usr/local/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py in _slice_helper(tensor, slice_spec, var)
    652         ellipsis_mask=ellipsis_mask,
    653         var=var,
--> 654         name=name)
    655 
    656 

/usr/local/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py in strided_slice(input_, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, var, name)
    818       ellipsis_mask=ellipsis_mask,
    819       new_axis_mask=new_axis_mask,
--> 820       shrink_axis_mask=shrink_axis_mask)
    821 
    822   parent_name = name

/usr/local/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py in strided_slice(input, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, name)
   9332       else:
   9333         message = e.message
-> 9334       _six.raise_from(_core._status_to_exception(e.code, message), None)
   9335   # Add nodes to the TensorFlow graph.
   9336   if begin_mask is None:

/usr/local/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: Index out of range using input dim 3; input has only 3 dims [Op:StridedSlice] name: strided_slice/

我能够在这里建立自己的网络。 基本的错误是我的模型产生了一个扁平(16,1)阵列,但我正在喂一个(8,2)真。

我已将最后一个单元格更改为:

EPOCHS = 1000

history = model.fit(
  inputs,
  np.reshape(outputs,(NUM_SAMPLES,16)), #previously this was outputs
  epochs=EPOCHS,
  validation_split = 0.2,
  #verbose=0
  )

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM