简体   繁体   English

训练有素的问题 model 和负载 model

[英]Problem with trained model and load model

I'm trying to create model from wav2vec2 with facebook/wav2vec2-base-960h pretrained model and this is my training_args我正在尝试使用facebook/wav2vec2-base-960h预训练 model 从wav2vec2创建 model,这是我的training_args

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir=save_dir,
    group_by_length=True,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    gradient_accumulation_steps=2,
    evaluation_strategy="steps",
    num_train_epochs=0.5,
    fp16=True,
    save_steps=10,
    eval_steps=10,
    logging_steps=10,
    learning_rate=1e-4,
    warmup_steps=500,
    save_total_limit=2,
)

and this is my trainer这是我的trainer

from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=_common_voice_train,
    eval_dataset=_common_voice_test,
    tokenizer=processor.feature_extractor,
)

now when the training part is over and model trained the trainer.evaluate() part show me the good result like this现在当训练部分结束并且 model 训练了trainer.evaluate()部分时,向我展示这样的好结果

reference: "شما امروز صبوری بفرمایین ثبت شده تا امروز با شما هماهنگی انجام بشه"参考资料:“شما امروood
predicted: "شما امروز سبوری بفرمای سبز شده تا امروز با شما همهمنگی انجام باشه"预测:“شما امروò سبوری بورمای سبò شده تا امروò با شما همهمنگь انوام باشه”

but when I'm trying to load and use the model I got this但是当我尝试加载和使用 model 时,我得到了这个

رچسصجپ هدثج یو تو یتنپ هر وغسهروغج سچ ثزتسه شتذس صمرجچو犹他州

I load my model like this我像这样加载我的 model

sample_rate = 16_000

model = Wav2Vec2ForCTC.from_pretrained("/content/drive/MyDrive/model")
processor = Wav2Vec2Processor.from_pretrained("/content/drive/MyDrive/model")
audio_input, sample_rate = librosa.load("/content60_L4.wav", sr=sample_rate)
input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])

I can't find my mistake我找不到我的错误

It's a little bit odd.这有点奇怪。 I encountered this issue and solved it via a strange path.我遇到了这个问题并通过一条奇怪的路径解决了它。 Try to use the same directory for your inference as your training directory.尝试使用与训练目录相同的目录进行推理。

After wrapping up the model and running it with the Python interpreter instead of Conda, I have never seen the bug again.在打包 model 并使用 Python 解释器而不是 Conda 运行它之后,我再也没有看到这个错误。

I do not know the reason for this, maybe someone can help me be more accurate about the cause.我不知道这是什么原因,也许有人可以帮助我更准确地了解原因。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM