Note
Go to the end to download the full example code
Automatically use a function at insert or selectΒΆ
Sometimes the application wants to apply a function in an insert or in a select.
For example, the application might need the geometry with lat/lon coordinates while they
are projected in the DB. To avoid having to always tweak the query with a
ST_Transform()
, it is possible to define a TypeDecorator
12 import re
13 from typing import Any
14
15 from sqlalchemy import Column
16 from sqlalchemy import Integer
17 from sqlalchemy import MetaData
18 from sqlalchemy import func
19 from sqlalchemy import text
20 from sqlalchemy.orm import declarative_base
21 from sqlalchemy.types import TypeDecorator
22
23 from geoalchemy2 import Geometry
24 from geoalchemy2 import shape
25
26 # Tests imports
27 from tests import test_only_with_dialects
28
29 metadata = MetaData()
30
31 Base = declarative_base(metadata=metadata)
32
33
34 class TransformedGeometry(TypeDecorator):
35 """This class is used to insert a ST_Transform() in each insert or select."""
36
37 impl = Geometry
38
39 cache_ok = True
40
41 def __init__(self, db_srid, app_srid, **kwargs):
42 kwargs["srid"] = db_srid
43 super().__init__(**kwargs)
44 self.app_srid = app_srid
45 self.db_srid = db_srid
46
47 def column_expression(self, col):
48 """The column_expression() method is overridden to set the correct type.
49
50 This is needed so that the returned element will also be decorated. In this case we don't
51 want to transform it again afterwards so we set the same SRID to both the ``db_srid`` and
52 ``app_srid`` arguments.
53 Without this the SRID of the WKBElement would be wrong.
54 """
55 return getattr(func, self.impl.as_binary)(
56 func.ST_Transform(col, self.app_srid),
57 type_=self.__class__(db_srid=self.app_srid, app_srid=self.app_srid),
58 )
59
60 def bind_expression(self, bindvalue):
61 return func.ST_Transform(
62 self.impl.bind_expression(bindvalue),
63 self.db_srid,
64 type_=self,
65 )
66
67
68 class ThreeDGeometry(TypeDecorator):
69 """This class is used to insert a ST_Force3D() in each insert."""
70
71 impl = Geometry
72
73 cache_ok = True
74
75 def column_expression(self, col):
76 """The column_expression() method is overridden to set the correct type.
77
78 This is not needed in this example but it is needed if one wants to override other methods
79 of the TypeDecorator class, like ``process_result_value()`` for example.
80 """
81 return getattr(func, self.impl.as_binary)(col, type_=self)
82
83 def bind_expression(self, bindvalue):
84 return func.ST_Force3D(
85 self.impl.bind_expression(bindvalue),
86 type=self,
87 )
88
89
90 class Point(Base): # type: ignore
91 __tablename__ = "point"
92 id = Column(Integer, primary_key=True)
93 raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
94 geom: Column[Any] = Column(
95 TransformedGeometry(db_srid=2154, app_srid=4326, geometry_type="POINT")
96 )
97 three_d_geom: Column = Column(ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
98
99
100 def check_wkb(wkb, x, y):
101 pt = shape.to_shape(wkb)
102 assert round(pt.x, 5) == x
103 assert round(pt.y, 5) == y
104
105
106 @test_only_with_dialects("postgresql")
107 class TestTypeDecorator:
108 def _create_one_point(self, session, conn):
109 metadata.drop_all(conn, checkfirst=True)
110 metadata.create_all(conn)
111
112 # Create new point instance
113 p = Point()
114 p.raw_geom = "SRID=4326;POINT(5 45)"
115 p.geom = "SRID=4326;POINT(5 45)"
116 p.three_d_geom = "SRID=4326;POINT(5 45)" # Insert 2D geometry into 3D column
117
118 # Insert point
119 session.add(p)
120 session.flush()
121 session.expire(p)
122
123 return p.id
124
125 def test_transform(self, session, conn):
126 self._create_one_point(session, conn)
127
128 # Query the point and check the result
129 pt = session.query(Point).one()
130 assert pt.id == 1
131 assert pt.raw_geom.srid == 4326
132 check_wkb(pt.raw_geom, 5, 45)
133
134 assert pt.geom.srid == 4326
135 check_wkb(pt.geom, 5, 45)
136
137 # Check that the data is correct in DB using raw query
138 q = text("SELECT id, ST_AsEWKT(geom) AS geom FROM point;")
139 res_q = session.execute(q).fetchone()
140 assert res_q.id == 1
141 assert re.match(
142 r"SRID=2154;POINT\(857581\.8993196681? 6435414\.7478354[0-9]*\)", res_q.geom
143 )
144
145 # Compare geom, raw_geom with auto transform and explicit transform
146 pt_trans = session.query(
147 Point,
148 Point.raw_geom,
149 func.ST_Transform(Point.raw_geom, 2154).label("trans"),
150 ).one()
151
152 assert pt_trans[0].id == 1
153
154 assert pt_trans[0].geom.srid == 4326
155 check_wkb(pt_trans[0].geom, 5, 45)
156
157 assert pt_trans[0].raw_geom.srid == 4326
158 check_wkb(pt_trans[0].raw_geom, 5, 45)
159
160 assert pt_trans[1].srid == 4326
161 check_wkb(pt_trans[1], 5, 45)
162
163 assert pt_trans[2].srid == 2154
164 check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
165
166 def test_force_3d(self, session, conn):
167 self._create_one_point(session, conn)
168
169 # Query the point and check the result
170 pt = session.query(Point).one()
171
172 assert pt.id == 1
173 assert pt.three_d_geom.srid == 4326
174 assert pt.three_d_geom.desc.lower() == (
175 "01010000a0e6100000000000000000144000000000008046400000000000000000"
176 )