Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,26 @@
try:
# pylint: disable=wrong-import-order, wrong-import-position
import resource

from apache_beam.ml.inference.model_manager import ModelManager
except ImportError:
resource = None # type: ignore[assignment]
ModelManager = None # type: ignore[assignment]


def _try_import_model_manager(throw_error: bool = True):
try:
from apache_beam.ml.inference.model_manager import ModelManager
return ModelManager
except ImportError as e:
if throw_error:
raise ImportError(
"Model Manager is not available. Please ensure that "
"all required packages for inference are installed and up to date."
) from e
Comment thread
AMOOOMA marked this conversation as resolved.
else:
return None


# ModelManager is an optional dependency so we don't throw an error here.
ModelManager = _try_import_model_manager(throw_error=False)

_NANOSECOND_TO_MILLISECOND = 1_000_000
_NANOSECOND_TO_MICROSECOND = 1_000
Expand Down Expand Up @@ -1443,6 +1458,7 @@ def annotations(self):
'model_handler_type': (
f'{self._model_handler.__class__.__module__}'
f'.{self._model_handler.__class__.__qualname__}'),
'model_identifier': self._model_tag,
**super().annotations()
}

Expand Down Expand Up @@ -1997,6 +2013,9 @@ def load():
# Ensure the tag we're loading is valid, if not replace it with a valid tag
self._cur_tag = self._model_metadata.get_valid_tag(model_tag)
if self.use_model_manager:
# Force an import here to avoid missing ModelManager when needed.
# Throw an error if ModelManager is not available since it's required for this code path.
ModelManager = _try_import_model_manager(throw_error=True)
logging.info("Using Model Manager to manage models automatically.")
model_manager = multi_process_shared.MultiProcessShared(
lambda: ModelManager(**self._model_manager_args),
Comment thread
AMOOOMA marked this conversation as resolved.
Expand Down
Loading