簡體   English   中英

深度學習點網格的bezier控制點

[英]deep learning bezier control points from point mesh

我想構建一個神經網絡,當給定一組較大的點作為曲線的一部分時,它將生成貝塞爾曲線(其控制點)的參數。

換句話說,通過網格的貝塞爾曲線是什么。 我用它來教自己TensorFlow。 使用我到目前為止構建的代碼,我使用輸入dim 3得到“索引超出范圍;輸入只有3個dims [Op:StridedSlice]名稱:strided_slice”但歡迎對該方法的整體建議!

geomdl是一個從控制點生成網格的庫,我用它來生成訓練。

    # this program has the following structure:
    # input = 100 pairs of x-y points
    # output = 8 pairs of x-y that are the control points of a bezier curve
    from geomdl import BSpline
    from geomdl import utilities

    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers

    import random
    import numpy as np
    tf.enable_eager_execution()

我從8點控制點生成100點的網格

    # those are utilities to generate the training data
    n_evalpts = 100
    def evalpts(ctrlpts):
        c = BSpline.Curve()
        c.degree = 2
        c.ctrlpts = ctrlpts
        c.delta = 1/(n_evalpts-0.5)
        c.knotvector = utilities.generate_knot_vector(c.degree, len(c.ctrlpts))
        c.render
        return c.evalpts

    def mevalpts(ctrlptslist):
        evalptslist=[]
        for i in range(len(ctrlptslist)):
            evalptslist.append(evalpts(ctrlptslist[i]))
        return evalptslist

    #curve = evalpts([[5.0, 10.0], [15.0, 25.0], [30.0, 30.0], [45.0, 5.0], [55.0, 5.0],
    #                 [70.0, 40.0], [60.0, 60.0], [35.0, 60.0]])

現在生成訓練數據

    #generate X samples of inputs (100 pairs of points) and outputs (8 pairs)
    NUM_SAMPLES = 25
    outputs = tf.random_uniform(shape=[NUM_SAMPLES, 8,2], maxval=5)
    inputs = np.array(mevalpts(outputs.numpy().tolist())) #shape [n_evalpts,2]
    #inputs = tf.Variable(mevalpts(outputs.numpy().tolist()))
    print(inputs.shape)
    print(outputs.shape)

定義模型(嘗試具有自定義丟失功能,但放棄了)

    def model():
        model = keras.Sequential([
            layers.Flatten(input_shape=(n_evalpts,2)),
            layers.Dense(64, activation=tf.nn.relu),
            layers.Dense(64, activation=tf.nn.relu),
            layers.Dense(32, activation=tf.nn.relu),
            layers.Dense(16)
        ])
        optimizer = tf.keras.optimizers.RMSprop(0.001)
        model.compile(loss=loss,  #'mean_squared_error',
                      optimizer=optimizer,
                      metrics=['mean_absolute_error', 'mean_squared_error'])
        return model

    model = model()

適合數據,但這不起作用

    EPOCHS = 100

    history = model.fit(
      inputs,
      outputs,
      epochs=EPOCHS,
      #validation_split = 0.2,
      #verbose=0
      )

更新:

我得到的確切錯誤,最后一個單元格如下:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-9-9994db909e15> in <module>
      4   inputs,
      5   outputs,
----> 6   epochs=EPOCHS,
      7   #validation_split = 0.2,
      8   #verbose=0

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, max_queue_size, workers, use_multiprocessing, **kwargs)
    878           initial_epoch=initial_epoch,
    879           steps_per_epoch=steps_per_epoch,
--> 880           validation_steps=validation_steps)
    881 
    882   def evaluate(self,

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, mode, validation_in_fit, **kwargs)
    308           if ins and isinstance(ins[-1], int):
    309             # Do not slice the training phase flag.
--> 310             ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
    311           else:
    312             ins_batch = slice_arrays(ins, batch_ids)

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in slice_arrays(arrays, start, stop)
    524       if hasattr(start, 'shape'):
    525         start = start.tolist()
--> 526       return [None if x is None else x[start] for x in arrays]
    527     else:
    528       return [None if x is None else x[start:stop] for x in arrays]

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in <listcomp>(.0)
    524       if hasattr(start, 'shape'):
    525         start = start.tolist()
--> 526       return [None if x is None else x[start] for x in arrays]
    527     else:
    528       return [None if x is None else x[start:stop] for x in arrays]

/usr/local/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py in _slice_helper(tensor, slice_spec, var)
    652         ellipsis_mask=ellipsis_mask,
    653         var=var,
--> 654         name=name)
    655 
    656 

/usr/local/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py in strided_slice(input_, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, var, name)
    818       ellipsis_mask=ellipsis_mask,
    819       new_axis_mask=new_axis_mask,
--> 820       shrink_axis_mask=shrink_axis_mask)
    821 
    822   parent_name = name

/usr/local/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py in strided_slice(input, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, name)
   9332       else:
   9333         message = e.message
-> 9334       _six.raise_from(_core._status_to_exception(e.code, message), None)
   9335   # Add nodes to the TensorFlow graph.
   9336   if begin_mask is None:

/usr/local/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: Index out of range using input dim 3; input has only 3 dims [Op:StridedSlice] name: strided_slice/

我能夠在這里建立自己的網絡。 基本的錯誤是我的模型產生了一個扁平(16,1)陣列,但我正在喂一個(8,2)真。

我已將最后一個單元格更改為:

EPOCHS = 1000

history = model.fit(
  inputs,
  np.reshape(outputs,(NUM_SAMPLES,16)), #previously this was outputs
  epochs=EPOCHS,
  validation_split = 0.2,
  #verbose=0
  )

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM