Skip to content

Workaround for Keras memory leaks #2

@asuiu

Description

@asuiu

Since keras.Model.predict method has proved memory leaks, a workaround is needed and periodically clear the Keras backend state by calling K.clear_session().

Please add this functionality to ModelObject in the next way:

class ModelObject(ABC):
    def __init__(
            self,
            mp: ModelParams,
            model: Model,
            encoders: Dict[str, List[DataEncoder]],
            state_autoclear: int = 128,
    ) -> None:
        """
        :param mp: the model parameters
        :param model: the keras model
        :param encoders: the encoders used to encode/decode the data
        :param state_autoclear: the number of calls to predict/evaluate methods after which the keras session is cleared
        """
        self.mp = mp
        self.model = model
        self.encoders = encoders
        self._state_autoclear = state_autoclear
        self._state_calls = 0

    def _check_clear_state(self):
        self._state_calls += 1  # since it's an atomic operation, it's thread safe due to GIL
        if self._state_calls >= self._state_autoclear:
            K.clear_session()
            logging.debug(f"Cleared keras session after {self._state_calls} calls to predict methods")
            self._state_calls = 0

and add _check_clear_state to predict() and evaluate() methods

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions