Use CompositeTypeΒΆ

Some functions return composite types. This example shows how to deal with this kind of functions.

  9 import pytest
 10 from pkg_resources import parse_version
 11 from sqlalchemy import Column
 12 from sqlalchemy import Float
 13 from sqlalchemy import Integer
 14 from sqlalchemy import MetaData
 15 from sqlalchemy import __version__ as SA_VERSION
 16 from sqlalchemy.orm import declarative_base
 17
 18 from geoalchemy2 import Raster
 19 from geoalchemy2 import WKTElement
 20 from geoalchemy2.functions import GenericFunction
 21 from geoalchemy2.types import CompositeType
 22
 23 # Tests imports
 24 from tests import select
 25 from tests import test_only_with_dialects
 26
 27
 28 class SummaryStatsCustomType(CompositeType):
 29     """Define the composite type returned by the function ST_SummaryStatsAgg."""
 30
 31     typemap = {
 32         "count": Integer,
 33         "sum": Float,
 34         "mean": Float,
 35         "stddev": Float,
 36         "min": Float,
 37         "max": Float,
 38     }
 39
 40     cache_ok = True
 41
 42
 43 class ST_SummaryStatsAgg(GenericFunction):
 44     type = SummaryStatsCustomType()
 45     # Set a specific identifier to not override the actual ST_SummaryStatsAgg function
 46     identifier = "ST_SummaryStatsAgg_custom"
 47
 48     inherit_cache = True
 49
 50
 51 metadata = MetaData()
 52 Base = declarative_base(metadata=metadata)
 53
 54
 55 class Ocean(Base):  # type: ignore
 56     __tablename__ = "ocean"
 57     id = Column(Integer, primary_key=True)
 58     rast = Column(Raster)
 59
 60     def __init__(self, rast):
 61         self.rast = rast
 62
 63
 64 @test_only_with_dialects("postgresql")
 65 class TestSTSummaryStatsAgg:
 66     @pytest.mark.skipif(
 67         parse_version(SA_VERSION) < parse_version("1.4"),
 68         reason="requires SQLAlchely>1.4",
 69     )
 70     def test_st_summary_stats_agg(self, session, conn):
 71         metadata.drop_all(conn, checkfirst=True)
 72         metadata.create_all(conn)
 73
 74         # Create a new raster
 75         polygon = WKTElement("POLYGON((0 0,1 1,0 1,0 0))", srid=4326)
 76         o = Ocean(polygon.ST_AsRaster(5, 6))
 77         session.add(o)
 78         session.flush()
 79
 80         # Define the query to compute stats
 81         stats_agg = select([Ocean.rast.ST_SummaryStatsAgg_custom(1, True, 1).label("stats")])
 82         stats_agg_alias = stats_agg.alias("stats_agg")
 83
 84         # Use these stats
 85         query = select(
 86             [
 87                 stats_agg_alias.c.stats.count.label("count"),
 88                 stats_agg_alias.c.stats.sum.label("sum"),
 89                 stats_agg_alias.c.stats.mean.label("mean"),
 90                 stats_agg_alias.c.stats.stddev.label("stddev"),
 91                 stats_agg_alias.c.stats.min.label("min"),
 92                 stats_agg_alias.c.stats.max.label("max"),
 93             ]
 94         )
 95
 96         # Check the query
 97         assert str(query.compile(dialect=session.bind.dialect)) == (
 98             "SELECT "
 99             "(stats_agg.stats).count AS count, "
100             "(stats_agg.stats).sum AS sum, "
101             "(stats_agg.stats).mean AS mean, "
102             "(stats_agg.stats).stddev AS stddev, "
103             "(stats_agg.stats).min AS min, "
104             "(stats_agg.stats).max AS max \n"
105             "FROM ("
106             "SELECT "
107             "ST_SummaryStatsAgg("
108             "ocean.rast, "
109             "%(ST_SummaryStatsAgg_1)s, %(ST_SummaryStatsAgg_2)s, %(ST_SummaryStatsAgg_3)s"
110             ") AS stats \n"
111             "FROM ocean) AS stats_agg"
112         )
113
114         # Execute the query
115         res = session.execute(query).fetchall()
116
117         # Check the result
118         assert res == [(15, 15.0, 1.0, 0.0, 1.0, 1.0)]

Gallery generated by Sphinx-Gallery