"""This module defines specific functions for MSSQL dialect."""
import hashlib
import math
import re
from collections.abc import Mapping
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import text
from sqlalchemy.dialects.mssql.base import ischema_names as _mssql_ischema_names
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import expression
from sqlalchemy.sql import operators
from sqlalchemy.sql import visitors
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql.elements import Null
from sqlalchemy.sql.sqltypes import NullType
from sqlalchemy.types import LargeBinary
from sqlalchemy.types import TypeDecorator
from sqlalchemy.types import UnicodeText
from geoalchemy2 import functions
from geoalchemy2.admin.dialects.common import _check_spatial_type
from geoalchemy2.admin.dialects.common import _spatial_idx_name
from geoalchemy2.elements import WKBElement
from geoalchemy2.elements import WKTElement
from geoalchemy2.exc import ArgumentError
from geoalchemy2.types import Geography
from geoalchemy2.types import Geometry
from geoalchemy2.types.dialects.mssql import _normalize_wkt_for_mssql
from geoalchemy2.types.dialects.mssql import _to_mssql_wkt
from geoalchemy2.types.dialects.mssql import bind_processor_process as _type_bind_processor_process
_mssql_ischema_names["geometry"] = Geometry
_mssql_ischema_names["geography"] = Geography
# Register GeoAlchemy's spatial index kwargs so SQLAlchemy accepts them on Index(...).
for _dialect_kwarg in ("bounding_box", "cells_per_object", "grids", "using", "with"):
Index.argument_for("mssql", _dialect_kwarg, None)
_MSSQL_WORLD_BOUNDING_BOX = (-180.0, -90.0, 180.0, 90.0)
_MSSQL_DEFAULT_BOUNDING_BOX = (-1000000000.0, -1000000000.0, 1000000000.0, 1000000000.0)
_MSSQL_GEOMETRY_TYPE_NAMES = {
"POINT": "Point",
"LINESTRING": "LineString",
"POLYGON": "Polygon",
"MULTIPOINT": "MultiPoint",
"MULTILINESTRING": "MultiLineString",
"MULTIPOLYGON": "MultiPolygon",
"GEOMETRYCOLLECTION": "GeometryCollection",
}
_MSSQL_GEOMETRY_TYPE_LOOKUP = {
value.upper(): key for key, value in _MSSQL_GEOMETRY_TYPE_NAMES.items()
}
_MSSQL_BOUNDING_BOX_ERROR = (
"mssql_bounding_box must be a 4-value tuple/list or comma-separated string "
"formatted as finite numeric coordinates: xmin, ymin, xmax, ymax"
)
_MSSQL_DYNAMIC_EWKT_KEY_PREFIX = "_geoalchemy2_mssql_ewkt"
_MSSQL_DYNAMIC_EWKB_KEY_PREFIX = "_geoalchemy2_mssql_ewkb"
_MSSQL_DISABLE_DYNAMIC_EWKT_SPLIT_OPTION = "geoalchemy2_mssql_disable_dynamic_ewkt_split"
def _quote_mssql_identifier(name):
return f"[{name.replace(']', ']]')}]"
def _quote_mssql_string(value):
escaped_value = value.replace("'", "''")
return f"N'{escaped_value}'"
def _quote_mssql_table_name(table_name, schema=None):
if schema:
return f"{_quote_mssql_identifier(schema)}.{_quote_mssql_identifier(table_name)}"
return _quote_mssql_identifier(table_name)
def _get_mssql_full_table_name(table_name, schema=None):
if schema:
return f"{schema}.{table_name}"
return table_name
def _format_mssql_number(value):
try:
return format(float(value), "g")
except (TypeError, ValueError): # pragma: no cover
return str(value)
def _format_mssql_bounding_box(bounding_box):
if isinstance(bounding_box, str):
bounding_box = [value.strip() for value in bounding_box.split(",")]
if not isinstance(bounding_box, (tuple, list)):
raise ArgumentError(_MSSQL_BOUNDING_BOX_ERROR)
try:
xmin, ymin, xmax, ymax = bounding_box
except ValueError as exc:
raise ArgumentError(_MSSQL_BOUNDING_BOX_ERROR) from exc
try:
values = [float(value) for value in (xmin, ymin, xmax, ymax)]
except (TypeError, ValueError) as exc:
raise ArgumentError(_MSSQL_BOUNDING_BOX_ERROR) from exc
if not all(math.isfinite(value) for value in values):
raise ArgumentError(_MSSQL_BOUNDING_BOX_ERROR)
return ", ".join(_format_mssql_number(value) for value in values)
def _base_mssql_geometry_type(geometry_type):
if geometry_type is None:
return None
geometry_type = geometry_type.upper()
if geometry_type.endswith("ZM"):
return geometry_type[:-2]
if geometry_type.endswith(("Z", "M")):
return geometry_type[:-1]
return geometry_type
def _mssql_geometry_type_constraint_value(geometry_type):
base_geometry_type = _base_mssql_geometry_type(geometry_type)
if base_geometry_type in (None, "GEOMETRY"):
return None
return _MSSQL_GEOMETRY_TYPE_NAMES.get(base_geometry_type)
def _mssql_geometry_type_constraint_prefix(geometry_type):
base_geometry_type = _base_mssql_geometry_type(geometry_type)
if base_geometry_type in (None, "GEOMETRY"):
return None
return base_geometry_type if base_geometry_type in _MSSQL_GEOMETRY_TYPE_NAMES else None
def _mssql_spatial_constraint_name(table_name, column_name, constraint_type):
return f"ck_{table_name}_{column_name}_{constraint_type}"
def _column_regex(column_name):
quoted_column = re.escape(column_name.replace("]", "]]"))
unquoted_column = re.escape(column_name)
return rf"(?:\[{quoted_column}\]|{unquoted_column})"
def _default_mssql_bounding_box(col_type, is_geography=False):
if is_geography:
return None
if getattr(col_type, "srid", None) == 4326:
return _MSSQL_WORLD_BOUNDING_BOX
return _MSSQL_DEFAULT_BOUNDING_BOX
def _get_mssql_column_type_name(bind, table_name, column_name, schema=None):
full_table_name = _get_mssql_full_table_name(table_name, schema=schema)
type_query = text(
"""SELECT t.name
FROM sys.columns AS c
JOIN sys.types AS t
ON c.user_type_id = t.user_type_id
WHERE c.object_id = OBJECT_ID(:full_table_name) AND c.name = :column_name"""
)
return bind.execute(
type_query,
{"full_table_name": full_table_name, "column_name": column_name},
).scalar()
def _get_mssql_spatial_column_constraints(bind, table_name, column_name, schema=None):
full_table_name = _get_mssql_full_table_name(table_name, schema=schema)
constraints_query = text(
"""SELECT definition
FROM sys.check_constraints
WHERE parent_object_id = OBJECT_ID(:full_table_name)"""
)
column_pattern = _column_regex(column_name)
srid_pattern = re.compile(
rf"{column_pattern}\s*\.\s*(?:\[STSrid\]|STSrid)\s*=\s*\(?\s*(-?\d+)\s*\)?",
re.IGNORECASE,
)
geometry_type_pattern = re.compile(
rf"{column_pattern}\s*\.\s*(?:\[STGeometryType\]|STGeometryType)\s*"
rf"\(\s*\)\s*=\s*\(?\s*N?'([^']+)'",
re.IGNORECASE,
)
geometry_type_prefix_pattern = re.compile(
rf"{column_pattern}\s*\.\s*(?:\[AsTextZM\]|AsTextZM)\s*\(\s*\)\s*\)*\s+"
rf"LIKE\s+\(?\s*N?'([^'%]+)%",
re.IGNORECASE,
)
srid = -1
geometry_type = "GEOMETRY"
for definition in bind.execute(
constraints_query,
{"full_table_name": full_table_name},
).scalars():
srid_match = srid_pattern.search(definition)
if srid_match:
srid = int(srid_match.group(1))
geometry_type_match = geometry_type_pattern.search(definition)
if geometry_type_match:
geometry_type = _MSSQL_GEOMETRY_TYPE_LOOKUP.get(
geometry_type_match.group(1).upper(),
geometry_type,
)
continue
geometry_type_prefix_match = geometry_type_prefix_pattern.search(definition)
if geometry_type_prefix_match:
geometry_type = _MSSQL_GEOMETRY_TYPE_LOOKUP.get(
geometry_type_prefix_match.group(1).upper(),
geometry_type,
)
return geometry_type, srid
def _get_mssql_spatial_indexes(bind, table_name, schema=None, column_name=None):
full_table_name = _get_mssql_full_table_name(table_name, schema=schema)
where_clauses = ["i.object_id = OBJECT_ID(:full_table_name)", "i.type_desc = 'SPATIAL'"]
params = {"full_table_name": full_table_name}
if column_name is not None:
where_clauses.append("c.name = :column_name")
params["column_name"] = column_name
spatial_index_query = text(
f"""SELECT
i.name AS index_name,
c.name AS column_name,
si.tessellation_scheme,
sit.cells_per_object,
sit.bounding_box_xmin,
sit.bounding_box_ymin,
sit.bounding_box_xmax,
sit.bounding_box_ymax,
sit.level_1_grid_desc,
sit.level_2_grid_desc,
sit.level_3_grid_desc,
sit.level_4_grid_desc
FROM sys.indexes AS i
JOIN sys.index_columns AS ic
ON i.object_id = ic.object_id
AND i.index_id = ic.index_id
JOIN sys.columns AS c
ON ic.object_id = c.object_id
AND ic.column_id = c.column_id
LEFT JOIN sys.spatial_indexes AS si
ON i.object_id = si.object_id
AND i.index_id = si.index_id
LEFT JOIN sys.spatial_index_tessellations AS sit
ON i.object_id = sit.object_id
AND i.index_id = sit.index_id
WHERE {" AND ".join(where_clauses)}
ORDER BY i.name"""
)
spatial_indexes = []
for row in bind.execute(spatial_index_query, params).mappings():
dialect_options = {}
if row["tessellation_scheme"] is not None:
dialect_options["mssql_using"] = row["tessellation_scheme"]
if row["cells_per_object"] is not None:
dialect_options["mssql_cells_per_object"] = int(row["cells_per_object"])
if row["bounding_box_xmin"] is not None:
dialect_options["mssql_bounding_box"] = (
row["bounding_box_xmin"],
row["bounding_box_ymin"],
row["bounding_box_xmax"],
row["bounding_box_ymax"],
)
grids = tuple(
level
for level in (
row["level_1_grid_desc"],
row["level_2_grid_desc"],
row["level_3_grid_desc"],
row["level_4_grid_desc"],
)
if level is not None
)
if len(grids) == 4:
dialect_options["mssql_grids"] = grids
spatial_indexes.append(
{
"name": row["index_name"],
"column_name": row["column_name"],
"dialect_options": dialect_options,
}
)
return spatial_indexes
def _get_mssql_spatial_index_with_clauses(col_type, idx_kwargs, is_geography=False):
with_clauses = []
raw_with = idx_kwargs.get("mssql_with")
if raw_with:
if isinstance(raw_with, str):
with_clauses.append(raw_with)
else:
with_clauses.extend(raw_with)
grids = idx_kwargs.get("mssql_grids")
if grids:
if isinstance(grids, str):
if grids.lstrip().upper().startswith("GRIDS"):
grids_clause = grids
else:
grids_clause = f"GRIDS = ({grids})"
else:
grids_clause = f"GRIDS = ({', '.join(str(level) for level in grids)})"
with_clauses.append(grids_clause)
cells_per_object = idx_kwargs.get("mssql_cells_per_object")
if cells_per_object is not None:
with_clauses.append(f"CELLS_PER_OBJECT = {int(cells_per_object)}")
if not is_geography:
has_bounding_box = any(
clause.lstrip().upper().startswith("BOUNDING_BOX") for clause in with_clauses
)
if not has_bounding_box:
bounding_box = idx_kwargs.get("mssql_bounding_box") or _default_mssql_bounding_box(
col_type,
is_geography=is_geography,
)
if bounding_box is not None:
with_clauses.insert(
0,
f"BOUNDING_BOX = ({_format_mssql_bounding_box(bounding_box)})",
)
return with_clauses
def create_spatial_index(
bind, table_name, column_name, col_type, schema=None, index_name=None, **idx_kwargs
):
index_name = index_name or _spatial_idx_name(table_name, column_name)
table_ref = _quote_mssql_table_name(table_name, schema=schema)
is_geography = _check_spatial_type(col_type, Geography, bind.dialect)
if not _check_spatial_type(col_type, (Geometry, Geography), bind.dialect):
type_name = _get_mssql_column_type_name(bind, table_name, column_name, schema=schema)
is_geography = str(type_name).lower() == "geography"
if is_geography:
using = idx_kwargs.get("mssql_using", "GEOGRAPHY_AUTO_GRID")
else:
using = idx_kwargs.get("mssql_using", "GEOMETRY_AUTO_GRID")
ddl = [
f"CREATE SPATIAL INDEX {_quote_mssql_identifier(index_name)}",
f"ON {table_ref} ({_quote_mssql_identifier(column_name)})",
]
if using:
ddl.append(f"USING {using}")
with_clauses = _get_mssql_spatial_index_with_clauses(
col_type, idx_kwargs, is_geography=is_geography
)
if with_clauses:
ddl.append(f"WITH ({', '.join(with_clauses)})")
bind.execute(text(" ".join(ddl)))
def create_spatial_constraints(bind, table_name, column_name, col_type, schema=None):
col_type = _resolve_mssql_spatial_type(col_type, bind.dialect)
table_ref = _quote_mssql_table_name(table_name, schema=schema)
column_ref = _quote_mssql_identifier(column_name)
if getattr(col_type, "srid", -1) >= 0:
constraint_name = _mssql_spatial_constraint_name(table_name, column_name, "srid")
bind.execute(
text(
f"ALTER TABLE {table_ref} ADD CONSTRAINT "
f"{_quote_mssql_identifier(constraint_name)} CHECK "
f"({column_ref} IS NULL OR {column_ref}.STSrid = {int(col_type.srid)})"
)
)
geometry_type_prefix = _mssql_geometry_type_constraint_prefix(
getattr(col_type, "geometry_type", None)
)
if geometry_type_prefix is not None:
constraint_name = _mssql_spatial_constraint_name(table_name, column_name, "geometry_type")
bind.execute(
text(
f"ALTER TABLE {table_ref} ADD CONSTRAINT "
f"{_quote_mssql_identifier(constraint_name)} CHECK "
f"({column_ref} IS NULL OR UPPER({column_ref}.AsTextZM()) LIKE "
f"{_quote_mssql_string(f'{geometry_type_prefix}%')})"
)
)
def drop_spatial_constraints(bind, table_name, column_name, schema=None):
full_table_name = _get_mssql_full_table_name(table_name, schema=schema)
table_ref = _quote_mssql_table_name(table_name, schema=schema)
srid_constraint_name = _mssql_spatial_constraint_name(table_name, column_name, "srid")
geometry_type_constraint_name = _mssql_spatial_constraint_name(
table_name,
column_name,
"geometry_type",
)
constraints_query = text(
"""SELECT DISTINCT cc.name
FROM sys.check_constraints AS cc
WHERE cc.parent_object_id = OBJECT_ID(:full_table_name)
AND cc.name IN (:srid_constraint_name, :geometry_type_constraint_name)"""
)
constraint_names = list(
bind.execute(
constraints_query,
{
"full_table_name": full_table_name,
"srid_constraint_name": srid_constraint_name,
"geometry_type_constraint_name": geometry_type_constraint_name,
},
).scalars()
)
for constraint_name in constraint_names:
bind.execute(
text(
f"ALTER TABLE {table_ref} DROP CONSTRAINT "
f"{_quote_mssql_identifier(constraint_name)}"
)
)
def drop_spatial_index(bind, table_name, index_name, schema=None):
table_ref = _quote_mssql_table_name(table_name, schema=schema)
bind.execute(text(f"DROP INDEX {_quote_mssql_identifier(index_name)} ON {table_ref}"))
[docs]
def reflect_geometry_column(inspector, table, column_info):
"""Reflect a geometry or geography column with the MSSQL dialect."""
if not isinstance(column_info.get("type"), (Geometry, Geography, NullType)):
return
column_name = column_info["name"]
schema = table.schema or inspector.default_schema_name
full_table_name = _get_mssql_full_table_name(table.name, schema=schema)
type_query = text(
"""SELECT t.name, c.is_nullable
FROM sys.columns AS c
JOIN sys.types AS t ON c.user_type_id = t.user_type_id
WHERE c.object_id = OBJECT_ID(:full_table_name) AND c.name = :column_name"""
)
type_name, is_nullable = inspector.bind.execute(
type_query,
{"full_table_name": full_table_name, "column_name": column_name},
).one()
type_name = type_name.lower()
if type_name not in ("geometry", "geography"):
return
spatial_index = bool(
_get_mssql_spatial_indexes(
inspector.bind,
table.name,
schema=schema,
column_name=column_name,
)
)
geometry_type, srid = _get_mssql_spatial_column_constraints(
inspector.bind,
table.name,
column_name,
schema=schema,
)
spatial_type = Geography if type_name == "geography" else Geometry
column_info["type"] = spatial_type(
geometry_type=geometry_type,
srid=srid,
spatial_index=spatial_index,
nullable=bool(is_nullable),
_spatial_index_reflected=True,
)
def _is_mssql_generated_spatial_index(idx, table, col):
columns = list(idx.columns.values())
return (
getattr(idx, "_column_flag", False)
and len(columns) == 1
and columns[0] is col
and idx.name == _spatial_idx_name(table.name, col.name)
and getattr(col.type, "spatial_index", False)
)
[docs]
def before_create(table, bind, **kw):
"""Remove spatial indexes from CREATE TABLE so they can be emitted separately."""
schema = table.schema
if schema and schema != bind.dialect.default_schema_name:
quoted_schema = _quote_mssql_identifier(schema)
schema_literal = schema.replace("'", "''")
quoted_schema_literal = quoted_schema.replace("'", "''")
bind.exec_driver_sql(
f"IF SCHEMA_ID(N'{schema_literal}') IS NULL "
f"EXEC(N'CREATE SCHEMA {quoted_schema_literal}')"
)
table.info["_after_create_indexes"] = []
current_indexes = set(table.indexes)
for idx in current_indexes:
for col in table.columns:
if (
_check_spatial_type(col.type, (Geometry, Geography), bind.dialect)
and col in idx.columns.values()
):
table.indexes.remove(idx)
if not _is_mssql_generated_spatial_index(idx, table, col):
table.info["_after_create_indexes"].append(idx)
break
def after_create(table, bind, **kw):
dialect = bind.dialect
after_create_indexes = table.info.pop("_after_create_indexes", [])
delayed_spatial_index_cols = set()
for idx in after_create_indexes:
columns = list(idx.columns.values())
for col in columns:
if not _check_spatial_type(col.type, (Geometry, Geography), dialect):
continue
if len(columns) == 1 or idx.name == _spatial_idx_name(table.name, col.name):
delayed_spatial_index_cols.add(col.name)
for col in table.columns:
if _check_spatial_type(col.type, (Geometry, Geography), dialect):
create_spatial_constraints(bind, table.name, col.name, col.type, schema=table.schema)
if (
_check_spatial_type(col.type, (Geometry, Geography), dialect)
and getattr(col.type, "spatial_index", False)
and col.name not in delayed_spatial_index_cols
):
create_spatial_index(bind, table.name, col.name, col.type, schema=table.schema)
for idx in after_create_indexes:
table.indexes.add(idx)
columns = list(idx.columns.values())
if len(columns) == 1 and _check_spatial_type(
columns[0].type,
(Geometry, Geography),
dialect,
):
create_spatial_index(
bind,
table.name,
columns[0].name,
columns[0].type,
schema=table.schema,
index_name=idx.name,
**idx.kwargs,
)
else:
idx.create(bind=bind)
def before_drop(table, bind, **kw):
return
def after_drop(table, bind, **kw):
return
def _process_wkt_value(value, strip_srid=False):
if isinstance(value, WKTElement):
value = value.data
elif isinstance(value, (WKBElement, bytes, bytearray, memoryview)):
value = _to_mssql_wkt(value)
if isinstance(value, str) and strip_srid:
wkt_match = WKTElement._REMOVE_SRID.match(value)
value = wkt_match.group(3)
if isinstance(value, str):
value = _normalize_wkt_for_mssql(value)
return value
def _process_ewkt_srid_value(value, default_srid=0):
if value is None:
return default_srid
if isinstance(value, WKTElement):
if value.srid >= 0:
return value.srid
value = value.data
elif isinstance(value, WKBElement):
return value.srid if value.srid >= 0 else default_srid
if isinstance(value, str):
wkt_match = WKTElement._REMOVE_SRID.match(value)
srid = wkt_match.group(2)
try:
if srid is not None:
return int(srid)
except (ValueError, TypeError): # pragma: no cover
raise ArgumentError(
f"The SRID ({srid}) of the supplied value can not be casted to integer"
) from None
return default_srid
def _process_wkb_value(value, extended=False):
if value is None:
return None
if isinstance(value, WKBElement):
value = value.as_wkb().data if extended else value.data
elif extended:
value = WKBElement(value, extended=None).as_wkb().data
if isinstance(value, memoryview):
value = value.tobytes()
return value
def _process_ewkb_srid_value(value, default_srid=0):
if value is None:
return default_srid
if isinstance(value, WKBElement):
return value.srid if value.srid >= 0 else default_srid
if isinstance(value, (bytes, bytearray, memoryview, str)):
srid = WKBElement(value, extended=None).srid
return srid if srid >= 0 else default_srid
return default_srid
[docs]
class _MSSQLWKTBindType(TypeDecorator):
impl = UnicodeText
cache_ok = True
"""Allow SQLAlchemy to cache statements using this bind adapter."""
def __init__(self, strip_srid=False, spatial_type=None):
super().__init__()
self.strip_srid = strip_srid
self.spatial_type = spatial_type
[docs]
def process_bind_param(self, value, dialect):
"""Return the MSSQL WKT text value for a bound spatial object."""
if self.spatial_type is not None:
return _type_bind_processor_process(self.spatial_type, value, dialect)
return _process_wkt_value(value, strip_srid=self.strip_srid)
[docs]
class _MSSQLWKBBindType(TypeDecorator):
impl = LargeBinary
cache_ok = True
"""Allow SQLAlchemy to cache statements using this bind adapter."""
def __init__(self, extended=False):
super().__init__()
self.extended = extended
[docs]
def process_bind_param(self, value, dialect):
"""Return the MSSQL WKB value for a bound spatial object."""
return _process_wkb_value(value, extended=self.extended)
[docs]
class _MSSQLDynamicEWKTTextBindType(TypeDecorator):
impl = UnicodeText
cache_ok = True
"""Allow SQLAlchemy to cache statements using this bind adapter."""
[docs]
def process_bind_param(self, value, dialect):
"""Return the text component from a dynamic EWKT value."""
return _process_wkt_value(value, strip_srid=True)
[docs]
class _MSSQLDynamicEWKTSRIDBindType(TypeDecorator):
impl = Integer
cache_ok = True
"""Allow SQLAlchemy to cache statements using this bind adapter."""
def __init__(self, default_srid=0):
super().__init__()
self.default_srid = default_srid
[docs]
def process_bind_param(self, value, dialect):
"""Return the SRID component from a dynamic EWKT value."""
return _process_ewkt_srid_value(value, default_srid=self.default_srid)
[docs]
class _MSSQLDynamicEWKBSRIDBindType(TypeDecorator):
impl = Integer
cache_ok = True
"""Allow SQLAlchemy to cache statements using this bind adapter."""
def __init__(self, default_srid=0):
super().__init__()
self.default_srid = default_srid
[docs]
def process_bind_param(self, value, dialect):
"""Return the SRID component from a dynamic EWKB value."""
return _process_ewkb_srid_value(value, default_srid=self.default_srid)
class _MSSQLDynamicEWKTCallable:
def __init__(self, source_callable):
self.source_callable = source_callable
self._consumer_count = 2
self._pending = None
self._remaining = 0
def add_consumers(self, count):
self._consumer_count += count
def __call__(self):
if self._remaining == 0:
self._pending = self.source_callable()
self._remaining = self._consumer_count
self._remaining -= 1
value = self._pending
if self._remaining == 0:
self._pending = None
return value
def _coerce_wkt_bind_clause(wkt_clause, strip_srid=False, literal=False, spatial_type=None):
if not hasattr(wkt_clause, "value"):
return wkt_clause
if literal:
return expression.bindparam(
key=wkt_clause.key,
value=_process_wkt_value(wkt_clause.value, strip_srid=strip_srid),
type_=UnicodeText(),
unique=True,
)
return expression.type_coerce(
wkt_clause,
_MSSQLWKTBindType(strip_srid=strip_srid, spatial_type=spatial_type),
)
def _coerce_wkb_bind_clause(wkb_clause, extended=False, literal=False):
if not hasattr(wkb_clause, "value"):
return wkb_clause
if literal:
return expression.bindparam(
key=wkb_clause.key,
value=_process_wkb_value(wkb_clause.value, extended=extended),
type_=LargeBinary(),
unique=True,
)
return expression.type_coerce(wkb_clause, _MSSQLWKBBindType(extended=extended))
def _should_coerce_wkt_bind_clause(wkt_clause):
if not hasattr(wkt_clause, "value"):
return False
value = wkt_clause.value
if isinstance(value, (WKTElement, WKBElement, bytes, bytearray, memoryview)):
return True
if not isinstance(value, str):
return False
return _normalize_wkt_for_mssql(value) != value
def _should_coerce_wkt_bind_clause_for_text(wkt_clause, strip_srid=False):
if not _should_coerce_wkt_bind_clause(wkt_clause):
if not strip_srid or not hasattr(wkt_clause, "value"):
return False
value = wkt_clause.value
if isinstance(value, WKTElement):
value = value.data
return isinstance(value, str) and value.startswith("SRID=")
return True
def _is_bindparam_clause(clause):
return isinstance(clause, BindParameter)
def _is_mssql_auto_constructor_bindparam(clause, constructor_name):
return (
_is_bindparam_clause(clause)
and getattr(clause, "unique", False)
and getattr(clause, "_orig_key", None) == constructor_name
)
def _should_coerce_wkb_bind_clause(wkb_clause):
return hasattr(wkb_clause, "value") and isinstance(
wkb_clause.value, (WKBElement, bytes, bytearray, memoryview)
)
def _infer_srid_from_wkb_clause(wkb_clause, default_srid, extended=False):
if not hasattr(wkb_clause, "value"):
return default_srid
value = wkb_clause.value
if isinstance(value, WKBElement):
return value.srid if value.srid >= 0 else default_srid
if extended and isinstance(value, (bytes, bytearray, memoryview)):
srid = WKBElement(value, extended=None).srid
return srid if srid >= 0 else default_srid
return default_srid
def _is_spatial_clause(clause, dialect=None):
return _check_spatial_type(getattr(clause, "type", None), (Geometry, Geography), dialect)
def _is_spatial_function_target(clause, dialect=None):
return _is_spatial_clause(clause, dialect) or isinstance(
getattr(clause, "value", None),
(WKTElement, WKBElement),
)
def _resolve_mssql_spatial_type(spatial_type, dialect):
if isinstance(spatial_type, TypeDecorator):
return spatial_type.load_dialect_impl(dialect)
return spatial_type
def _is_mssql_spatial_constructor(clause):
return isinstance(
clause,
(
functions.ST_GeomFromText,
functions.ST_GeogFromText,
functions.ST_GeomFromEWKT,
functions.ST_GeomFromWKB,
functions.ST_GeogFromWKB,
functions.ST_GeomFromEWKB,
),
)
def _is_mssql_text_constructor(clause):
return isinstance(
clause,
(
functions.ST_GeomFromText,
functions.ST_GeogFromText,
functions.ST_GeomFromEWKT,
),
)
def _is_mssql_wkb_constructor(clause):
return isinstance(
clause,
(
functions.ST_GeomFromWKB,
functions.ST_GeogFromWKB,
functions.ST_GeomFromEWKB,
),
)
def _unwrap_mssql_constructor_clauses(clauses, predicate):
if len(clauses) != 1:
return clauses, None
inner_constructor = clauses[0]
if not predicate(inner_constructor):
return clauses, None
return list(inner_constructor.clauses), inner_constructor
def _spatial_constructor_matches_target(constructor_type, target_type, dialect):
constructor_type = _resolve_mssql_spatial_type(constructor_type, dialect)
target_type = _resolve_mssql_spatial_type(target_type, dialect)
return (
_check_spatial_type(constructor_type, Geometry, dialect)
and _check_spatial_type(target_type, Geometry, dialect)
) or (
_check_spatial_type(constructor_type, Geography, dialect)
and _check_spatial_type(target_type, Geography, dialect)
)
def _coerce_mssql_spatial_method_argument(target_clause, other_clause, dialect):
target_type = _resolve_mssql_spatial_type(getattr(target_clause, "type", None), dialect)
if _is_mssql_spatial_constructor(other_clause) and not _spatial_constructor_matches_target(
getattr(other_clause, "type", None),
target_type,
dialect,
):
other_clause = other_clause._clone()
other_clause.type = target_type
return other_clause
if not _is_spatial_clause(other_clause, dialect):
return expression.type_coerce(other_clause, target_type)
return other_clause
def _compile_mssql_function_fallback(element, compiler, **kw):
return compiler.visit_function(element, **kw)
def _compile_mssql_method(element, compiler, method_name, property_=False, **kw):
clauses = list(element.clauses)
if not clauses or not _is_spatial_function_target(clauses[0], compiler.dialect):
return _compile_mssql_function_fallback(element, compiler, **kw)
target = compiler.process(clauses[0], **kw)
if property_:
return f"{target}.{method_name}"
compiled_args = ", ".join(compiler.process(arg, **kw) for arg in clauses[1:])
return f"{target}.{method_name}({compiled_args})"
def _compile_mssql_binary_method(element, compiler, method_name, **kw):
clauses = list(element.clauses)
if len(clauses) < 2 or not _is_spatial_function_target(clauses[0], compiler.dialect):
return _compile_mssql_function_fallback(element, compiler, **kw)
target_clause = clauses[0]
other_clause = clauses[1]
other_clause = _coerce_mssql_spatial_method_argument(
target_clause,
other_clause,
compiler.dialect,
)
target = compiler.process(target_clause, **kw)
other = compiler.process(other_clause, **kw)
return f"{target}.{method_name}({other})"
def _compile_mssql_dwithin(element, compiler, **kw):
clauses = list(element.clauses)
if len(clauses) < 3 or not _is_spatial_function_target(clauses[0], compiler.dialect):
return _compile_mssql_function_fallback(element, compiler, **kw)
target_clause = clauses[0]
other_clause = _coerce_mssql_spatial_method_argument(
target_clause,
clauses[1],
compiler.dialect,
)
distance_clause = clauses[2]
target = compiler.process(target_clause, **kw)
other = compiler.process(other_clause, **kw)
distance = compiler.process(distance_clause, **kw)
return f"CASE WHEN {target}.STDistance({other}) <= {distance} THEN 1 ELSE 0 END"
def _mssql_little_endian_binary_from_big_endian(binary_expr):
return " + ".join(f"SUBSTRING({binary_expr}, {position}, 1)" for position in (4, 3, 2, 1))
def _mssql_ewkb_type_from_iso_type(wkb_type):
dimension_type = f"({wkb_type} / 1000)"
return (
f"CONVERT(bigint, {wkb_type} % 1000) + 536870912 + "
f"CASE WHEN {dimension_type} IN (1, 3) THEN 2147483648 ELSE 0 END + "
f"CASE WHEN {dimension_type} IN (2, 3) THEN 1073741824 ELSE 0 END"
)
def _mssql_binary4_from_unsigned_int(unsigned_int_expr):
return f"SUBSTRING(CONVERT(binary(8), CONVERT(bigint, ({unsigned_int_expr}))), 5, 4)"
def _mssql_clause_has_bind_parameter(clause):
return any(isinstance(child, BindParameter) for child in visitors.iterate(clause))
def _compile_mssql_derived_geometry_target(target, body, alias_name):
alias = _quote_mssql_identifier(alias_name)
column = _quote_mssql_identifier("geom")
return f"(SELECT {body(alias + '.' + column)} FROM (SELECT {target} AS {column}) AS {alias})"
def _compile_mssql_ewkb_body(target):
wkb = f"{target}.AsBinaryZM()"
little_endian_wkb_type = (
f"CONVERT(int, SUBSTRING({wkb}, 5, 1) + SUBSTRING({wkb}, 4, 1) + "
f"SUBSTRING({wkb}, 3, 1) + SUBSTRING({wkb}, 2, 1))"
)
big_endian_wkb_type = f"CONVERT(int, SUBSTRING({wkb}, 2, 4))"
little_endian_ewkb_type_word = _mssql_binary4_from_unsigned_int(
_mssql_ewkb_type_from_iso_type(little_endian_wkb_type)
)
big_endian_ewkb_type_word = _mssql_binary4_from_unsigned_int(
_mssql_ewkb_type_from_iso_type(big_endian_wkb_type)
)
little_endian_ewkb_type = _mssql_little_endian_binary_from_big_endian(
little_endian_ewkb_type_word
)
big_endian_ewkb_type = big_endian_ewkb_type_word
little_endian_srid = _mssql_little_endian_binary_from_big_endian(
f"CONVERT(binary(4), {target}.STSrid)"
)
big_endian_srid = f"CONVERT(binary(4), {target}.STSrid)"
payload = f"SUBSTRING({wkb}, 6, DATALENGTH({wkb}) - 5)"
return (
f"CASE WHEN {target} IS NULL THEN NULL "
f"WHEN SUBSTRING({wkb}, 1, 1) = 0x01 THEN "
f"CAST(0x01 AS varbinary(max)) + {little_endian_ewkb_type} + "
f"{little_endian_srid} + {payload} "
f"ELSE CAST(0x00 AS varbinary(max)) + {big_endian_ewkb_type} + "
f"{big_endian_srid} + {payload} END"
)
def _compile_mssql_as_ewkb(element, compiler, **kw):
clauses = list(element.clauses)
if not clauses or not _is_spatial_function_target(clauses[0], compiler.dialect):
return _compile_mssql_function_fallback(element, compiler, **kw)
target_clause = clauses[0]
target = compiler.process(target_clause, **kw)
if _mssql_clause_has_bind_parameter(target_clause):
return _compile_mssql_derived_geometry_target(
target,
_compile_mssql_ewkb_body,
"geoalchemy2_mssql_ewkb",
)
return _compile_mssql_ewkb_body(target)
def _compile_mssql_ewkt_body(target):
return (
f"CASE WHEN {target} IS NULL THEN NULL "
f"ELSE CONCAT('SRID=', {target}.STSrid, ';', {target}.AsTextZM()) END"
)
def _compile_mssql_as_ewkt(element, compiler, **kw):
clauses = list(element.clauses)
if not clauses or not _is_spatial_function_target(clauses[0], compiler.dialect):
return _compile_mssql_function_fallback(element, compiler, **kw)
target_clause = clauses[0]
target = compiler.process(target_clause, **kw)
if _mssql_clause_has_bind_parameter(target_clause):
return _compile_mssql_derived_geometry_target(
target,
_compile_mssql_ewkt_body,
"geoalchemy2_mssql_ewkt",
)
return _compile_mssql_ewkt_body(target)
def _compile_mssql_srid_clause(clause, compiler, default_srid, **kw):
if hasattr(clause, "value"):
value = clause.value
try:
if value is not None and int(value) < 0:
return "0"
except (TypeError, ValueError): # pragma: no cover
pass
return compiler.process(clause, **kw) if clause is not None else str(default_srid)
def _mssql_dynamic_ewkt_bind_keys(source_bind):
source_name = getattr(source_bind, "_orig_key", None) or source_bind.key
source_name = str(source_name)
source_key = str(source_bind.key)
key_token = re.sub(r"[^0-9A-Za-z_]+", "_", source_name).strip("_") or "param"
key_digest = hashlib.sha1(source_key.encode("utf-8")).hexdigest()[:8]
key_base = f"{_MSSQL_DYNAMIC_EWKT_KEY_PREFIX}_{key_token}_{key_digest}"
return f"{key_base}_text", f"{key_base}_srid"
def _mssql_dynamic_ewkt_bind_identifier(source_bind):
return getattr(source_bind, "_identifying_key", source_bind.key)
def _make_mssql_dynamic_ewkt_bind_clauses(wkt_clause, default_srid=0):
text_key, srid_key = _mssql_dynamic_ewkt_bind_keys(wkt_clause)
bind_kwargs = {
"required": wkt_clause.required,
}
if getattr(wkt_clause, "callable", None) is not None:
shared_callable = _MSSQLDynamicEWKTCallable(wkt_clause.callable)
bind_kwargs["callable_"] = shared_callable
elif not wkt_clause.required:
bind_kwargs["value"] = getattr(wkt_clause, "value", None)
return (
expression.bindparam(
key=text_key,
type_=_MSSQLDynamicEWKTTextBindType(),
**bind_kwargs,
),
expression.bindparam(
key=srid_key,
type_=_MSSQLDynamicEWKTSRIDBindType(default_srid=default_srid),
**bind_kwargs,
),
)
def _get_mssql_dynamic_ewkt_bind_clauses(wkt_clause, compiler, default_srid=0):
cache = getattr(compiler, "_geoalchemy2_mssql_dynamic_ewkt_bind_cache", None)
if cache is None:
cache = {}
compiler._geoalchemy2_mssql_dynamic_ewkt_bind_cache = cache
source_identifier = _mssql_dynamic_ewkt_bind_identifier(wkt_clause)
if source_identifier not in cache:
cache[source_identifier] = _make_mssql_dynamic_ewkt_bind_clauses(
wkt_clause,
default_srid=default_srid,
)
return cache[source_identifier]
def _mssql_dynamic_ewkb_bind_keys(source_bind, default_srid=0):
source_name = getattr(source_bind, "_orig_key", None) or source_bind.key
source_name = str(source_name)
source_key = str(source_bind.key)
key_token = re.sub(r"[^0-9A-Za-z_]+", "_", source_name).strip("_") or "param"
srid_token = re.sub(r"[^0-9A-Za-z_]+", "_", str(default_srid)).strip("_") or "0"
key_digest = hashlib.sha1(f"{source_key}:{default_srid}".encode()).hexdigest()[:8]
key_base = f"{_MSSQL_DYNAMIC_EWKB_KEY_PREFIX}_{key_token}_srid_{srid_token}_{key_digest}"
return f"{key_base}_wkb", f"{key_base}_srid"
def _make_mssql_dynamic_ewkb_bind_clauses(wkb_clause, default_srid=0, shared_callable=None):
wkb_key, srid_key = _mssql_dynamic_ewkb_bind_keys(wkb_clause, default_srid=default_srid)
bind_kwargs = {
"required": wkb_clause.required,
}
if getattr(wkb_clause, "callable", None) is not None:
if shared_callable is None:
shared_callable = _MSSQLDynamicEWKTCallable(wkb_clause.callable)
bind_kwargs["callable_"] = shared_callable
elif not wkb_clause.required:
bind_kwargs["value"] = getattr(wkb_clause, "value", None)
return (
expression.bindparam(
key=wkb_key,
type_=_MSSQLWKBBindType(extended=True),
**bind_kwargs,
),
expression.bindparam(
key=srid_key,
type_=_MSSQLDynamicEWKBSRIDBindType(default_srid=default_srid),
**bind_kwargs,
),
)
def _get_mssql_dynamic_ewkb_bind_clauses(wkb_clause, compiler, default_srid=0):
cache = getattr(compiler, "_geoalchemy2_mssql_dynamic_ewkb_bind_cache", None)
if cache is None:
cache = {}
compiler._geoalchemy2_mssql_dynamic_ewkb_bind_cache = cache
cache_key = (_mssql_dynamic_ewkt_bind_identifier(wkb_clause), default_srid)
if cache_key not in cache:
shared_callable = None
if getattr(wkb_clause, "callable", None) is not None:
callable_cache = getattr(
compiler,
"_geoalchemy2_mssql_dynamic_ewkb_callable_cache",
None,
)
if callable_cache is None:
callable_cache = {}
compiler._geoalchemy2_mssql_dynamic_ewkb_callable_cache = callable_cache
callable_key = _mssql_dynamic_ewkt_bind_identifier(wkb_clause)
shared_callable = callable_cache.get(callable_key)
if shared_callable is None:
shared_callable = _MSSQLDynamicEWKTCallable(wkb_clause.callable)
callable_cache[callable_key] = shared_callable
else:
shared_callable.add_consumers(2)
cache[cache_key] = _make_mssql_dynamic_ewkb_bind_clauses(
wkb_clause,
default_srid=default_srid,
shared_callable=shared_callable,
)
return cache[cache_key]
def _collect_mssql_dynamic_ewkt_source_binds(clauseelement, dialect):
if not hasattr(clauseelement, "get_children"):
return ()
source_binds = []
seen_source_identifiers = set()
for element in visitors.iterate(clauseelement):
if not isinstance(element, functions.ST_GeomFromEWKT):
continue
clauses = list(element.clauses)
if len(clauses) != 1 or not _is_bindparam_clause(clauses[0]):
continue
candidate_spatial_type = _resolve_mssql_spatial_type(element.type, dialect)
if (
_check_spatial_type(candidate_spatial_type, (Geometry, Geography), dialect)
and getattr(candidate_spatial_type, "srid", -1) >= 0
):
continue
source_identifier = _mssql_dynamic_ewkt_bind_identifier(clauses[0])
if source_identifier in seen_source_identifiers:
continue
seen_source_identifiers.add(source_identifier)
source_binds.append(clauses[0])
return tuple(source_binds)
def _collect_mssql_dynamic_ewkb_source_binds(clauseelement, dialect):
if not hasattr(clauseelement, "get_children"):
return ()
source_binds = []
seen_bind_keys = set()
for element in visitors.iterate(clauseelement):
if not isinstance(element, functions.ST_GeomFromEWKB):
continue
clauses = list(element.clauses)
if len(clauses) != 1 or not _is_bindparam_clause(clauses[0]):
continue
if _is_mssql_auto_constructor_bindparam(clauses[0], "ST_GeomFromEWKB"):
continue
candidate_spatial_type = _resolve_mssql_spatial_type(element.type, dialect)
if (
_check_spatial_type(candidate_spatial_type, (Geometry, Geography), dialect)
and getattr(candidate_spatial_type, "srid", -1) >= 0
):
continue
default_srid = element.type.srid if element.type.srid >= 0 else 0
bind_key = (_mssql_dynamic_ewkt_bind_identifier(clauses[0]), default_srid)
if bind_key in seen_bind_keys:
continue
seen_bind_keys.add(bind_key)
source_binds.append((clauses[0], default_srid))
return tuple(source_binds)
def _compile_mssql_statement_bind_name_map(clauseelement, dialect):
if not hasattr(clauseelement, "compile"):
return {}
if hasattr(clauseelement, "execution_options"):
clauseelement = clauseelement.execution_options(
**{_MSSQL_DISABLE_DYNAMIC_EWKT_SPLIT_OPTION: True}
)
compiled = clauseelement.compile(dialect=dialect)
bind_name_map = {}
for bind, compiled_name in compiled.bind_names.items():
bind_name_map.setdefault(
getattr(bind, "_identifying_key", bind.key),
compiled_name,
)
return bind_name_map
def _get_mssql_dynamic_ewkt_bind_mappings(clauseelement, dialect):
source_binds = _collect_mssql_dynamic_ewkt_source_binds(clauseelement, dialect)
if not source_binds:
return ()
statement_bind_name_map = _compile_mssql_statement_bind_name_map(clauseelement, dialect)
dynamic_bind_mappings = []
for source_bind in source_binds:
source_identifier = _mssql_dynamic_ewkt_bind_identifier(source_bind)
candidate_keys = []
for candidate_key in (
source_bind.key,
getattr(source_bind, "_orig_key", None),
statement_bind_name_map.get(source_identifier),
):
if candidate_key is not None and candidate_key not in candidate_keys:
candidate_keys.append(candidate_key)
text_key, srid_key = _mssql_dynamic_ewkt_bind_keys(source_bind)
dynamic_bind_mappings.append((tuple(candidate_keys), text_key, srid_key))
return tuple(dynamic_bind_mappings)
def _get_mssql_dynamic_ewkb_bind_mappings(clauseelement, dialect):
source_binds = _collect_mssql_dynamic_ewkb_source_binds(clauseelement, dialect)
if not source_binds:
return ()
statement_bind_name_map = _compile_mssql_statement_bind_name_map(clauseelement, dialect)
dynamic_bind_mappings = []
for source_bind, default_srid in source_binds:
source_identifier = _mssql_dynamic_ewkt_bind_identifier(source_bind)
candidate_keys = []
for candidate_key in (
source_bind.key,
getattr(source_bind, "_orig_key", None),
statement_bind_name_map.get(source_identifier),
):
if candidate_key is not None and candidate_key not in candidate_keys:
candidate_keys.append(candidate_key)
wkb_key, srid_key = _mssql_dynamic_ewkb_bind_keys(
source_bind,
default_srid=default_srid,
)
dynamic_bind_mappings.append((tuple(candidate_keys), wkb_key, srid_key))
return tuple(dynamic_bind_mappings)
def _expand_mssql_dynamic_ewkt_param_mapping(parameters, dynamic_bind_mappings):
if not isinstance(parameters, Mapping):
return parameters, False
expanded_parameters = parameters
changed = False
for source_keys, text_key, srid_key in dynamic_bind_mappings:
source_key = next((key for key in source_keys if key in parameters), None)
if source_key is None:
continue
if text_key in parameters and srid_key in parameters:
continue
if expanded_parameters is parameters:
expanded_parameters = dict(parameters)
source_value = parameters[source_key]
expanded_parameters.setdefault(text_key, source_value)
expanded_parameters.setdefault(srid_key, source_value)
changed = True
return expanded_parameters, changed
def before_execute(conn, clauseelement, multiparams, params, execution_options):
dynamic_bind_mappings = _get_mssql_dynamic_ewkt_bind_mappings(
clauseelement,
conn.dialect,
) + _get_mssql_dynamic_ewkb_bind_mappings(clauseelement, conn.dialect)
if not dynamic_bind_mappings:
return clauseelement, multiparams, params
multiparams_changed = False
expanded_multiparams = multiparams
if multiparams:
expanded_values = []
for value in multiparams:
expanded_value, value_changed = _expand_mssql_dynamic_ewkt_param_mapping(
value,
dynamic_bind_mappings,
)
expanded_values.append(expanded_value)
multiparams_changed = multiparams_changed or value_changed
if multiparams_changed:
expanded_multiparams = tuple(expanded_values)
expanded_params, params_changed = _expand_mssql_dynamic_ewkt_param_mapping(
params,
dynamic_bind_mappings,
)
if multiparams_changed or params_changed:
return clauseelement, expanded_multiparams, expanded_params
return clauseelement, multiparams, params
def _compile_mssql_geom_from_text(element, compiler, strip_srid=False, **kw):
clauses = list(element.clauses)
clauses, inner_constructor = _unwrap_mssql_constructor_clauses(
clauses,
_is_mssql_text_constructor,
)
strip_srid = strip_srid or isinstance(inner_constructor, functions.ST_GeomFromEWKT)
original_wkt_clause = clauses[0]
wkt_clause = original_wkt_clause
spatial_type = None
split_disabled = bool(
getattr(compiler, "execution_options", {}).get(
_MSSQL_DISABLE_DYNAMIC_EWKT_SPLIT_OPTION,
False,
)
)
if strip_srid:
candidate_spatial_type = _resolve_mssql_spatial_type(element.type, compiler.dialect)
if (
_check_spatial_type(candidate_spatial_type, (Geometry, Geography), compiler.dialect)
and getattr(candidate_spatial_type, "srid", -1) >= 0
):
spatial_type = candidate_spatial_type
if (
strip_srid
and spatial_type is None
and len(clauses) == 1
and _is_bindparam_clause(original_wkt_clause)
and not kw.get("literal_binds", False)
and not split_disabled
):
dynamic_text_clause, dynamic_srid_clause = _get_mssql_dynamic_ewkt_bind_clauses(
original_wkt_clause,
compiler,
)
compiled_wkt = compiler.process(dynamic_text_clause, **kw)
compiled_srid = compiler.process(dynamic_srid_clause, **kw)
return f"{element.type.name}::STGeomFromText({compiled_wkt}, {compiled_srid})"
if (
kw.get("literal_binds", False)
or _is_bindparam_clause(original_wkt_clause)
or _should_coerce_wkt_bind_clause_for_text(
original_wkt_clause,
strip_srid=strip_srid,
)
):
wkt_clause = _coerce_wkt_bind_clause(
original_wkt_clause,
strip_srid=strip_srid,
literal=kw.get("literal_binds", False),
spatial_type=spatial_type,
)
compiled_wkt = compiler.process(wkt_clause, **kw)
if len(clauses) > 1:
compiled_srid = _compile_mssql_srid_clause(clauses[1], compiler, 0, **kw)
else:
srid = element.type.srid if element.type.srid >= 0 else 0
if strip_srid and hasattr(original_wkt_clause, "value"):
value = original_wkt_clause.value
if isinstance(value, WKTElement):
value = value.data
if isinstance(value, str):
wkt_match = WKTElement._REMOVE_SRID.match(value)
matched_srid = wkt_match.group(2)
if matched_srid is not None:
srid = int(matched_srid)
compiled_srid = str(srid)
return f"{element.type.name}::STGeomFromText({compiled_wkt}, {compiled_srid})"
def _compile_mssql_geom_from_wkb(element, compiler, extended=False, **kw):
clauses = list(element.clauses)
clauses, inner_constructor = _unwrap_mssql_constructor_clauses(
clauses,
_is_mssql_wkb_constructor,
)
extended = extended or isinstance(inner_constructor, functions.ST_GeomFromEWKB)
original_wkb_clause = clauses[0]
wkb_clause = original_wkb_clause
spatial_type = None
split_disabled = bool(
getattr(compiler, "execution_options", {}).get(
_MSSQL_DISABLE_DYNAMIC_EWKT_SPLIT_OPTION,
False,
)
)
if extended:
candidate_spatial_type = _resolve_mssql_spatial_type(element.type, compiler.dialect)
if (
_check_spatial_type(candidate_spatial_type, (Geometry, Geography), compiler.dialect)
and getattr(candidate_spatial_type, "srid", -1) >= 0
):
spatial_type = candidate_spatial_type
if (
extended
and spatial_type is None
and len(clauses) == 1
and _is_bindparam_clause(original_wkb_clause)
and not _is_mssql_auto_constructor_bindparam(original_wkb_clause, "ST_GeomFromEWKB")
and not kw.get("literal_binds", False)
and not split_disabled
):
default_srid = element.type.srid if element.type.srid >= 0 else 0
dynamic_wkb_clause, dynamic_srid_clause = _get_mssql_dynamic_ewkb_bind_clauses(
original_wkb_clause,
compiler,
default_srid=default_srid,
)
compiled_wkb = compiler.process(dynamic_wkb_clause, **kw)
compiled_srid = compiler.process(dynamic_srid_clause, **kw)
return f"{element.type.name}::STGeomFromWKB({compiled_wkb}, {compiled_srid})"
if (
kw.get("literal_binds", False)
or (extended and _is_bindparam_clause(original_wkb_clause))
or _should_coerce_wkb_bind_clause(clauses[0])
):
wkb_clause = _coerce_wkb_bind_clause(
clauses[0], extended=extended, literal=kw.get("literal_binds", False)
)
if kw.get("literal_binds", False) and hasattr(wkb_clause, "value"):
compiled_wkb = f"0x{WKBElement._wkb_to_hex(wkb_clause.value)}"
else:
compiled_wkb = compiler.process(wkb_clause, **kw)
if len(clauses) > 1:
compiled_srid = _compile_mssql_srid_clause(clauses[1], compiler, 0, **kw)
else:
default_srid = element.type.srid if element.type.srid >= 0 else 0
compiled_srid = str(
_infer_srid_from_wkb_clause(original_wkb_clause, default_srid, extended=extended)
)
return f"{element.type.name}::STGeomFromWKB({compiled_wkb}, {compiled_srid})"
@compiles(functions.ST_GeomFromText, "mssql") # type: ignore
def _MSSQL_ST_GeomFromText(element, compiler, **kw):
return _compile_mssql_geom_from_text(element, compiler, **kw)
@compiles(functions.ST_GeogFromText, "mssql") # type: ignore
def _MSSQL_ST_GeogFromText(element, compiler, **kw):
return _compile_mssql_geom_from_text(element, compiler, **kw)
@compiles(functions.ST_GeomFromEWKT, "mssql") # type: ignore
def _MSSQL_ST_GeomFromEWKT(element, compiler, **kw):
return _compile_mssql_geom_from_text(element, compiler, strip_srid=True, **kw)
@compiles(functions.ST_GeomFromWKB, "mssql") # type: ignore
def _MSSQL_ST_GeomFromWKB(element, compiler, **kw):
return _compile_mssql_geom_from_wkb(element, compiler, **kw)
@compiles(functions.ST_GeogFromWKB, "mssql") # type: ignore
def _MSSQL_ST_GeogFromWKB(element, compiler, **kw):
return _compile_mssql_geom_from_wkb(element, compiler, **kw)
@compiles(functions.ST_GeomFromEWKB, "mssql") # type: ignore
def _MSSQL_ST_GeomFromEWKB(element, compiler, **kw):
return _compile_mssql_geom_from_wkb(element, compiler, extended=True, **kw)
@compiles(functions.ST_AsBinary, "mssql") # type: ignore
def _MSSQL_ST_AsBinary(element, compiler, **kw):
return _compile_mssql_method(element, compiler, "STAsBinary", **kw)
@compiles(functions.ST_AsEWKB, "mssql") # type: ignore
def _MSSQL_ST_AsEWKB(element, compiler, **kw):
return _compile_mssql_as_ewkb(element, compiler, **kw)
@compiles(functions.ST_AsText, "mssql") # type: ignore
def _MSSQL_ST_AsText(element, compiler, **kw):
return _compile_mssql_method(element, compiler, "AsTextZM", **kw)
@compiles(functions.ST_AsEWKT, "mssql") # type: ignore
def _MSSQL_ST_AsEWKT(element, compiler, **kw):
return _compile_mssql_as_ewkt(element, compiler, **kw)
@compiles(functions.ST_GeometryType, "mssql") # type: ignore
def _MSSQL_ST_GeometryType(element, compiler, **kw):
return _compile_mssql_method(element, compiler, "STGeometryType", **kw)
@compiles(functions.ST_SRID, "mssql") # type: ignore
def _MSSQL_ST_SRID(element, compiler, **kw):
return _compile_mssql_method(element, compiler, "STSrid", property_=True, **kw)
@compiles(functions.ST_Buffer, "mssql") # type: ignore
def _MSSQL_ST_Buffer(element, compiler, **kw):
return _compile_mssql_method(element, compiler, "STBuffer", **kw)
@compiles(functions.ST_Area, "mssql") # type: ignore
def _MSSQL_ST_Area(element, compiler, **kw):
return _compile_mssql_method(element, compiler, "STArea", **kw)
@compiles(functions.ST_Length, "mssql") # type: ignore
def _MSSQL_ST_Length(element, compiler, **kw):
return _compile_mssql_method(element, compiler, "STLength", **kw)
@compiles(functions.ST_Distance, "mssql") # type: ignore
def _MSSQL_ST_Distance(element, compiler, **kw):
return _compile_mssql_binary_method(element, compiler, "STDistance", **kw)
@compiles(functions.ST_DWithin, "mssql") # type: ignore
def _MSSQL_ST_DWithin(element, compiler, **kw):
return _compile_mssql_dwithin(element, compiler, **kw)
@compiles(functions.ST_Within, "mssql") # type: ignore
def _MSSQL_ST_Within(element, compiler, **kw):
return _compile_mssql_binary_method(element, compiler, "STWithin", **kw)
@compiles(functions.ST_Equals, "mssql") # type: ignore
def _MSSQL_ST_Equals(element, compiler, **kw):
return _compile_mssql_binary_method(element, compiler, "STEquals", **kw)
@compiles(functions.ST_Contains, "mssql") # type: ignore
def _MSSQL_ST_Contains(element, compiler, **kw):
return _compile_mssql_binary_method(element, compiler, "STContains", **kw)
@compiles(functions.ST_Intersects, "mssql") # type: ignore
def _MSSQL_ST_Intersects(element, compiler, **kw):
return _compile_mssql_binary_method(element, compiler, "STIntersects", **kw)
@compiles(functions.ST_Disjoint, "mssql") # type: ignore
def _MSSQL_ST_Disjoint(element, compiler, **kw):
return _compile_mssql_binary_method(element, compiler, "STDisjoint", **kw)
@compiles(functions.ST_Touches, "mssql") # type: ignore
def _MSSQL_ST_Touches(element, compiler, **kw):
return _compile_mssql_binary_method(element, compiler, "STTouches", **kw)
@compiles(functions.ST_Overlaps, "mssql") # type: ignore
def _MSSQL_ST_Overlaps(element, compiler, **kw):
return _compile_mssql_binary_method(element, compiler, "STOverlaps", **kw)
@compiles(functions.ST_Crosses, "mssql") # type: ignore
def _MSSQL_ST_Crosses(element, compiler, **kw):
return _compile_mssql_binary_method(element, compiler, "STCrosses", **kw)
@compiles(BinaryExpression, "mssql") # type: ignore
def _MSSQL_binary_expression(binary, compiler, override_operator=None, **kw):
operator = override_operator or binary.operator
if operator in (operators.eq, operators.ne):
if _is_spatial_clause(binary.left, compiler.dialect):
target_clause = binary.left
other_clause = binary.right
elif _is_spatial_clause(binary.right, compiler.dialect):
target_clause = binary.right
other_clause = binary.left
else:
target_clause = None
if target_clause is not None and isinstance(other_clause, Null):
target = compiler.process(target_clause, **kw)
return f"{target} IS {'NOT ' if operator is operators.ne else ''}NULL"
if target_clause is not None:
other_clause = _coerce_mssql_spatial_method_argument(
target_clause,
other_clause,
compiler.dialect,
)
target = compiler.process(target_clause, **kw)
other = compiler.process(other_clause, **kw)
equals = f"{target}.STEquals({other})"
return f"{equals} = {1 if operator is operators.eq else 0}"
return compiler.visit_binary(binary, override_operator=override_operator, **kw)