import binascii
from sqlalchemy import event
from sqlalchemy import Table
from sqlalchemy.sql import expression
from sqlalchemy.sql import type_coerce
from sqlalchemy.types import UserDefinedType
# Python datatypes
class GisElement(object):
"""Represents a geometry value."""
def __str__(self):
return self.desc
def __repr__(self):
return "<%s at 0x%x; %r>" % (
self.__class__.__name__,
id(self),
self.desc,
)
class BinaryGisElement(GisElement, expression.Function):
"""Represents a Geometry value expressed as binary."""
def __init__(self, data):
self.data = data
expression.Function.__init__(
self, "ST_GeomFromEWKB", data, type_=Geometry(coerce_="binary")
)
@property
def desc(self):
return self.as_hex
@property
def as_hex(self):
return binascii.hexlify(self.data)
class TextualGisElement(GisElement, expression.Function):
"""Represents a Geometry value expressed as text."""
def __init__(self, desc, srid=-1):
self.desc = desc
expression.Function.__init__(
self, "ST_GeomFromText", desc, srid, type_=Geometry
)
# SQL datatypes.
class Geometry(UserDefinedType):
"""Base PostGIS Geometry column type."""
name = "GEOMETRY"
def __init__(self, dimension=None, srid=-1, coerce_="text"):
self.dimension = dimension
self.srid = srid
self.coerce = coerce_
class comparator_factory(UserDefinedType.Comparator):
"""Define custom operations for geometry types."""
# override the __eq__() operator
def __eq__(self, other):
return self.op("~=")(other)
# add a custom operator
def intersects(self, other):
return self.op("&&")(other)
# any number of GIS operators can be overridden/added here
# using the techniques above.
def _coerce_compared_value(self, op, value):
return self
def get_col_spec(self):
return self.name
def bind_expression(self, bindvalue):
if self.coerce == "text":
return TextualGisElement(bindvalue)
elif self.coerce == "binary":
return BinaryGisElement(bindvalue)
else:
assert False
def column_expression(self, col):
if self.coerce == "text":
return func.ST_AsText(col, type_=self)
elif self.coerce == "binary":
return func.ST_AsBinary(col, type_=self)
else:
assert False
def bind_processor(self, dialect):
def process(value):
if isinstance(value, GisElement):
return value.desc
else:
return value
return process
def result_processor(self, dialect, coltype):
if self.coerce == "text":
fac = TextualGisElement
elif self.coerce == "binary":
fac = BinaryGisElement
else:
assert False
def process(value):
if value is not None:
return fac(value)
else:
return value
return process
def adapt(self, impltype):
return impltype(
dimension=self.dimension, srid=self.srid, coerce_=self.coerce
)
# other datatypes can be added as needed.
class Point(Geometry):
name = "POINT"
class Curve(Geometry):
name = "CURVE"
class LineString(Curve):
name = "LINESTRING"
# ... etc.
# DDL integration
# PostGIS historically has required AddGeometryColumn/DropGeometryColumn
# and other management methods in order to create PostGIS columns. Newer
# versions don't appear to require these special steps anymore. However,
# here we illustrate how to set up these features in any case.
def setup_ddl_events():
@event.listens_for(Table, "before_create")
def before_create(target, connection, **kw):
dispatch("before-create", target, connection)
@event.listens_for(Table, "after_create")
def after_create(target, connection, **kw):
dispatch("after-create", target, connection)
@event.listens_for(Table, "before_drop")
def before_drop(target, connection, **kw):
dispatch("before-drop", target, connection)
@event.listens_for(Table, "after_drop")
def after_drop(target, connection, **kw):
dispatch("after-drop", target, connection)
def dispatch(event, table, bind):
if event in ("before-create", "before-drop"):
regular_cols = [
c for c in table.c if not isinstance(c.type, Geometry)
]
gis_cols = set(table.c).difference(regular_cols)
table.info["_saved_columns"] = table.c
# temporarily patch a set of columns not including the
# Geometry columns
table.columns = expression.ColumnCollection(*regular_cols)
if event == "before-drop":
for c in gis_cols:
bind.execute(
select(
func.DropGeometryColumn(
"public", table.name, c.name
)
).execution_options(autocommit=True)
)
elif event == "after-create":
table.columns = table.info.pop("_saved_columns")
for c in table.c:
if isinstance(c.type, Geometry):
bind.execute(
select(
func.AddGeometryColumn(
table.name,
c.name,
c.type.srid,
c.type.name,
c.type.dimension,
)
).execution_options(autocommit=True)
)
elif event == "after-drop":
table.columns = table.info.pop("_saved_columns")
setup_ddl_events()
# illustrate usage
if __name__ == "__main__":
from sqlalchemy import (
create_engine,
MetaData,
Column,
Integer,
String,
func,
select,
)
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
engine = create_engine(
"postgresql://scott:tiger@localhost/test", echo=True
)
metadata = MetaData(engine)
Base = declarative_base(metadata=metadata)
class Road(Base):
__tablename__ = "roads"
road_id = Column(Integer, primary_key=True)
road_name = Column(String)
road_geom = Column(Geometry(2))
metadata.drop_all()
metadata.create_all()
session = sessionmaker(bind=engine)()
# Add objects. We can use strings...
session.add_all(
[
Road(
road_name="Jeff Rd",
road_geom="LINESTRING(191232 243118,191108 243242)",
),
Road(
road_name="Geordie Rd",
road_geom="LINESTRING(189141 244158,189265 244817)",
),
Road(
road_name="Paul St",
road_geom="LINESTRING(192783 228138,192612 229814)",
),
Road(
road_name="Graeme Ave",
road_geom="LINESTRING(189412 252431,189631 259122)",
),
Road(
road_name="Phil Tce",
road_geom="LINESTRING(190131 224148,190871 228134)",
),
]
)
# or use an explicit TextualGisElement
# (similar to saying func.GeomFromText())
r = Road(
road_name="Dave Cres",
road_geom=TextualGisElement(
"LINESTRING(198231 263418,198213 268322)", -1
),
)
session.add(r)
# pre flush, the TextualGisElement represents the string we sent.
assert str(r.road_geom) == "LINESTRING(198231 263418,198213 268322)"
session.commit()
# after flush and/or commit, all the TextualGisElements
# become PersistentGisElements.
assert str(r.road_geom) == "LINESTRING(198231 263418,198213 268322)"
r1 = session.query(Road).filter(Road.road_name == "Graeme Ave").one()
# illustrate the overridden __eq__() operator.
# strings come in as TextualGisElements
r2 = (
session.query(Road)
.filter(Road.road_geom == "LINESTRING(189412 252431,189631 259122)")
.one()
)
r3 = session.query(Road).filter(Road.road_geom == r1.road_geom).one()
assert r1 is r2 is r3
# core usage just fine:
road_table = Road.__table__
stmt = select(road_table).where(
road_table.c.road_geom.intersects(r1.road_geom)
)
print(session.execute(stmt).fetchall())
# TODO: for some reason the auto-generated labels have the internal
# replacement strings exposed, even though PG doesn't complain
# look up the hex binary version, using SQLAlchemy casts
as_binary = session.scalar(
select(type_coerce(r.road_geom, Geometry(coerce_="binary")))
)
assert as_binary.as_hex == (
"01020000000200000000000000b832084100000000"
"e813104100000000283208410000000088601041"
)
# back again, same method !
as_text = session.scalar(
select(type_coerce(as_binary, Geometry(coerce_="text")))
)
assert as_text.desc == "LINESTRING(198231 263418,198213 268322)"
session.rollback()
metadata.drop_all()