diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index b2441281dd18..97dd22103969 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -68,11 +68,26 @@ try: # pylint: disable=wrong-import-order, wrong-import-position import resource - - from apache_beam.ml.inference.model_manager import ModelManager except ImportError: resource = None # type: ignore[assignment] - ModelManager = None # type: ignore[assignment] + + +def _try_import_model_manager(throw_error: bool = True): + try: + from apache_beam.ml.inference.model_manager import ModelManager + return ModelManager + except ImportError as e: + if throw_error: + raise ImportError( + "Model Manager is not available. Please ensure that " + "all required packages for inference are installed and up to date." + ) from e + else: + return None + + +# ModelManager is an optional dependency so we don't throw an error here. +ModelManager = _try_import_model_manager(throw_error=False) _NANOSECOND_TO_MILLISECOND = 1_000_000 _NANOSECOND_TO_MICROSECOND = 1_000 @@ -1443,6 +1458,7 @@ def annotations(self): 'model_handler_type': ( f'{self._model_handler.__class__.__module__}' f'.{self._model_handler.__class__.__qualname__}'), + 'model_identifier': self._model_tag, **super().annotations() } @@ -1997,6 +2013,9 @@ def load(): # Ensure the tag we're loading is valid, if not replace it with a valid tag self._cur_tag = self._model_metadata.get_valid_tag(model_tag) if self.use_model_manager: + # Force an import here to avoid missing ModelManager when needed. + # Throw an error if ModelManager is not available since it's required for this code path. + ModelManager = _try_import_model_manager(throw_error=True) logging.info("Using Model Manager to manage models automatically.") model_manager = multi_process_shared.MultiProcessShared( lambda: ModelManager(**self._model_manager_args),