Note
Go to the end to download the full example code
Use CompositeTypeΒΆ
Some functions return composite types. This example shows how to deal with this kind of functions.
9 import pytest
10 from pkg_resources import parse_version
11 from sqlalchemy import Column
12 from sqlalchemy import Float
13 from sqlalchemy import Integer
14 from sqlalchemy import MetaData
15 from sqlalchemy import __version__ as SA_VERSION
16 from sqlalchemy.orm import declarative_base
17
18 from geoalchemy2 import Raster
19 from geoalchemy2 import WKTElement
20 from geoalchemy2.functions import GenericFunction
21 from geoalchemy2.types import CompositeType
22
23 # Tests imports
24 from tests import select
25 from tests import test_only_with_dialects
26
27
28 class SummaryStatsCustomType(CompositeType):
29 """Define the composite type returned by the function ST_SummaryStatsAgg."""
30
31 typemap = {
32 "count": Integer,
33 "sum": Float,
34 "mean": Float,
35 "stddev": Float,
36 "min": Float,
37 "max": Float,
38 }
39
40 cache_ok = True
41
42
43 class ST_SummaryStatsAgg(GenericFunction):
44 type = SummaryStatsCustomType()
45 # Set a specific identifier to not override the actual ST_SummaryStatsAgg function
46 identifier = "ST_SummaryStatsAgg_custom"
47
48 inherit_cache = True
49
50
51 metadata = MetaData()
52 Base = declarative_base(metadata=metadata)
53
54
55 class Ocean(Base): # type: ignore
56 __tablename__ = "ocean"
57 id = Column(Integer, primary_key=True)
58 rast = Column(Raster)
59
60 def __init__(self, rast):
61 self.rast = rast
62
63
64 @test_only_with_dialects("postgresql")
65 class TestSTSummaryStatsAgg:
66 @pytest.mark.skipif(
67 parse_version(SA_VERSION) < parse_version("1.4"),
68 reason="requires SQLAlchely>1.4",
69 )
70 def test_st_summary_stats_agg(self, session, conn):
71 metadata.drop_all(conn, checkfirst=True)
72 metadata.create_all(conn)
73
74 # Create a new raster
75 polygon = WKTElement("POLYGON((0 0,1 1,0 1,0 0))", srid=4326)
76 o = Ocean(polygon.ST_AsRaster(5, 6))
77 session.add(o)
78 session.flush()
79
80 # Define the query to compute stats
81 stats_agg = select([Ocean.rast.ST_SummaryStatsAgg_custom(1, True, 1).label("stats")])
82 stats_agg_alias = stats_agg.alias("stats_agg")
83
84 # Use these stats
85 query = select(
86 [
87 stats_agg_alias.c.stats.count.label("count"),
88 stats_agg_alias.c.stats.sum.label("sum"),
89 stats_agg_alias.c.stats.mean.label("mean"),
90 stats_agg_alias.c.stats.stddev.label("stddev"),
91 stats_agg_alias.c.stats.min.label("min"),
92 stats_agg_alias.c.stats.max.label("max"),
93 ]
94 )
95
96 # Check the query
97 assert str(query.compile(dialect=session.bind.dialect)) == (
98 "SELECT "
99 "(stats_agg.stats).count AS count, "
100 "(stats_agg.stats).sum AS sum, "
101 "(stats_agg.stats).mean AS mean, "
102 "(stats_agg.stats).stddev AS stddev, "
103 "(stats_agg.stats).min AS min, "
104 "(stats_agg.stats).max AS max \n"
105 "FROM ("
106 "SELECT "
107 "ST_SummaryStatsAgg("
108 "ocean.rast, "
109 "%(ST_SummaryStatsAgg_1)s, %(ST_SummaryStatsAgg_2)s, %(ST_SummaryStatsAgg_3)s"
110 ") AS stats \n"
111 "FROM ocean) AS stats_agg"
112 )
113
114 # Execute the query
115 res = session.execute(query).fetchall()
116
117 # Check the result
118 assert res == [(15, 15.0, 1.0, 0.0, 1.0, 1.0)]