简体   繁体   中英

How can i have this output in order to create a word2vec model using keras and numpy in python?

Hello I'm having a hard time trying to understand this, here's the deal, I'm trying to have the same output as the person who gave me the project, but it seems like I'm doing something wrong in the process, here's the given output The output i want to have and the output i have: The output i have

Here's what my code looks like:

#the keras model/graph would look something like this:
from keras import layers, optimizers, Model

 # embedding , 
embedding = layers.Embedding(Vt, vector_dim, input_length=1, name='embedding')
# entrée deux entier (couple de morceaux)
input_target = Input((1,), dtype='int32')
input_context = Input((1,), dtype='int32')

print(input_target)
print(input_context)

target = embedding(input_target)
context = embedding(input_context)

#target = layers.Reshape((vector_dim,))(target)
#context = layers.Reshape((vector_dim,))(context)

print("----------")
print(target)
print(context)

dot_product = layers.dot([target, context], axes=1)
dot_product = Flatten()(dot_product)
print(dot_product)
#dot_product = layers.Reshape((1,))(dot_product)
#dot_product = layers.Reshape((vector_dim,))(dot_product)

output = Dense(1, activation='sigmoid',name="classif")(dot_product)

# # definition du modèle
Track2Vec = Model(inputs=[input_target, input_context], outputs=output)
Track2Vec.compile(loss='binary_crossentropy', optimizer='adam',metrics=["accuracy"])

Thanks in advance for trying to figure what is wrong in my work!

The issue is in the layers.dot where you need to perform the dot product to give you (1,1) rather than (30,30).

Try to exchange you inputs for that layer.

''' dot_product = layers.dot([context, target], axes=1) ''' This should work.

this is the output (i want to suppress the two reshape layers lines) Output

try this:

dot_product = Dot(axes=1)([target, context])

you should also import:

from keras.layers.merge import Dot

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM