Source code for examples.custom_attributes.custom_management

"""Illustrates customized class instrumentation, using
the :mod:`sqlalchemy.ext.instrumentation` extension package.

In this example, mapped classes are modified to
store their state in a dictionary attached to an attribute
named "_goofy_dict", instead of using __dict__.
this example illustrates how to replace SQLAlchemy's class
descriptors with a user-defined system.


"""

from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy import Text
from sqlalchemy.ext.instrumentation import InstrumentationManager
from sqlalchemy.orm import registry as _reg
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import del_attribute
from sqlalchemy.orm.attributes import get_attribute
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.instrumentation import is_instrumented


registry = _reg()


class MyClassState(InstrumentationManager):
    def get_instance_dict(self, class_, instance):
        return instance._goofy_dict

    def initialize_instance_dict(self, class_, instance):
        instance.__dict__["_goofy_dict"] = {}

    def install_state(self, class_, instance, state):
        instance.__dict__["_goofy_dict"]["state"] = state

    def state_getter(self, class_):
        def find(instance):
            return instance.__dict__["_goofy_dict"]["state"]

        return find


class MyClass:
    __sa_instrumentation_manager__ = MyClassState

    def __init__(self, **kwargs):
        for k in kwargs:
            setattr(self, k, kwargs[k])

    def __getattr__(self, key):
        if is_instrumented(self, key):
            return get_attribute(self, key)
        else:
            try:
                return self._goofy_dict[key]
            except KeyError:
                raise AttributeError(key)

    def __setattr__(self, key, value):
        if is_instrumented(self, key):
            set_attribute(self, key, value)
        else:
            self._goofy_dict[key] = value

    def __delattr__(self, key):
        if is_instrumented(self, key):
            del_attribute(self, key)
        else:
            del self._goofy_dict[key]


if __name__ == "__main__":
    engine = create_engine("sqlite://")
    meta = MetaData()

    table1 = Table(
        "table1",
        meta,
        Column("id", Integer, primary_key=True),
        Column("name", Text),
    )
    table2 = Table(
        "table2",
        meta,
        Column("id", Integer, primary_key=True),
        Column("name", Text),
        Column("t1id", Integer, ForeignKey("table1.id")),
    )
    meta.create_all(engine)

    class A(MyClass):
        pass

    class B(MyClass):
        pass

    registry.map_imperatively(A, table1, properties={"bs": relationship(B)})

    registry.map_imperatively(B, table2)

    a1 = A(name="a1", bs=[B(name="b1"), B(name="b2")])

    assert a1.name == "a1"
    assert a1.bs[0].name == "b1"

    sess = Session(engine)
    sess.add(a1)

    sess.commit()

    a1 = sess.query(A).get(a1.id)

    assert a1.name == "a1"
    assert a1.bs[0].name == "b1"

    a1.bs.remove(a1.bs[0])

    sess.commit()

    a1 = sess.query(A).get(a1.id)
    assert len(a1.bs) == 1