简体   繁体   English

如何从字典值返回 object?

[英]How to return an object from a dictionary value?

I have several classes, each holding its own Neural Network architecture.我有几个班级,每个班级都有自己的神经网络架构。

Based on a user-entered flag, rx_flag , I am trying to retrieve a specific architecture in the driver file.基于用户输入的标志rx_flag ,我正在尝试检索驱动程序文件中的特定体系结构。

I have two problems:我有两个问题:

  1. I am unable to create a dictionary in the usual way.我无法以通常的方式创建字典。 The following format is not working:以下格式无效:
     def build_model(rx_flag): switcher = { 'xss': XSS().get_model(), 'rss': RSS().get_model() } return switcher.get(rx_flag)
  2. After some research, I am finally able to at least construct a dictionary where the models in those classes are stored as values, but when I return them to the main() method, I get the NoneType .经过一番研究,我终于能够至少构建一个字典,将这些类中的模型存储为值,但是当我将它们返回给main()方法时,我得到了NoneType

Here is my class.这是我的 class。 The other classes have a similar template.其他类也有类似的模板。 I have commented out the implementation of __hash__() and __eq__() because storing in the dictionary seems to be working fine without it too.我已经注释掉了__hash__()__eq__()的实现,因为没有它,在字典中存储似乎也可以正常工作。

from model import Model
from keras.layers import Dense
from keras.models import Sequential

class XSS(Model):

    def __init__(self):
        self.num_layers = 2
        self.input_dim = 3
        self.output_dim = 1
        self.architecture = [64, 32]
        self.model = Sequential()

    def get_model( self , arch=[64, 32]):
        # add input layer
        self.model.add(Dense(arch[0], activation='relu', input_shape=(self.input_dim, )))

        # add intermediate layers
        for i in range(1, self.num_layers):
            self.model.add(Dense(arch[i], activation='relu'))

        # add output layer
        self.model.add(Dense(self.output_dim, activation='linear'))
        return self.model

    def get_name( self ):
        return 'xss'

    def get_value( self ):
        return self.__value()

    def __value( self ):
        return (self.model, self.num_layers, self.input_dim, self.output_dim, self.architecture)

    # def __hash__(self):
    #     return (self.hash(self.__value()))
    #
    # def __eq__(self, other):
    #     if isinstance(other, XSS):
    #         return self.__value() == other.__value()
    #     return NotImplemented

This is the driver code:这是驱动程序代码:

import sys

from model import Model
from xss import XSS

def build_model(rx_flag):
    switcher = {}
    obX = XSS()
    switcher[obX.get_name()] = obX.get_model()
    obR = RSS()
    switcher[obR.get_name()] = obR.get_model()
    print(switcher)
    return switcher.get(rx_flag)

if __name__ == '__main__':
    rx_flag = sys.argv[0]
    # create a model instance based on flag
    model = build_model(rx_flag)
    model.summary()

This is the error that I get on attempting model.summary() .这是我在尝试model.summary()时遇到的错误。

Traceback (most recent call last):
File "C:/Users/path/driver.py", line 19, in <module>
    model.summary()
AttributeError: 'NoneType' object has no attribute 'summary'

How can I build the dictionary in a more Pythonic way, and have it return the actual model?如何以更 Pythonic 的方式构建字典,并让它返回实际的 model?

@juanpa.arrivillaga suggested that rx_flag is not what I think it is. @juanpa.arrivillaga建议 rx_flag 不是我想的那样。 They are right.他们是对的。

The code works fine even on PyCharm when I initialize rx_flag as below:当我如下初始化rx_flag时,即使在 PyCharm 上代码也能正常工作:

rx_flag = sys.argv[1]

I was under the impression that it is the first parameter when I enter it in the run configuration on PyCharm.在 PyCharm 的运行配置中输入时,我的印象是它是第一个参数。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM