简体   繁体   English

Python:BERT Model 池错误 - mean() 收到了无效的 arguments 组合 - 得到(str,int)

[英]Python: BERT Model Pooling Error - mean() received an invalid combination of arguments - got (str, int)

I am writing the code to train a bert model on my dataset.我正在编写代码来在我的数据集上训练一个bert model。 By when I run the code it throws an error in the average pool layer.当我运行代码时,它会在平均池层中引发错误。 I am unable to understand what causes this error.我无法理解导致此错误的原因。

Model Model

class BERTBaseUncased(nn.Module):
    def __init__(self, bert_path):
        super(BERTBaseUncased, self).__init__()
        self.bert_path = bert_path
        self.bert = transformers.BertModel.from_pretrained(self.bert_path)
        self.bert_drop = nn.Dropout(0.3)
        self.out = nn.Linear(768 * 2, 1)

    def forward(
            self,
            ids,
            mask,
            token_type_ids
    ):
        o1, _ = self.bert(
            ids,
            attention_mask=mask,
            token_type_ids=token_type_ids)
        
        apool = torch.mean(o1, 1)
        mpool, _ = torch.max(o1, 1)
        cat = torch.cat((apool, mpool), 1)

        bo = self.bert_drop(cat)
        p2 = self.out(bo)
        return p2

Error错误

Exception in device=TPU:0: mean() received an invalid combination of arguments - got (str, int), but expected one of:
 * (Tensor input, *, torch.dtype dtype)
 * (Tensor input, tuple of names dim, bool keepdim, *, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of ints dim, bool keepdim, *, torch.dtype dtype, Tensor out)

Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 228, in _start_fn
    fn(gindex, *args)
  File "<ipython-input-12-94e926c1f4df>", line 4, in _mp_fn
    a = _run()
  File "<ipython-input-5-ef9fa564682f>", line 146, in _run
    train_loop_fn(para_loader.per_device_loader(device), model, optimizer, device, scheduler=scheduler)
  File "<ipython-input-5-ef9fa564682f>", line 22, in train_loop_fn
    token_type_ids=token_type_ids
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 577, in __call__
    result = self.forward(*input, **kwargs)
  File "<ipython-input-11-9196e0d23668>", line 73, in forward
    apool = torch.mean(o1, 1)
TypeError: mean() received an invalid combination of arguments - got (str, int), but expected one of:
 * (Tensor input, *, torch.dtype dtype)
 * (Tensor input, tuple of names dim, bool keepdim, *, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of ints dim, bool keepdim, *, torch.dtype dtype, Tensor out)

I am trying to run this on a Kaggle TPU.我正在尝试在 Kaggle TPU 上运行它。 How to fix this?如何解决这个问题?

Since one of the 3.X updates, the models return now task-specific output objects (which are dictionaries) instead of plain tuples.由于 3.X 更新之一,模型现在返回特定于任务的 output 对象(它们是字典)而不是普通元组。 You can either force the model to return a tuple by specifying return_dict=False :您可以通过指定return_dict=False来强制 model 返回一个元组:

o1, _ = self.bert(
            ids,
            attention_mask=mask,
            token_type_ids=token_type_ids,
            return_dict=False)

or by utilizing the basemodeloutputwithpoolingandcrossattentions object:或者通过使用带有池化和交叉注意的基本模型输出object

o = self.bert(
            ids,
            attention_mask=mask,
            token_type_ids=token_type_ids)
#you can view the other attributes with o.keys()
o1 = o.last_hidden_state

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 我如何摆脱这个错误:max() 收到了一个无效的参数组合 - 得到了 (str, int),但需要以下之一:*(张量输入) - How do I get rid of this error: max() received an invalid combination of arguments - got (str, int), but expected one of: * (Tensor input) 如何解决“TypeError: max() 收到无效的参数组合 - 得到(线性,int),但预期”? - How can I solve "TypeError: max() received an invalid combination of arguments - got (Linear, int), but expected"? torch.Tensor() new() 收到了无效的 arguments 组合 - got (list, dtype=torch.dtype) - torch.Tensor() new() received an invalid combination of arguments - got (list, dtype=torch.dtype) nn.LSTM() 收到 arguments 的无效组合 - nn.LSTM() received an invalid combination of arguments conv1d() 收到 arguments 的无效组合 - conv1d() received an invalid combination of arguments torch.addmm接收到无效的参数组合 - torch.addmm received an invalid combination of arguments Python 错误“预计最多输入 1 个参数,得到 3 个”是什么意思? - What does the Python error 'imput expected at most 1 arguments, got 3' mean? Python dataframe -.astype(str).astype(int) 给出错误 ValueError: invalid literal for int() with base 10: '' - Python dataframe - .astype(str).astype(int) gives error ValueError: invalid literal for int() with base 10: '' Python无效的Int()错误 - Python Invalid Int () Error 描述符“值”需要一个“ dict”对象,但接收到一个“ int” /“ str” Python - descriptor 'values' requires a 'dict' object but received a 'int'/'str' python
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM