Automatically use a function at insert or selectΒΆ

Sometimes the application wants to apply a function in an insert or in a select. For example, the application might need the geometry with lat/lon coordinates while they are projected in the DB. To avoid having to always tweak the query with a ST_Transform(), it is possible to define a TypeDecorator

 12 import re
 13 from typing import Any
 14
 15 from sqlalchemy import Column
 16 from sqlalchemy import Integer
 17 from sqlalchemy import MetaData
 18 from sqlalchemy import func
 19 from sqlalchemy import text
 20 from sqlalchemy.orm import declarative_base
 21 from sqlalchemy.types import TypeDecorator
 22
 23 from geoalchemy2 import Geometry
 24 from geoalchemy2 import shape
 25
 26 # Tests imports
 27 from tests import test_only_with_dialects
 28
 29 metadata = MetaData()
 30
 31 Base = declarative_base(metadata=metadata)
 32
 33
 34 class TransformedGeometry(TypeDecorator):
 35     """This class is used to insert a ST_Transform() in each insert or select."""
 36
 37     impl = Geometry
 38
 39     cache_ok = True
 40
 41     def __init__(self, db_srid, app_srid, **kwargs):
 42         kwargs["srid"] = db_srid
 43         super().__init__(**kwargs)
 44         self.app_srid = app_srid
 45         self.db_srid = db_srid
 46
 47     def column_expression(self, col):
 48         """The column_expression() method is overridden to set the correct type.
 49
 50         This is needed so that the returned element will also be decorated. In this case we don't
 51         want to transform it again afterwards so we set the same SRID to both the ``db_srid`` and
 52         ``app_srid`` arguments.
 53         Without this the SRID of the WKBElement would be wrong.
 54         """
 55         return getattr(func, self.impl.as_binary)(
 56             func.ST_Transform(col, self.app_srid),
 57             type_=self.__class__(db_srid=self.app_srid, app_srid=self.app_srid),
 58         )
 59
 60     def bind_expression(self, bindvalue):
 61         return func.ST_Transform(
 62             self.impl.bind_expression(bindvalue),
 63             self.db_srid,
 64             type_=self,
 65         )
 66
 67
 68 class ThreeDGeometry(TypeDecorator):
 69     """This class is used to insert a ST_Force3D() in each insert."""
 70
 71     impl = Geometry
 72
 73     cache_ok = True
 74
 75     def column_expression(self, col):
 76         """The column_expression() method is overridden to set the correct type.
 77
 78         This is not needed in this example but it is needed if one wants to override other methods
 79         of the TypeDecorator class, like ``process_result_value()`` for example.
 80         """
 81         return getattr(func, self.impl.as_binary)(col, type_=self)
 82
 83     def bind_expression(self, bindvalue):
 84         return func.ST_Force3D(
 85             self.impl.bind_expression(bindvalue),
 86             type=self,
 87         )
 88
 89
 90 class Point(Base):  # type: ignore
 91     __tablename__ = "point"
 92     id = Column(Integer, primary_key=True)
 93     raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
 94     geom: Column[Any] = Column(
 95         TransformedGeometry(db_srid=2154, app_srid=4326, geometry_type="POINT")
 96     )
 97     three_d_geom: Column = Column(ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
 98
 99
100 def check_wkb(wkb, x, y):
101     pt = shape.to_shape(wkb)
102     assert round(pt.x, 5) == x
103     assert round(pt.y, 5) == y
104
105
106 @test_only_with_dialects("postgresql")
107 class TestTypeDecorator:
108     def _create_one_point(self, session, conn):
109         metadata.drop_all(conn, checkfirst=True)
110         metadata.create_all(conn)
111
112         # Create new point instance
113         p = Point()
114         p.raw_geom = "SRID=4326;POINT(5 45)"
115         p.geom = "SRID=4326;POINT(5 45)"
116         p.three_d_geom = "SRID=4326;POINT(5 45)"  # Insert 2D geometry into 3D column
117
118         # Insert point
119         session.add(p)
120         session.flush()
121         session.expire(p)
122
123         return p.id
124
125     def test_transform(self, session, conn):
126         self._create_one_point(session, conn)
127
128         # Query the point and check the result
129         pt = session.query(Point).one()
130         assert pt.id == 1
131         assert pt.raw_geom.srid == 4326
132         check_wkb(pt.raw_geom, 5, 45)
133
134         assert pt.geom.srid == 4326
135         check_wkb(pt.geom, 5, 45)
136
137         # Check that the data is correct in DB using raw query
138         q = text("SELECT id, ST_AsEWKT(geom) AS geom FROM point;")
139         res_q = session.execute(q).fetchone()
140         assert res_q.id == 1
141         assert re.match(
142             r"SRID=2154;POINT\(857581\.8993196681? 6435414\.7478354[0-9]*\)", res_q.geom
143         )
144
145         # Compare geom, raw_geom with auto transform and explicit transform
146         pt_trans = session.query(
147             Point,
148             Point.raw_geom,
149             func.ST_Transform(Point.raw_geom, 2154).label("trans"),
150         ).one()
151
152         assert pt_trans[0].id == 1
153
154         assert pt_trans[0].geom.srid == 4326
155         check_wkb(pt_trans[0].geom, 5, 45)
156
157         assert pt_trans[0].raw_geom.srid == 4326
158         check_wkb(pt_trans[0].raw_geom, 5, 45)
159
160         assert pt_trans[1].srid == 4326
161         check_wkb(pt_trans[1], 5, 45)
162
163         assert pt_trans[2].srid == 2154
164         check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
165
166     def test_force_3d(self, session, conn):
167         self._create_one_point(session, conn)
168
169         # Query the point and check the result
170         pt = session.query(Point).one()
171
172         assert pt.id == 1
173         assert pt.three_d_geom.srid == 4326
174         assert pt.three_d_geom.desc.lower() == (
175             "01010000a0e6100000000000000000144000000000008046400000000000000000"
176         )

Gallery generated by Sphinx-Gallery