[英]How to implement a repository for lazy data loading with neuraxle?
In the neuraxle documentation there is an example shown, using a repository for lazy loading data within a pipeline, see the following code:在neuraxle 文档中显示了一个示例,使用存储库在管道中延迟加载数据,请参阅以下代码:
from neuraxle.pipeline import Pipeline, MiniBatchSequentialPipeline
from neuraxle.base import ExecutionContext
from neuraxle.steps.column_transformer import ColumnTransformer
from neuraxle.steps.flow import TrainOnlyWrapper
training_data_ids = training_data_repository.get_all_ids()
context = ExecutionContext('caching_folder').set_service_locator({
BaseRepository: training_data_repository
})
pipeline = Pipeline([
ConvertIDsToLoadedData().assert_has_services(BaseRepository),
ColumnTransformer([
(range(0, 2), DateToCosineEncoder()),
(3, CategoricalEnum(categeories_count=5, starts_at_zero=True)),
]),
Normalizer(),
TrainOnlyWrapper(DataShuffler()),
MiniBatchSequentialPipeline([
Model()
], batch_size=128)
]).with_context(context)
However, it is not shown, how to implement the BaseRepository
and ConvertIDsToLoadedData
classes.但是,没有显示如何实现BaseRepository
和ConvertIDsToLoadedData
类。 What would be the best way to implement those classes?实现这些类的最佳方法是什么? Could anyone give an example?谁能举个例子?
I didn't check wheter or not the following compiles, but it should look like what follows.我没有检查以下是否编译,但它应该如下所示。 Please someone edit this answer if you find something to change and tried to compile it:如果您发现要更改的内容并尝试编译它,请有人编辑此答案:
class BaseDataRepository(ABC):
@abstractmethod
def get_all_ids(self) -> List[int]:
pass
@abstractmethod
def get_data_from_id(self, _id: int) -> object:
pass
class InMemoryDataRepository(BaseDataRepository):
def __init__(self, ids, data):
self.ids: List[int] = ids
self.data: Dict[int, object] = data
def get_all_ids(self) -> List[int]:
return list(self.ids)
def get_data_from_id(self, _id: int) -> object:
return self.data[_id]
class ConvertIDsToLoadedData(BaseStep):
def _transform_data_container(self, data_container: DataContainer, context: ExecutionContext):
repo: BaseDataRepository = context.get_service(BaseDataRepository)
ids = data_container.data_inputs
# Replace data ids by their loaded object counterpart:
data_container.data_inputs = [repo.get_data_from_id(_id) for _id in ids]
return data_container, context
context = ExecutionContext('caching_folder').set_service_locator({
BaseDataRepository: InMemoryDataRepository(ids, data) # or insert here any other replacement class that inherits from `BaseDataRepository` when you'll change the database to a real one (e.g.: SQL) rather than a cheap "InMemory" stub.
})
For updates, see the issue I opened here for this question: https://github.com/Neuraxio/Neuraxle/issues/421有关更新,请参阅我在这里为这个问题打开的问题: https : //github.com/Neuraxio/Neuraxle/issues/421
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.