[英]How to return an object from a dictionary value?
I have several classes, each holding its own Neural Network architecture.我有几个班级,每个班级都有自己的神经网络架构。
Based on a user-entered flag, rx_flag
, I am trying to retrieve a specific architecture in the driver file.基于用户输入的标志
rx_flag
,我正在尝试检索驱动程序文件中的特定体系结构。
I have two problems:我有两个问题:
def build_model(rx_flag): switcher = { 'xss': XSS().get_model(), 'rss': RSS().get_model() } return switcher.get(rx_flag)
main()
method, I get the NoneType
.main()
方法时,我得到了NoneType
。 Here is my class.这是我的 class。 The other classes have a similar template.
其他类也有类似的模板。 I have commented out the implementation of
__hash__()
and __eq__()
because storing in the dictionary seems to be working fine without it too.我已经注释掉了
__hash__()
和__eq__()
的实现,因为没有它,在字典中存储似乎也可以正常工作。
from model import Model
from keras.layers import Dense
from keras.models import Sequential
class XSS(Model):
def __init__(self):
self.num_layers = 2
self.input_dim = 3
self.output_dim = 1
self.architecture = [64, 32]
self.model = Sequential()
def get_model( self , arch=[64, 32]):
# add input layer
self.model.add(Dense(arch[0], activation='relu', input_shape=(self.input_dim, )))
# add intermediate layers
for i in range(1, self.num_layers):
self.model.add(Dense(arch[i], activation='relu'))
# add output layer
self.model.add(Dense(self.output_dim, activation='linear'))
return self.model
def get_name( self ):
return 'xss'
def get_value( self ):
return self.__value()
def __value( self ):
return (self.model, self.num_layers, self.input_dim, self.output_dim, self.architecture)
# def __hash__(self):
# return (self.hash(self.__value()))
#
# def __eq__(self, other):
# if isinstance(other, XSS):
# return self.__value() == other.__value()
# return NotImplemented
This is the driver code:这是驱动程序代码:
import sys
from model import Model
from xss import XSS
def build_model(rx_flag):
switcher = {}
obX = XSS()
switcher[obX.get_name()] = obX.get_model()
obR = RSS()
switcher[obR.get_name()] = obR.get_model()
print(switcher)
return switcher.get(rx_flag)
if __name__ == '__main__':
rx_flag = sys.argv[0]
# create a model instance based on flag
model = build_model(rx_flag)
model.summary()
This is the error that I get on attempting model.summary()
.这是我在尝试
model.summary()
时遇到的错误。
Traceback (most recent call last):
File "C:/Users/path/driver.py", line 19, in <module>
model.summary()
AttributeError: 'NoneType' object has no attribute 'summary'
How can I build the dictionary in a more Pythonic way, and have it return the actual model?如何以更 Pythonic 的方式构建字典,并让它返回实际的 model?
@juanpa.arrivillaga suggested that rx_flag is not what I think it is. @juanpa.arrivillaga建议 rx_flag 不是我想的那样。 They are right.
他们是对的。
The code works fine even on PyCharm when I initialize rx_flag
as below:当我如下初始化
rx_flag
时,即使在 PyCharm 上代码也能正常工作:
rx_flag = sys.argv[1]
I was under the impression that it is the first parameter when I enter it in the run configuration on PyCharm.在 PyCharm 的运行配置中输入时,我的印象是它是第一个参数。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.