Source code for renku.infrastructure.database

#
# Copyright 2018-2023- Swiss Data Science Center (SDSC)
# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
# Eidgenössische Technische Hochschule Zürich (ETHZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom database for store Persistent objects."""

import datetime
import hashlib
import importlib
import io
import json
from enum import Enum
from pathlib import Path
from types import BuiltinFunctionType, FunctionType
from typing import Any, Dict, List, Optional, Union, cast
from uuid import uuid4

import deal
import persistent
import zstandard as zstd
from BTrees.Length import Length
from BTrees.OOBTree import BTree, OOBucket, OOSet, OOTreeSet
from persistent import GHOST, UPTODATE
from persistent.interfaces import IPickleCache
from zc.relation.catalog import Catalog
from ZODB.utils import z64
from zope.interface import implementer
from zope.interface.interface import InterfaceClass

from renku.core import errors
from renku.domain_model.project import Project
from renku.infrastructure.immutable import Immutable
from renku.infrastructure.persistent import Persistent

OID_TYPE = str
TYPE_TYPE = "type"
FUNCTION_TYPE = "function"
REFERENCE_TYPE = "reference"
SET_TYPE = "set"
FROZEN_SET_TYPE = "frozenset"
MARKER = object()
"""These are used as _p_serial to mark if an object was read from storage or is new"""

NEW = z64  # NOTE: Do not change this value since this is the default when a Persistent object is created
PERSISTED = b"1" * 8


def _is_module_allowed(module_name: str, type_name: str):
    """Checks whether it is allowed to import from the given module for security purposes.

    Args:
        module_name(str): The module name to check.
        type_name(str): The type within the module to check.

    Raises:
        TypeError: If the type is now allowed in the database.
    """

    if module_name not in ["BTrees", "builtins", "datetime", "persistent", "renku", "zc", "zope", "deal"]:
        raise TypeError(f"Objects of type '{type_name}' are not allowed")


[docs]def get_type_name(object) -> Optional[str]: """Return fully-qualified object's type name. Args: object: The object to get the type name for. Returns: Optional[str]: The fully qualified type name. """ if object is None: return None object_type = object if isinstance(object, type) else type(object) return f"{object_type.__module__}.{object_type.__qualname__}"
[docs]def get_class(type_name: Optional[str]) -> Optional[type]: """Return the class for a fully-qualified type name. Args: type_name(Optional[str]): The name of the class to get. Returns: Optional[type]: The class. """ if type_name is None: return None components = type_name.split(".") module_name = components[0] _is_module_allowed(module_name, type_name) module = __import__(module_name) return get_attribute(module, components[1:])
[docs]def get_attribute(object, name: Union[List[str], str]): """Return an attribute of an object. Args: object: The object to get an attribute on. name(Union[List[str], str): The name of the attribute to get. Returns: The value of the attribute. """ import sys components = name.split(".") if isinstance(name, str) else name def _module_name(o): return o.__module__ if hasattr(o, "__module__") else o.__name__ module_name = _module_name(object) root_module_name = module_name.split(".")[0] for component in components: module_name = _module_name(object) if not hasattr(object, component) and f"{module_name}.{component}" not in sys.modules: try: _is_module_allowed(root_module_name, object.__name__) object = importlib.import_module(f".{component}", package=module_name) continue except ModuleNotFoundError: pass object = getattr(object, component) return object
[docs]class RenkuOOBTree(BTree): """Customize ``BTrees.OOBTree.BTree`` implementation.""" max_leaf_size = 1000 max_internal_size = 2000
[docs]class Database: """The Metadata Object Database. This class is equivalent to a ``persistent.DataManager`` and implements the ``persistent.interfaces.IPersistentDataManager`` interface. """ ROOT_OID = "root" def __init__(self, storage): self._storage: Storage = storage self._cache = Cache() # The pre-cache is used by get to avoid infinite loops when objects load their state self._pre_cache: Dict[OID_TYPE, persistent.Persistent] = {} # Objects added explicitly by add() or when serializing other objects. After commit they are moved to _cache. self._objects_to_commit: Dict[OID_TYPE, persistent.Persistent] = {} self._reader: ObjectReader = ObjectReader(database=self) self._writer: ObjectWriter = ObjectWriter(database=self) self._root: RenkuOOBTree self._initialize_root()
[docs] @classmethod def from_path(cls, path: Union[Path, str]) -> "Database": """Create a Storage and Database using the given path. Args: path(Union[pathlib.Path, str]): The path of the database. Returns: The database object. """ storage = Storage(path) return Database(storage=storage)
[docs] @staticmethod def generate_oid(object: persistent.Persistent) -> OID_TYPE: """Generate an ``oid`` for a ``persistent.Persistent`` object based on its id. Args: object(persistent.Persistent): The object to create an oid for. Returns: An oid for the object. """ oid = getattr(object, "_p_oid") if oid: assert isinstance(oid, OID_TYPE) return oid id: Optional[str] = getattr(object, "id", None) or getattr(object, "_id", None) if id: return Database.hash_id(id) return Database.new_oid()
[docs] @staticmethod def hash_id(id: str) -> OID_TYPE: """Return ``oid`` from id. Args: id(str): The id to hash. Returns: OID_TYPE: The hashed id. """ return hashlib.sha3_256(id.encode("utf-8")).hexdigest()
[docs] @staticmethod def new_oid(): """Generate a random ``oid``.""" return f"{uuid4().hex}{uuid4().hex}"
@staticmethod def _get_filename_from_oid(oid: OID_TYPE) -> str: return oid.lower() def __getitem__(self, key) -> "Index": return self._root[key]
[docs] def clear(self): """Remove all objects and clear all caches. Objects won't be deleted in the storage.""" self._cache.clear() self._pre_cache.clear() self._objects_to_commit.clear() # NOTE: Clear root at the end because it will be added to _objects_to_commit when `register` is called. self._root.clear()
def _initialize_root(self): """Initialize root object.""" try: self._root = cast(RenkuOOBTree, self.get(Database.ROOT_OID)) except errors.ObjectNotFoundError: self._root = RenkuOOBTree() self._root._p_oid = Database.ROOT_OID self.register(self._root)
[docs] def add_index( self, name: str, object_type: type, attribute: Optional[str] = None, key_type: Optional[type] = None ) -> "Index": """Add an index. Args: name(str): The name of the index. object_type(type): The type contained within the index. attribute(str, optional): The attribute of the contained object to create a key from (Default value = None). key_type(type, optional): The type of the key (Default value = None). Returns: Index: The created ``Index`` object. """ assert name not in self._root, f"Index or object already exists: '{name}'" index = Index(name=name, object_type=object_type, attribute=attribute, key_type=key_type) index._p_jar = self self._root[name] = index return index
[docs] def add_root_object(self, name: str, obj: Persistent): """Add an object to the DB root. Args: name(str): The key of the object. obj(Persistent): The object to store. """ assert name not in self._root, f"Index or object already exists: '{name}'" obj._p_jar = self obj._p_oid = name self._root[name] = obj
[docs] def add(self, object: persistent.Persistent, oid: OID_TYPE): """Add a new object to the database. NOTE: Normally, we add objects to indexes but this method adds objects directly to Dataset's root. Use it only for singleton objects that have no Index defined for them (e.g. Project). Args: object(persistent.Persistent): The object to add. oid(OID_TYPE, optional): The oid for the object (Default value = None). """ assert not oid or isinstance(oid, OID_TYPE), f"Invalid oid type: '{type(oid)}'" object._p_oid = oid self.register(object)
[docs] def register(self, object: persistent.Persistent): """Register a persistent.Persistent object to be stored. NOTE: When a persistent.Persistent object is changed it calls this method. Args: object(persistent.Persistent): The object to register with the database. """ assert isinstance(object, persistent.Persistent), f"Cannot add non-Persistent object: '{object}'" if object._p_oid is None: object._p_oid = self.generate_oid(object) elif isinstance(object, Persistent): # NOTE: A safety-net to make sure that all objects have correct p_oid id = getattr(object, "id") expected_oid = Database.hash_id(id) actual_oid = object._p_oid assert actual_oid == expected_oid, f"Object has wrong oid: {actual_oid} != {expected_oid}" object._p_jar = self # object._p_serial = NEW self._objects_to_commit[object._p_oid] = object
[docs] def get(self, oid: OID_TYPE) -> persistent.Persistent: """Get the object by ``oid``. Args: oid(OID_TYPE): The oid of the object to get. Returns: persistent.Persistent: The object. """ if oid != Database.ROOT_OID and oid in self._root: # NOTE: Avoid looping if getting "root" return self._root[oid] object = self.get_cached(oid) if object is not None: return object object = self.get_from_path(path=self._get_filename_from_oid(oid)) if isinstance(object, Persistent): object.freeze() # NOTE: Avoid infinite loop if object tries to load its state before it is added to the cache self._pre_cache[oid] = object self._cache[oid] = object self._pre_cache.pop(oid) return object
[docs] def get_from_path( self, path: str, absolute: bool = False, override_type: Optional[str] = None ) -> persistent.Persistent: """Load a database object from a path. Args: path(str): Path of the database object. absolute(bool): Whether the path is absolute or a filename inside the database (Default value = False). override_type(Optional[str]): load object as a different type than what is set inside `renku_data_type` (Default value = None). Returns: persistent.Persistent: The object. """ deal.disable(warn=False) data = self._storage.load(filename=path, absolute=absolute) if override_type is not None: if "@renku_data_type" not in data: raise errors.IncompatibleParametersError("Cannot override type on found data.") data["@renku_data_type"] = override_type object = self._reader.deserialize(data) object._p_changed = 0 object._p_serial = PERSISTED deal.enable() return object
[docs] def get_by_id(self, id: str) -> persistent.Persistent: """Return an object by its id. Args: id(str): The id to look up. Returns: persistent.Persistent: The object with the given id. """ oid = Database.hash_id(id) return self.get(oid)
[docs] def get_cached(self, oid: OID_TYPE) -> Optional[persistent.Persistent]: """Return an object if it is in the cache or will be committed. Args: oid(OID_TYPE): The id of the object to look up. Returns: Optional[persistent.Persistent]: The cached object. """ object = self._cache.get(oid) if object is not None: return object object = self._pre_cache.get(oid) if object is not None: return object object = self._objects_to_commit.get(oid) if object is not None: return object return None
[docs] def remove_root_object(self, name: str) -> None: """Remove a root object from the database. Args: name(str): The name of the root object to remove. """ assert name in self._root, f"Index or object doesn't exist in root: '{name}'" obj = self.get(name) self.remove_from_cache(obj) del self._root[name]
[docs] def new_ghost(self, oid: OID_TYPE, object: persistent.Persistent): """Create a new ghost object. Args: oid(OID_TYPE): The oid of the new ghost object. object(persistent.Persistent): The object to create a new ghost entry for. """ object._p_jar = self self._cache.new_ghost(oid, object)
[docs] def setstate(self, object: persistent.Persistent): """Load the state for a ghost object. Args: object(persistent.Persistent): The object to set the state on. """ deal.disable(warn=False) data = self._storage.load(filename=self._get_filename_from_oid(object._p_oid)) self._reader.set_ghost_state(object, data) object._p_serial = PERSISTED if isinstance(object, Persistent): object.freeze() deal.enable()
[docs] def commit(self): """Commit modified and new objects.""" while self._objects_to_commit: _, object = self._objects_to_commit.popitem() if object._p_changed or object._p_serial == NEW: self._store_object(object)
def _store_object(self, object: persistent.Persistent): data = self._writer.serialize(object) compress = False if isinstance(object, (Catalog, RenkuOOBTree, OOBucket, Project, Index)) else True self._storage.store(filename=self._get_filename_from_oid(object._p_oid), data=data, compress=compress) self._cache[object._p_oid] = object object._p_changed = 0 # NOTE: transition from changed to up-to-date object._p_serial = PERSISTED
[docs] def persist_to_path(self, object: persistent.Persistent, path: Path): """Store an object to path.""" data = self._writer.serialize(object) compress = False if isinstance(object, (Catalog, RenkuOOBTree, OOBucket, Project, Index)) else True self._storage.store(filename=str(path), data=data, compress=compress, absolute=True)
[docs] def remove_from_cache(self, object: persistent.Persistent): """Remove an object from cache. Args: object(persistent.Persistent): The object to remove. """ oid = object._p_oid def remove_from(cache): existing_entry = cache.get(oid) if existing_entry is object: cache.pop(oid) remove_from(self._cache) remove_from(self._pre_cache) remove_from(self._objects_to_commit)
[docs] def readCurrent(self, object): """We don't use this method but some Persistent logic require its existence. Args: object: The object to read. """ assert object._p_jar is self assert object._p_oid is not None
[docs] def oldstate(self, object, tid): """See ``persistent.interfaces.IPersistentDataManager::oldstate``.""" raise NotImplementedError
[docs]@implementer(IPickleCache) class Cache: """Database ``Cache``.""" def __init__(self): self._entries = {} def __len__(self): return len(self._entries) def __getitem__(self, oid): assert isinstance(oid, OID_TYPE), f"Invalid oid type: '{type(oid)}'" return self._entries[oid] def __setitem__(self, oid, object): assert isinstance(object, persistent.Persistent), f"Cannot cache non-Persistent objects: '{object}'" assert isinstance(oid, OID_TYPE), f"Invalid oid type: '{type(oid)}'" assert object._p_jar is not None, "Cached object jar missing" assert oid == object._p_oid, f"Cache key does not match oid: {oid} != {object._p_oid}" if oid in self._entries: existing_data = self.get(oid) if existing_data is not object: raise ValueError(f"The same oid exists: {existing_data} != {object}") self._entries[oid] = object def __delitem__(self, oid): assert isinstance(oid, OID_TYPE), f"Invalid oid type: '{type(oid)}'" self._entries.pop(oid)
[docs] def clear(self): """Remove all entries.""" self._entries.clear()
[docs] def pop(self, oid, default=MARKER): """Remove and return an object. Args: oid: The oid of the object to remove from the cache. default: Default value to return (Default value = MARKER). Raises: KeyError: If object wasn't found and no default was given. Returns: The removed object or the default value if it doesn't exist. """ return self._entries.pop(oid) if default is MARKER else self._entries.pop(oid, default)
[docs] def get(self, oid, default=None): """See ``IPickleCache``. Args: oid: The oid of the object to get. default: Default value to return if object wasn't found (Default value = None). Returns: The object or default value if the object wasn't found. """ assert isinstance(oid, OID_TYPE), f"Invalid oid type: '{type(oid)}'" return self._entries.get(oid, default)
[docs] def new_ghost(self, oid, object): """See ``IPickleCache``.""" assert object._p_oid is None, f"Object already has an oid: {object}" assert object._p_jar is not None, f"Object does not have a jar: {object}" assert oid not in self._entries, f"Duplicate oid: {oid}" object._p_oid = oid if object._p_state != GHOST: object._p_invalidate() self[oid] = object
[docs]class Index(persistent.Persistent): """Database index.""" def __init__(self, *, name: str, object_type, attribute: Optional[str], key_type=None): """Create an index where keys are extracted using ``attribute`` from an object or a key. Args: name (str): Index's name. object_type: Type of objects that the index points to. attribute (Optional[str], optional): Name of an attribute to be used to automatically generate a key (e.g. `entity.path`). key_type: Type of keys. If not None then a key must be provided when updating the index (Default value = None). """ assert name == name.lower(), f"Index name must be all lowercase: '{name}'." super().__init__() self._p_oid = f"{name}-index" self._name: str = name self._object_type = object_type self._key_type = key_type self._attribute: Optional[str] = attribute self._entries: RenkuOOBTree = RenkuOOBTree() self._entries._p_oid = name def __len__(self): return len(self._entries) def __contains__(self, key): return key in self._entries def __getitem__(self, key): return self._entries[key] def __setitem__(self, key, value): # NOTE: if Index is using a key object then we cannot check if key is valid. It's safer to use `add` method # instead of setting values directly. self._verify_and_get_key(object=value, key_object=None, key=key, missing_key_object_ok=True) self._entries[key] = value def __getstate__(self): return { "name": self._name, "object_type": get_type_name(self._object_type), "key_type": get_type_name(self._key_type), "attribute": self._attribute, "entries": self._entries, } def __setstate__(self, data): self._name = data.pop("name") self._object_type = get_class(data.pop("object_type")) self._key_type = get_class(data.pop("key_type")) self._attribute = data.pop("attribute") self._entries = data.pop("entries") def __iter__(self): return self._entries.__iter__() @property def name(self) -> str: """Return Index's name.""" return self._name @property def object_type(self) -> type: """Return Index's object_type.""" return self._object_type
[docs] def get(self, key, default=None): """Return an entry based on its key. Args: key: The key of the entry to get. default: Default value to return of entry wasn't found (Default value = None). Returns: The found entry or the default value if it wasn't found. """ return self._entries.get(key, default)
[docs] def pop(self, key, default=MARKER): """Remove and return an object. Args: key: The key of the entry to remove. default: Default value to return of entry wasn't found (Default value = MARKER). Returns: The removed entry or the default value if it wasn't found. """ if not key: return return self._entries.pop(key) if default is MARKER else self._entries.pop(key, default)
[docs] def keys(self, min=None, max=None, excludemin=False, excludemax=False): """Return an iterator of keys.""" return self._entries.keys(min=min, max=max, excludemin=excludemin, excludemax=excludemax)
[docs] def values(self): """Return an iterator of values.""" return self._entries.values()
[docs] def items(self): """Return an iterator of keys and values.""" return self._entries.items()
[docs] def add(self, object: persistent.Persistent, *, key: Optional[str] = None, key_object=None, verify=True): """Update index with object. If `Index._attribute` is not None then key is automatically generated. Key is extracted from `key_object` if it is not None; otherwise, it's extracted from `object`. Args: object(persistent.Persistent): Object to add. key(Optional[str], optional): Key to use in the index (Default value = None). key_object: Object to use to extract a key from (Default value = None). verify: Whether to check if the key is valid (Default value = True). """ assert isinstance(object, self._object_type), f"Cannot add objects of type '{type(object)}'" key = self._verify_and_get_key( object=object, key_object=key_object, key=key, missing_key_object_ok=False, verify=verify ) self._entries[key] = object
[docs] def remove(self, object: persistent.Persistent, *, key: Optional[str] = None, key_object=None, verify=True): """Remove object from the index. If `Index._attribute` is not None then key is automatically generated. Key is extracted from `key_object` if it is not None; otherwise, it's extracted from `object`. Args: object(persistent.Persistent): Object to add. key(Optional[str], optional): Key to use in the index (Default value = None). key_object: Object to use to extract a key from (Default value = None). verify: Whether to check if the key is valid (Default value = True). """ assert isinstance(object, self._object_type), f"Cannot remove objects of type '{type(object)}'" key = self._verify_and_get_key( object=object, key_object=key_object, key=key, missing_key_object_ok=False, verify=verify ) del self._entries[key]
[docs] def generate_key(self, object: persistent.Persistent, *, key_object=None): """Return index key for an object. Key is extracted from `key_object` if it is not None; otherwise, it's extracted from `object`. Args: object(persistent.Persistent): The object to generate a key for. key_object: The object to derive a key from (Default value = None). Returns: A key for object. """ return self._verify_and_get_key(object=object, key_object=key_object, key=None, missing_key_object_ok=False)
def _verify_and_get_key( self, *, object: persistent.Persistent, key_object, key, missing_key_object_ok, verify=True ): if self._key_type: if not missing_key_object_ok: assert isinstance(key_object, self._key_type), f"Invalid key type: {type(key_object)} for '{self.name}'" else: assert key_object is None, f"Index '{self.name}' does not accept 'key_object'" if self._attribute: key_object = key_object or object correct_key = get_attribute(key_object, self._attribute) if key is not None: if verify: assert key == correct_key, f"Incorrect key for index '{self.name}': '{key}' != '{correct_key}'" else: correct_key = key else: assert key is not None, "No key is provided" correct_key = key return correct_key
[docs]class Storage: """Store Persistent objects on the disk.""" OID_FILENAME_LENGTH = 64 def __init__(self, path: Union[Path, str]): self.path = Path(path) self.zstd_compressor = zstd.ZstdCompressor() self.zstd_decompressor = zstd.ZstdDecompressor()
[docs] def store(self, filename: str, data: Union[Dict, List], compress=False, absolute: bool = False): """Store object. Args: filename(str): Target file name to store data in. data(Union[Dict, List]): The data to store. compress(bool): Whether to compress the data or store it as plain json (Default value = False). absolute(bool): Whether filename is an absolute path (Default value = False). """ assert isinstance(filename, str) if absolute: path = Path(filename) else: is_oid_path = len(filename) == Storage.OID_FILENAME_LENGTH if is_oid_path: path = self.path / filename[0:2] / filename[2:4] / filename path.parent.mkdir(parents=True, exist_ok=True) else: path = self.path / filename self.path.mkdir(parents=True, exist_ok=True) if compress: with open(path, "wb") as fb, self.zstd_compressor.stream_writer(fb) as compressor: with io.TextIOWrapper(compressor) as out: json.dump(data, out, ensure_ascii=False) else: with open(path, "w") as ft: json.dump(data, ft, ensure_ascii=False, sort_keys=True, indent=2)
[docs] def load(self, filename: str, absolute: bool = False): """Load data for object with object id oid. Args: filename(str): The file name of the data to load. absolute(bool): Whether the path is absolute or a filename inside the database (Default value: False). Returns: The loaded data in dictionary form. """ assert isinstance(filename, str) if absolute: path = Path(filename) else: is_oid_path = len(filename) == Storage.OID_FILENAME_LENGTH if is_oid_path: path = self.path / filename[0:2] / filename[2:4] / filename else: path = self.path / filename if not path.exists(): raise errors.ObjectNotFoundError(filename) with open(path, "rb") as file: header = int.from_bytes(file.read(4), "little") file.seek(0) if header == zstd.MAGIC_NUMBER: with self.zstd_decompressor.stream_reader(file) as zfile: data = json.load(zfile) else: try: data = json.load(file) except json.JSONDecodeError: raise errors.MetadataCorruptError(path) return data
[docs]class ObjectWriter: """Serialize objects for storage in storage.""" def __init__(self, database: Database): self._database: Database = database
[docs] def serialize(self, object: persistent.Persistent): """Convert an object to JSON. Args: object(persistent.Persistent): Object to serialize. Returns: dict: Dictionary containing serialized data. """ assert isinstance(object, persistent.Persistent), f"Cannot serialize object of type '{type(object)}': {object}" assert object._p_oid, f"Object does not have an oid: '{object}'" assert object._p_jar is not None, f"Object is not associated with a Database: '{object}'" self._serialization_cache: Dict[int, Any] = {} state = object.__getstate__() was_dict = isinstance(state, dict) data = self._serialize_helper(state) is_dict = isinstance(data, dict) if not is_dict or (is_dict and not was_dict): data = {"@renku_data_value": data} data["@renku_data_type"] = get_type_name(object) data["@renku_oid"] = object._p_oid return data
def _serialize_helper(self, obj): # TODO: Raise an error if an unsupported object is being serialized if obj is None: return None elif isinstance(obj, (int, float, str, bool)): return obj elif isinstance(obj, list): return [self._serialize_helper(value) for value in obj] elif isinstance(obj, set): return { "@renku_data_type": SET_TYPE, "@renku_data_value": [self._serialize_helper(value) for value in obj], } elif isinstance(obj, frozenset): return { "@renku_data_type": FROZEN_SET_TYPE, "@renku_data_value": [self._serialize_helper(value) for value in obj], } elif isinstance(obj, dict): result = dict() items = sorted(obj.items(), key=lambda x: x[0]) for key, value in items: result[key] = self._serialize_helper(value) return result elif isinstance(obj, Index): # NOTE: Index objects are not stored as references and are included in their parent object (i.e. root) state = obj.__getstate__() state = self._serialize_helper(state) return {"@renku_data_type": get_type_name(obj), "@renku_oid": obj._p_oid, **state} elif isinstance(obj, (OOTreeSet, Length, OOSet)): state = obj.__getstate__() state = self._serialize_helper(state) return {"@renku_data_type": get_type_name(obj), "@renku_data_value": state} elif isinstance(obj, persistent.Persistent): if not obj._p_oid: obj._p_oid = Database.generate_oid(obj) if obj._p_state not in [GHOST, UPTODATE] or (obj._p_state == UPTODATE and obj._p_serial == NEW): self._database.register(obj) return {"@renku_data_type": get_type_name(obj), "@renku_oid": obj._p_oid, "@renku_reference": True} elif isinstance(obj, datetime.datetime): value = obj.isoformat() elif isinstance(obj, tuple): value = tuple(self._serialize_helper(value) for value in obj) elif isinstance(obj, (InterfaceClass)): # NOTE: Zope interfaces are weird, they're a class with type InterfaceClass, but need to be deserialized # as the class (without instantiation) return {"@renku_data_type": TYPE_TYPE, "@renku_data_value": f"{obj.__module__}.{obj.__name__}"} elif isinstance(obj, type): # NOTE: We're storing a type, not an instance return {"@renku_data_type": TYPE_TYPE, "@renku_data_value": get_type_name(obj)} elif isinstance(obj, (FunctionType, BuiltinFunctionType)): name = obj.__name__ module = getattr(obj, "__module__", None) return {"@renku_data_type": FUNCTION_TYPE, "@renku_data_value": f"{module}.{name}"} else: if id(obj) in self._serialization_cache: # NOTE: We already serialized this -> circular/repeat reference return {"@renku_data_type": REFERENCE_TYPE, "@renku_data_value": self._serialization_cache[id(obj)]} # NOTE: The reference used for circular reference is just the position in the serialization cache, # as the order is deterministic So the order in which objects are encoutered is their id for referencing. self._serialization_cache[id(obj)] = len(self._serialization_cache) if hasattr(obj, "__getstate__"): # NOTE: On Python 3.11+ this just returns __dict__ if __getstate__ isn't implemented. value = obj.__getstate__().copy() else: value = obj.__dict__.copy() value = {k: v for k, v in value.items() if not k.startswith("_v_")} value = self._serialize_helper(value) return {"@renku_data_type": get_type_name(obj), "@renku_data_value": value}
[docs]class ObjectReader: """Deserialize objects loaded from storage.""" def __init__(self, database: Database): self._classes: Dict[str, type] = {} self._database = database # a cache for normal (non-persistent objects with an id) to deduplicate them on load self._normal_object_cache: Dict[str, Any] = {} self._deserialization_cache: List[Any] = [] def _get_class(self, type_name: str) -> Optional[type]: cls = self._classes.get(type_name) if cls: return cls cls = get_class(type_name) if cls is None: return None self._classes[type_name] = cls return cls
[docs] def set_ghost_state(self, object: persistent.Persistent, data: Dict): """Set state of a Persistent ghost object. Args: object(persistent.Persistent): Object to set state on. data(Dict): State to set on the object. """ previous_cache = self._deserialization_cache self._deserialization_cache = [] state = self._deserialize_helper(data, create=False) object.__setstate__(state) self._deserialization_cache = previous_cache
[docs] def deserialize(self, data): """Convert JSON to Persistent object. Args: data: Data to deserialize. Returns: Deserialized object. """ oid = data["@renku_oid"] self._deserialization_cache = [] object = self._deserialize_helper(data) object._p_oid = oid object._p_jar = self._database return object
def _deserialize_helper(self, data, create=True): if data is None: return None elif isinstance(data, (int, float, str, bool)): return data elif isinstance(data, list): return [self._deserialize_helper(value) for value in data] else: assert isinstance(data, dict), f"Data must be a dict: '{type(data)}'" if "@renku_data_type" not in data: # NOTE: A normal dict value assert "@renku_oid" not in data items = sorted(data.items(), key=lambda x: x[0]) for key, value in items: data[key] = self._deserialize_helper(value) return data object_type = data.pop("@renku_data_type") if object_type in (TYPE_TYPE, FUNCTION_TYPE): # NOTE: if we stored a type (not instance), return the type return self._get_class(data["@renku_data_value"]) elif object_type == REFERENCE_TYPE: # NOTE: we had a circular reference, we return the (not yet finalized) class here return self._deserialization_cache[data["@renku_data_value"]] elif object_type == SET_TYPE: return {self._deserialize_helper(value) for value in data["@renku_data_value"]} elif object_type == FROZEN_SET_TYPE: return frozenset([self._deserialize_helper(value) for value in data["@renku_data_value"]]) cls = self._get_class(object_type) if cls is None: raise TypeError(f"Couldn't find class '{object_type}'") if issubclass(cls, datetime.datetime): assert create data = data["@renku_data_value"] return datetime.datetime.fromisoformat(data) elif issubclass(cls, tuple): data = data["@renku_data_value"] return tuple(self._deserialize_helper(value) for value in data) oid: str = data.pop("@renku_oid", None) if oid: assert isinstance(oid, str) if "@renku_reference" in data and data["@renku_reference"]: # A reference assert create, f"Cannot deserialize a reference without creating an instance {data}" new_object = self._database.get_cached(oid) if new_object is not None: return new_object assert issubclass(cls, persistent.Persistent) new_object = cls.__new__(cls) self._database.new_ghost(oid, new_object) return new_object elif issubclass(cls, Index): new_object = self._database.get_cached(oid) if new_object: return new_object new_object = cls.__new__(cls) new_object._p_oid = oid self.set_ghost_state(new_object, data) return new_object if "@renku_data_value" in data: data = data["@renku_data_value"] if not create: data = self._deserialize_helper(data) return data if issubclass(cls, persistent.Persistent): new_object = cls.__new__(cls) new_object._p_oid = oid # type: ignore[attr-defined] self.set_ghost_state(new_object, data) elif issubclass(cls, Enum): # NOTE: Enum replaces __new__ on classes with its own versions that validates entries new_object = cls.__new__(cls, data["_value_"]) return new_object else: new_object = cls.__new__(cls) # NOTE: we deserialize in the same order as we serialized, so the two stacks here match self._deserialization_cache.append(new_object) cache_index = len(self._deserialization_cache) - 1 data = self._deserialize_helper(data) assert isinstance(data, dict) if "id" in data and data["id"] in self._normal_object_cache: existing_object = self._normal_object_cache[data["id"]] # NOTE: replace uninitialized object in cache with actual object self._deserialization_cache[cache_index] = existing_object return existing_object if hasattr(new_object, "__setstate__"): new_object.__setstate__(data) else: for name, value in data.items(): object.__setattr__(new_object, name, value) if issubclass(cls, Immutable): new_object = cls.make_instance(new_object) if "id" in data and isinstance(data["id"], str) and data["id"].startswith("/"): self._normal_object_cache[data["id"]] = new_object return new_object