[英]Is it possible to load a pretrained Pytorch model from a GCS bucket URL without first persisting locally?
我是在 Google Dataflow 的背景下問這個問題的,但也是一般的。
使用 PyTorch,我可以引用包含多個文件的本地目錄,這些文件構成一個預訓練模型。 我碰巧使用的是 Roberta 模型,但其他人的界面是一樣的。
ls some-directory/
added_tokens.json
config.json
merges.txt
pytorch_model.bin
special_tokens_map.json vocab.json
from pytorch_transformers import RobertaModel
# this works
model = RobertaModel.from_pretrained('/path/to/some-directory/')
但是,我的預訓練模型存儲在 GCS 存儲桶中。 我們稱之為gs://my-bucket/roberta/
。
在 Google Dataflow 中加載這個模型的上下文中,我試圖保持無狀態並避免持久化到磁盤,所以我更喜歡直接從 GCS 獲取這個模型。 據我了解,PyTorch 通用接口方法from_pretrained()
可以采用本地目錄或 URL 的字符串表示形式。 但是,我似乎無法從 GCS URL 加載模型。
# this fails
model = RobertaModel.from_pretrained('gs://my-bucket/roberta/')
# ValueError: unable to parse gs://mahmed_bucket/roberta-base as a URL or as a local path
如果我嘗試使用目錄 blob 的公共 https URL,它也會失敗,盡管這可能是由於缺乏身份驗證,因為在可以創建客戶端的 python 環境中引用的憑據不會轉換為對https://storage.googleapis
公共請求https://storage.googleapis
# this fails, probably due to auth
bucket = gcs_client.get_bucket('my-bucket')
directory_blob = bucket.blob(prefix='roberta')
model = RobertaModel.from_pretrained(directory_blob.public_url)
# ValueError: No JSON object could be decoded
# and for good measure, it also fails if I append a trailing /
model = RobertaModel.from_pretrained(directory_blob.public_url + '/')
# ValueError: No JSON object could be decoded
我知道GCS 實際上沒有子目錄,它實際上只是存儲桶名稱下的一個平面命名空間。 但是,似乎我被身份驗證的必要性和 PyTorch 阻止了gs://
。
我可以通過首先在本地保存文件來解決這個問題。
from pytorch_transformers import RobertaModel
from google.cloud import storage
import tempfile
local_dir = tempfile.mkdtemp()
gcs = storage.Client()
bucket = gcs.get_bucket(bucket_name)
blobs = bucket.list_blobs(prefix=blob_prefix)
for blob in blobs:
blob.download_to_filename(local_dir + '/' + os.path.basename(blob.name))
model = RobertaModel.from_pretrained(local_dir)
但這似乎是一個黑客,我一直在想我一定錯過了一些東西。 當然有一種方法可以保持無狀態,而不必依賴磁盤持久性!
謝謝您的幫助! 我也很高興被指出任何重復的問題,因為我肯定找不到任何問題。
編輯和澄清
我的 Python 會話已經通過 GCS 身份驗證,這就是為什么我能夠在本地下載 blob 文件,然后使用load_frompretrained()
指向該本地目錄
load_frompretrained()
需要目錄引用,因為它需要問題頂部列出的所有文件,而不僅僅是pytorch-model.bin
為了澄清問題 #2,我想知道是否有某種方法可以為 PyTorch 方法提供一個嵌入了加密憑據或類似內容的請求 URL。 有點遠,但我想確保我沒有錯過任何東西。
為了澄清問題 #3(除了對下面一個答案的評論之外),即使有一種方法可以在我不知道的 URL 中嵌入憑據,我仍然需要引用一個目錄而不是單個 blob,並且我不知道 GCS 子目錄是否會被識別為這樣,因為(正如 Google 文檔所述)GCS 中的子目錄是一種錯覺,它們並不代表真正的目錄結構。 所以我認為這個問題無關緊要,或者至少被問題 #2 阻止,但這是我追尋的一個話題,所以我仍然很好奇。
主要編輯:
您可以在 Dataflow 工作人員上安裝輪文件,也可以使用工作人員臨時存儲在本地持久化二進制文件!
確實(目前截至 2019 年 11 月)您無法通過提供--requirements
參數來做到這一點。 相反,您必須像這樣使用setup.py
。 假設任何常量 IN CAPS 都在別處定義。
REQUIRED_PACKAGES = [
'torch==1.3.0',
'pytorch-transformers==1.2.0',
]
setup(
name='project_dir',
version=VERSION,
packages=find_packages(),
install_requires=REQUIRED_PACKAGES)
運行腳本
python setup.py sdist
python project_dir/my_dataflow_job.py \
--runner DataflowRunner \
--project ${GCP_PROJECT} \
--extra_package dist/project_dir-0.1.0.tar.gz \
# SNIP custom args for your job and required Dataflow Temp and Staging buckets #
在作業中,這里是在自定義 Dataflow 運算符的上下文中從 GCS 下載和使用模型。 為方便起見,我們將一些實用程序方法包裝在一個單獨的模塊中(對於繞過 Dataflow 依賴上傳很重要),並在自定義運算符的 LOCAL SCOPE 而非全局范圍內導入它們。
class AddColumn(beam.DoFn):
PRETRAINED_MODEL = 'gs://my-bucket/blah/roberta-model-files'
def get_model_tokenizer_wrapper(self):
import shutil
import tempfile
import dataflow_util as util
try:
return self.model_tokenizer_wrapper
except AttributeError:
tmp_dir = tempfile.mkdtemp() + '/'
util.download_tree(self.PRETRAINED_MODEL, tmp_dir)
model, tokenizer = util.create_model_and_tokenizer(tmp_dir)
model_tokenizer_wrapper = util.PretrainedPyTorchModelWrapper(
model, tokenizer)
shutil.rmtree(tmp_dir)
self.model_tokenizer_wrapper = model_tokenizer_wrapper
logging.info(
'Successfully created PretrainedPyTorchModelWrapper')
return self.model_tokenizer_wrapper
def process(self, elem):
model_tokenizer_wrapper = self.get_model_tokenizer_wrapper()
# And now use that wrapper to process your elem however you need.
# Note that when you read from BQ your elements are dictionaries
# of the column names and values for each BQ row.
代碼庫中 SEPARATE MODULE 中的實用函數。 在我們的項目根目錄中,它位於 dataflow_util/init.py 中,但您不必那樣做。
from contextlib import closing
import logging
import apache_beam as beam
import numpy as np
from pytorch_transformers import RobertaModel, RobertaTokenizer
import torch
class PretrainedPyTorchModelWrapper():
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def download_tree(gcs_dir, local_dir):
gcs = beam.io.gcp.gcsio.GcsIO()
assert gcs_dir.endswith('/')
assert local_dir.endswith('/')
for entry in gcs.list_prefix(gcs_dir):
download_file(gcs, gcs_dir, local_dir, entry)
def download_file(gcs, gcs_dir, local_dir, entry):
rel_path = entry[len(gcs_dir):]
dest_path = local_dir + rel_path
logging.info('Downloading %s', dest_path)
with closing(gcs.open(entry)) as f_read:
with open(dest_path, 'wb') as f_write:
# Download the file in chunks to avoid requiring large amounts of
# RAM when downloading large files.
while True:
file_data_chunk = f_read.read(
beam.io.gcp.gcsio.DEFAULT_READ_BUFFER_SIZE)
if len(file_data_chunk):
f_write.write(file_data_chunk)
else:
break
def create_model_and_tokenizer(local_model_path_str):
"""
Instantiate transformer model and tokenizer
:param local_model_path_str: string representation of the local path
to the directory containing the pretrained model
:return: model, tokenizer
"""
model_class, tokenizer_class = (RobertaModel, RobertaTokenizer)
# Load the pretrained tokenizer and model
tokenizer = tokenizer_class.from_pretrained(local_model_path_str)
model = model_class.from_pretrained(local_model_path_str)
return model, tokenizer
伙計們,你有它! 更多細節可以在這里找到: https : //beam.apache.org/documentation/sdks/python-pipeline-dependencies/
我發現這整個質疑鏈是無關緊要的,因為 Dataflow 只允許您在工作人員上安裝源代碼分發包,這意味着您實際上無法安裝 PyTorch。
當您提供一個 requirements.txt
文件時,Dataflow 將使用--no-binary
標志進行安裝,這會阻止安裝 Wheel (.whl) 軟件包並且只允許源分發 (.tar.gz)。我決定嘗試在 Google Dataflow 上為 PyTorch 推出我自己的源代碼分發版,其中一半是 C++,一半是 Cuda,另一部分是知道什么是傻瓜的差事。
感謝你們一路上的投入。
我對 Pytorch 或 Roberta 模型不太了解,但我會嘗試回答您關於 GCS 的詢問:
1.-“那么有沒有辦法加載存儲在 GCS 中的預訓練模型?”
如果您的模型可以直接從二進制文件加載 Blob:
from google.cloud import storage
client = storage.Client()
bucket = client.get_bucket("bucket name")
blob = bucket.blob("path_to_blob/blob_name.ext")
data = blob.download_as_string() # you will have your binary data transformed into string here.
2.-“在這種情況下執行公共 URL 請求時有沒有辦法進行身份驗證?”
這是棘手的部分,因為根據您運行腳本的上下文,它將使用默認服務帳戶進行身份驗證。 因此,當您使用官方 GCP 庫時,您可以:
A.- 授予該默認服務帳戶訪問您的存儲桶/對象的權限。
B.- 創建一個新的服務帳戶並在腳本中對其進行身份驗證(您還需要為該服務帳戶生成身份驗證令牌):
from google.cloud import storage
from google.oauth2 import service_account
VISION_SCOPES = ['https://www.googleapis.com/auth/devstorage']
SERVICE_ACCOUNT_FILE = 'key.json'
cred = service_account.Credentials.from_service_account_file(SERVICE_ACCOUNT_FILE, scopes=VISION_SCOPES)
client = storage.Client(credentials=cred)
bucket = client.get_bucket("bucket_name")
blob = bucket.blob("path/object.ext")
data = blob.download_as_string()
然而這是有效的,因為官方庫在后台處理對 API 調用的身份驗證,所以在 from_pretrained() 函數的情況下不起作用。
因此,另一種方法是將對象設為公開,這樣您就可以在使用公共 url 時訪問它。
3.-“即使有一種方法可以進行身份驗證,子目錄不存在仍然是一個問題嗎?”
不確定您的意思是在這里,您的存儲桶中可以有文件夾。
目前我沒有和 Roberta 一起玩,而是和 Bert 一起玩 NER 的令牌分類,但我認為它具有相同的機制..
下面是我的代碼:
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'your_gcs_auth.json'
# initiate storage
client = storage.Client()
en_bucket = client.get_bucket('your-gcs-bucketname')
# get blob
en_model_blob = en_bucket.get_blob('your-modelname-in-gcsbucket.bin')
en_model = en_model_blob.download_as_string()
# because model downloaded into string, need to convert it back
buffer = io.BytesIO(en_model)
# prepare loading model
state_dict = torch.load(buffer, map_location=torch.device('cpu'))
model = BertForTokenClassification.from_pretrained(pretrained_model_name_or_path=None, state_dict=state_dict, config=main_config)
model.load_state_dict(state_dict)
我不確定download_as_string()
方法是否將數據保存到本地磁盤,但根據我的經驗,如果我執行download_to_filename()
該函數會將模型下載到我的本地。
此外,如果您修改了轉換器網絡的配置(並將其放入 GCS 並需要加載),您還需要修改類PretrainedConfig
,因為它可以處理由download_as_string()
函數生成的文件。
歡呼,希望它有幫助
正如您正確指出的那樣,開箱即用的pytorch-transformers
似乎不支持這一點,但這主要是因為它無法將文件鏈接識別為 URL。
經過一番搜索,我在這個源文件中找到了相應的錯誤信息,大約在第 144-155 行。
當然,您可以嘗試將'gs'
標記添加到第 144 行,然后將您與 GCS 的連接解釋為http
請求(第 269-272 行)。 如果 GCS 接受這一點,那應該是唯一需要改變才能工作的事情。
如果這不起作用,唯一的直接解決方法是實現類似於 Amazon S3 存儲桶函數的功能,但我對 S3 和 GCS 存儲桶的了解還不夠,無法在這里做出任何有意義的判斷。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.