"""
Base Sklearn pipeline wrapper
"""
[docs]__author__ = "Elisha Yadgaran"
import inspect
import logging
from typing import Any, Dict, List
from simpleml.pipelines.base_pipeline import Pipeline
from simpleml.pipelines.projected_splits import ProjectedDatasetSplit
from simpleml.utils.signature_inspection import signature_kwargs_validator
from .external_pipeline import SklearnExternalPipeline
[docs]LOGGER = logging.getLogger(__name__)
[docs]class SklearnPipeline(Pipeline):
"""
Scikit-Learn Pipeline implementation
"""
[docs] def _create_external_pipeline(
self, transformers: List[Any], **kwargs
) -> SklearnExternalPipeline:
"""
Initialize a scikit-learn pipeline object
"""
supported_kwargs = signature_kwargs_validator(
SklearnExternalPipeline.__init__, **kwargs
)
return SklearnExternalPipeline(
transformers,
# Only supported sklearn params
**supported_kwargs,
)
[docs] def _filter_fit_params(self, split: ProjectedDatasetSplit) -> Dict[str, Any]:
"""
Sklearn Pipelines register arbitrary input kwargs but validate non X,y
as `stepname__parameter` format
"""
supported_fit_params = {}
# Ensure input compatibility with split object
fit_params = inspect.signature(self.external_pipeline.fit).parameters
for split_arg, val in split.items():
if split_arg not in fit_params and "__" not in split_arg:
LOGGER.warning(
f"Unsupported fit param encountered, `{split_arg}`. Dropping..."
)
else:
supported_fit_params[split_arg] = val
return supported_fit_params