diff --git a/CHANGES.rst b/CHANGES.rst index 9a22a69..790865d 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,10 +3,17 @@ Changelog of threedi-schema -0.228.4 (unreleased) +0.229.0 (unreleased) -------------------- -- Nothing changed yet. +- Rename sqlite table "tags" to "tag" +- Remove indices referring to removed tables in previous migrations +- Make model_settings.use_2d_rain and model_settings.friction_averaging booleans +- Remove columns referencing v2 in geometry_column +- Ensure correct use_* values when matching tables have no data +- Use custom types for comma separated and table text fields to strip extra white space +- Correct direction of dwf and surface map +- Remove v2 related views from sqlite 0.228.3 (2024-12-10) diff --git a/threedi_schema/__init__.py b/threedi_schema/__init__.py index 462a471..7723ea6 100644 --- a/threedi_schema/__init__.py +++ b/threedi_schema/__init__.py @@ -2,6 +2,6 @@ from .domain import constants, custom_types, models # NOQA # fmt: off -__version__ = '0.228.4.dev0' +__version__ = '0.229.0.dev0' # fmt: on diff --git a/threedi_schema/domain/custom_types.py b/threedi_schema/domain/custom_types.py index cdff404..2aec458 100644 --- a/threedi_schema/domain/custom_types.py +++ b/threedi_schema/domain/custom_types.py @@ -1,6 +1,8 @@ +import re + import geoalchemy2 from packaging import version -from sqlalchemy.types import Integer, TypeDecorator, VARCHAR +from sqlalchemy.types import Integer, Text, TypeDecorator, VARCHAR class Geometry(geoalchemy2.types.Geometry): @@ -66,6 +68,49 @@ class IntegerEnum(CustomEnum): impl = Integer +def clean_csv_string(value: str) -> str: + return re.sub(r"\s*,\s*", ",", value.strip()) + + +class CSVText(TypeDecorator): + impl = Text + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is not None: + value = clean_csv_string(value) + return value + + def process_result_value(self, value, dialect): + if value is not None: + value = clean_csv_string(value) + return value + + +def clean_csv_table(value: str) -> str: + # convert windows line endings to unix first + value = value.replace("\r\n", "\n") + # remove leading and trailing whitespace + value = value.strip() + # clean up each line + return "\n".join([clean_csv_string(line) for line in value.split("\n")]) + + +class CSVTable(TypeDecorator): + impl = Text + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is not None: + value = clean_csv_table(value) + return value + + def process_result_value(self, value, dialect): + if value is not None: + value = clean_csv_table(value) + return value + + class VarcharEnum(CustomEnum): cache_ok = True impl = VARCHAR diff --git a/threedi_schema/domain/models.py b/threedi_schema/domain/models.py index c80b0b1..61dd81e 100644 --- a/threedi_schema/domain/models.py +++ b/threedi_schema/domain/models.py @@ -1,8 +1,8 @@ -from sqlalchemy import Boolean, Column, Float, ForeignKey, Integer, String, Text +from sqlalchemy import Boolean, Column, Float, Integer, String, Text from sqlalchemy.orm import declarative_base from . import constants -from .custom_types import Geometry, IntegerEnum, VarcharEnum +from .custom_types import CSVTable, CSVText, Geometry, IntegerEnum, VarcharEnum Base = declarative_base() # automap_base() @@ -13,12 +13,12 @@ class Lateral2D(Base): code = Column(Text) display_name = Column(Text) type = Column(IntegerEnum(constants.Later2dType)) - timeseries = Column(Text) + timeseries = Column(CSVText) time_units = Column(Text) interpolate = Column(Boolean) offset = Column(Integer) units = Column(Text) - tags = Column(Text) + tags = Column(CSVText) geom = Column(Geometry("POINT"), nullable=False) @@ -28,10 +28,10 @@ class BoundaryConditions2D(Base): code = Column(Text) display_name = Column(Text) type = Column(IntegerEnum(constants.BoundaryType)) - timeseries = Column(Text) + timeseries = Column(CSVText) time_units = Column(Text) interpolate = Column(Boolean) - tags = Column(Text) + tags = Column(CSVText) geom = Column(Geometry("LINESTRING"), nullable=False) @@ -43,7 +43,7 @@ class ControlMeasureLocation(Base): display_name = Column(Text) code = Column(Text) geom = Column(Geometry("POINT"), nullable=False) - tags = Column(Text) + tags = Column(CSVText) class ControlMeasureMap(Base): @@ -56,7 +56,7 @@ class ControlMeasureMap(Base): display_name = Column(Text) code = Column(Text) geom = Column(Geometry("LINESTRING"), nullable=False) - tags = Column(Text) + tags = Column(CSVText) class ControlMemory(Base): @@ -74,13 +74,13 @@ class ControlMemory(Base): display_name = Column(Text) code = Column(Text) geom = Column(Geometry("POINT"), nullable=False) - tags = Column(Text) + tags = Column(CSVText) class ControlTable(Base): __tablename__ = "table_control" id = Column(Integer, primary_key=True) - action_table = Column(Text) + action_table = Column(CSVTable) action_type = Column(VarcharEnum(constants.ControlTableActionTypes)) measure_operator = Column(VarcharEnum(constants.MeasureOperators)) target_type = Column(VarcharEnum(constants.StructureControlTypes)) @@ -88,7 +88,7 @@ class ControlTable(Base): display_name = Column(Text) code = Column(Text) geom = Column(Geometry("POINT"), nullable=False) - tags = Column(Text) + tags = Column(CSVText) class Interflow(Base): @@ -130,7 +130,7 @@ class SurfaceParameter(Base): min_infiltration_capacity = Column(Float, nullable=False) infiltration_decay_constant = Column(Float, nullable=False) infiltration_recovery_constant = Column(Float, nullable=False) - tags = Column(Text) + tags = Column(CSVText) description = Column(Text) @@ -140,14 +140,12 @@ class Surface(Base): code = Column(String(100)) display_name = Column(String(255)) area = Column(Float) - surface_parameters_id = Column( - Integer, ForeignKey(SurfaceParameter.__tablename__ + ".id"), nullable=False - ) + surface_parameters_id = Column(Integer) geom = Column( Geometry("POLYGON"), nullable=True, ) - tags = Column(Text) + tags = Column(CSVText) class DryWeatherFlow(Base): @@ -163,7 +161,7 @@ class DryWeatherFlow(Base): Geometry("POLYGON"), nullable=False, ) - tags = Column(Text) + tags = Column(CSVText) class DryWeatherFlowMap(Base): @@ -178,15 +176,15 @@ class DryWeatherFlowMap(Base): nullable=False, ) percentage = Column(Float) - tags = Column(Text) + tags = Column(CSVText) class DryWeatherFlowDistribution(Base): __tablename__ = "dry_weather_flow_distribution" id = Column(Integer, primary_key=True) description = Column(Text) - tags = Column(Text) - distribution = Column(Text) + tags = Column(CSVText) + distribution = Column(CSVText) class GroundWater(Base): @@ -239,7 +237,7 @@ class GridRefinementLine(Base): grid_level = Column(Integer) geom = Column(Geometry("LINESTRING"), nullable=False) code = Column(String(100)) - tags = Column(Text) + tags = Column(CSVText) class GridRefinementArea(Base): @@ -249,7 +247,7 @@ class GridRefinementArea(Base): grid_level = Column(Integer) code = Column(String(100)) geom = Column(Geometry("POLYGON"), nullable=False) - tags = Column(Text) + tags = Column(CSVText) class ConnectionNode(Base): @@ -257,7 +255,7 @@ class ConnectionNode(Base): id = Column(Integer, primary_key=True) geom = Column(Geometry("POINT"), nullable=False) code = Column(String(100)) - tags = Column(Text) + tags = Column(CSVText) display_name = Column(Text) storage_area = Column(Float) initial_water_level = Column(Float) @@ -276,12 +274,12 @@ class Lateral1d(Base): id = Column(Integer, primary_key=True) code = Column(Text) display_name = Column(Text) - timeseries = Column(Text) + timeseries = Column(CSVText) time_units = Column(Text) interpolate = Column(Boolean) offset = Column(Integer) units = Column(Text) - tags = Column(Text) + tags = Column(CSVText) geom = Column(Geometry("POINT"), nullable=False) connection_node_id = Column(Integer) @@ -354,9 +352,9 @@ class ModelSettings(Base): embedded_cutoff_threshold = Column(Float) epsg_code = Column(Integer) max_angle_1d_advection = Column(Float) - friction_averaging = Column(IntegerEnum(constants.OffOrStandard)) + friction_averaging = Column(Boolean) table_step_size_1d = Column(Float) - use_2d_rain = Column(Integer) + use_2d_rain = Column(Boolean) use_interflow = Column(Boolean) use_interception = Column(Boolean) use_simple_infiltration = Column(Boolean) @@ -409,7 +407,7 @@ class PhysicalSettings(Base): __tablename__ = "physical_settings" id = Column(Integer, primary_key=True) use_advection_1d = Column(IntegerEnum(constants.AdvectionTypes1D)) - use_advection_2d = Column(IntegerEnum(constants.OffOrStandard)) + use_advection_2d = Column(Boolean) class SimulationTemplateSettings(Base): @@ -437,10 +435,10 @@ class BoundaryCondition1D(Base): code = Column(Text) display_name = Column(Text) type = Column(IntegerEnum(constants.BoundaryType)) - timeseries = Column(Text) + timeseries = Column(CSVText) time_units = Column(Text) interpolate = Column(Boolean) - tags = Column(Text) + tags = Column(CSVText) geom = Column(Geometry("POINT"), nullable=False) connection_node_id = Column(Integer) @@ -450,12 +448,10 @@ class SurfaceMap(Base): __tablename__ = "surface_map" id = Column(Integer, primary_key=True) surface_id = Column(Integer, nullable=False) - connection_node_id = Column( - Integer, ForeignKey(ConnectionNode.__tablename__ + ".id"), nullable=False - ) + connection_node_id = Column(Integer) percentage = Column(Float) geom = Column(Geometry("LINESTRING"), nullable=False) - tags = Column(Text) + tags = Column(CSVText) code = Column(String(100)) display_name = Column(String(255)) @@ -465,7 +461,7 @@ class Channel(Base): id = Column(Integer, primary_key=True) display_name = Column(String(255)) code = Column(String(100)) - tags = Column(Text) + tags = Column(CSVText) exchange_type = Column(IntegerEnum(constants.CalculationType)) calculation_point_distance = Column(Float) geom = Column(Geometry("LINESTRING"), nullable=False) @@ -489,14 +485,14 @@ class Windshielding(Base): northwest = Column(Float) geom = Column(Geometry("POINT"), nullable=False) channel_id = Column(Integer) - tags = Column(Text) + tags = Column(CSVText) class CrossSectionLocation(Base): __tablename__ = "cross_section_location" id = Column(Integer, primary_key=True) code = Column(String(100)) - tags = Column(Text) + tags = Column(CSVText) reference_level = Column(Float) friction_type = Column(IntegerEnum(constants.FrictionType)) friction_value = Column(Float) @@ -504,9 +500,9 @@ class CrossSectionLocation(Base): cross_section_shape = Column(IntegerEnum(constants.CrossSectionShape)) cross_section_width = Column(Float) cross_section_height = Column(Float) - cross_section_friction_values = Column(Text) - cross_section_vegetation_table = Column(Text) - cross_section_table = Column(Text) + cross_section_friction_values = Column(CSVText) + cross_section_vegetation_table = Column(CSVTable) + cross_section_table = Column(CSVTable) vegetation_stem_density = Column(Float) vegetation_stem_diameter = Column(Float) vegetation_height = Column(Float) @@ -520,7 +516,7 @@ class Pipe(Base): id = Column(Integer, primary_key=True) display_name = Column(String(255)) code = Column(String(100)) - tags = Column(Text) + tags = Column(CSVText) geom = Column(Geometry("LINESTRING"), nullable=False) sewerage_type = Column(IntegerEnum(constants.SewerageType)) exchange_type = Column(IntegerEnum(constants.PipeCalculationType)) @@ -535,7 +531,7 @@ class Pipe(Base): cross_section_shape = Column(IntegerEnum(constants.CrossSectionShape)) cross_section_width = Column(Float) cross_section_height = Column(Float) - cross_section_table = Column(Text) + cross_section_table = Column(CSVTable) exchange_thickness = Column(Float) hydraulic_conductivity_in = Column(Float) hydraulic_conductivity_out = Column(Float) @@ -546,7 +542,7 @@ class Culvert(Base): id = Column(Integer, primary_key=True) display_name = Column(String(255)) code = Column(String(100)) - tags = Column(Text) + tags = Column(CSVText) exchange_type = Column(IntegerEnum(constants.CalculationTypeCulvert)) friction_value = Column(Float) friction_type = Column(IntegerEnum(constants.FrictionType)) @@ -562,7 +558,7 @@ class Culvert(Base): cross_section_shape = Column(IntegerEnum(constants.CrossSectionShape)) cross_section_width = Column(Float) cross_section_height = Column(Float) - cross_section_table = Column(Text) + cross_section_table = Column(CSVTable) class DemAverageArea(Base): @@ -571,7 +567,7 @@ class DemAverageArea(Base): geom = Column(Geometry("POLYGON"), nullable=False) display_name = Column(Text) code = Column(Text) - tags = Column(Text) + tags = Column(CSVText) class Weir(Base): @@ -580,7 +576,7 @@ class Weir(Base): code = Column(String(100)) display_name = Column(String(255)) geom = Column(Geometry("LINESTRING"), nullable=False) - tags = Column(Text) + tags = Column(CSVText) crest_level = Column(Float) crest_type = Column(IntegerEnum(constants.CrestType)) friction_value = Column(Float) @@ -595,7 +591,7 @@ class Weir(Base): cross_section_shape = Column(IntegerEnum(constants.CrossSectionShape)) cross_section_width = Column(Float) cross_section_height = Column(Float) - cross_section_table = Column(Text) + cross_section_table = Column(CSVTable) class Orifice(Base): @@ -603,7 +599,7 @@ class Orifice(Base): id = Column(Integer, primary_key=True) code = Column(String(100)) display_name = Column(String(255)) - tags = Column(Text) + tags = Column(CSVText) geom = Column(Geometry("LINESTRING"), nullable=False) crest_type = Column(IntegerEnum(constants.CrestType)) crest_level = Column(Float) @@ -618,7 +614,7 @@ class Orifice(Base): cross_section_shape = Column(IntegerEnum(constants.CrossSectionShape)) cross_section_width = Column(Float) cross_section_height = Column(Float) - cross_section_table = Column(Text) + cross_section_table = Column(CSVTable) class Pump(Base): @@ -636,7 +632,7 @@ class Pump(Base): sewerage = Column(Boolean) connection_node_id = Column(Integer) geom = Column(Geometry("POINT"), nullable=False) - tags = Column(Text) + tags = Column(CSVText) class PumpMap(Base): @@ -645,7 +641,7 @@ class PumpMap(Base): pump_id = Column(Integer) connection_node_id_end = Column(Integer) geom = Column(Geometry("LINESTRING"), nullable=False) - tags = Column(Text) + tags = Column(CSVText) code = Column(String(100)) display_name = Column(String(255)) @@ -656,7 +652,7 @@ class Obstacle(Base): code = Column(String(100)) crest_level = Column(Float) geom = Column(Geometry("LINESTRING"), nullable=False) - tags = Column(Text) + tags = Column(CSVText) display_name = Column(String(255)) affects_2d = Column(Boolean) affects_1d2d_open_water = Column(Boolean) @@ -668,7 +664,7 @@ class PotentialBreach(Base): id = Column(Integer, primary_key=True) code = Column(String(100)) display_name = Column(String(255)) - tags = Column(Text) + tags = Column(CSVText) initial_exchange_level = Column(Float) final_exchange_level = Column(Float) levee_material = Column(IntegerEnum(constants.Material)) @@ -684,11 +680,11 @@ class ExchangeLine(Base): exchange_level = Column(Float) display_name = Column(Text) code = Column(Text) - tags = Column(Text) + tags = Column(CSVText) class Tags(Base): - __tablename__ = "tags" + __tablename__ = "tag" id = Column(Integer, primary_key=True) description = Column(Text) diff --git a/threedi_schema/migrations/utils.py b/threedi_schema/migrations/utils.py new file mode 100644 index 0000000..a4d784b --- /dev/null +++ b/threedi_schema/migrations/utils.py @@ -0,0 +1,32 @@ +from typing import List + +import sqlalchemy as sa + + +def drop_geo_table(op, table_name: str): + """ + + Safely drop table, taking into account geometry columns + + Parameters: + op : object + An object representing the database operation. + table_name : str + The name of the table to be dropped. + """ + op.execute(sa.text(f"SELECT DropTable(NULL, '{table_name}');")) + + +def drop_conflicting(op, new_tables: List[str]): + """ + Drop tables from database that conflict with new tables + + Parameters: + op: The SQLAlchemy operation context to interact with the database. + new_tables: A list of new table names to be checked for conflicts with existing tables. + """ + connection = op.get_bind() + existing_tables = [item[0] for item in connection.execute( + sa.text("SELECT name FROM sqlite_master WHERE type='table';")).fetchall()] + for table_name in set(existing_tables).intersection(new_tables): + drop_geo_table(op, table_name) \ No newline at end of file diff --git a/threedi_schema/migrations/versions/0222_upgrade_db_settings.py b/threedi_schema/migrations/versions/0222_upgrade_db_settings.py index bd4afb9..33f615c 100644 --- a/threedi_schema/migrations/versions/0222_upgrade_db_settings.py +++ b/threedi_schema/migrations/versions/0222_upgrade_db_settings.py @@ -15,6 +15,8 @@ from sqlalchemy import Boolean, Column, Float, Integer, String from sqlalchemy.orm import declarative_base +from threedi_schema.migrations.utils import drop_conflicting + # revision identifiers, used by Alembic. revision = "0222" down_revision = "0221" @@ -289,6 +291,12 @@ def set_use_inteception(): ); """)) + op.execute(sa.text(""" + DELETE FROM interception + WHERE (interception IS NULL OR interception = '') + AND (interception_file IS NULL OR interception_file = ''); + """)) + def delete_all_but_matching_id(table, settings_id): op.execute(f"DELETE FROM {table} WHERE id NOT IN (SELECT {settings_id} FROM model_settings);") @@ -369,18 +377,12 @@ def set_flow_variable_values(): op.execute(sa.text(query)) -def drop_conflicting(): - new_tables = list(ADD_TABLES.keys()) + [new_name for _, new_name in RENAME_TABLES] - for table_name in new_tables: - op.execute(f"DROP TABLE IF EXISTS {table_name};") - - def upgrade(): op.get_bind() # Only use first row of global settings delete_all_but_first_row("v2_global_settings") # Remove existing tables (outside of the specs) that conflict with new table names - drop_conflicting() + drop_conflicting(op, list(ADD_TABLES.keys()) + [new_name for _, new_name in RENAME_TABLES]) rename_tables(RENAME_TABLES) # rename columns in renamed tables for table_name, columns in RENAME_COLUMNS.items(): diff --git a/threedi_schema/migrations/versions/0223_upgrade_db_inflow.py b/threedi_schema/migrations/versions/0223_upgrade_db_inflow.py index 5283c1b..0c4749e 100644 --- a/threedi_schema/migrations/versions/0223_upgrade_db_inflow.py +++ b/threedi_schema/migrations/versions/0223_upgrade_db_inflow.py @@ -18,6 +18,7 @@ from threedi_schema.application.threedi_database import load_spatialite from threedi_schema.domain.custom_types import Geometry +from threedi_schema.migrations.utils import drop_conflicting, drop_geo_table # revision identifiers, used by Alembic. revision = "0223" @@ -77,7 +78,7 @@ Column("tags", Text), Column("distribution", Text) ], - "tags": [ + "tag": [ Column("description", Text) ] } @@ -136,7 +137,7 @@ def add_geometry_column(table: str, geocol: Column): def remove_tables(tables: List[str]): for table in tables: - op.drop_table(table) + drop_geo_table(op, table) def copy_values_to_new_table(src_table: str, src_columns: List[str], dst_table: str, dst_columns: List[str]): @@ -188,12 +189,10 @@ def set_map_geometries(basename): # Set geom as a line between point on surface/dry_weather_flow and connection node query = f""" UPDATE {basename}_map AS map - SET geom = ( - SELECT MakeLine(PointOnSurface(obj.geom), vcn.the_geom) + SET geom = MakeLine(PointOnSurface(obj.geom), vcn.the_geom) FROM {basename} obj JOIN v2_connection_nodes vcn ON map.connection_node_id = vcn.id - WHERE obj.id = map.{basename}_id - ); + WHERE obj.id = map.{basename}_id; """ op.execute(sa.text(query)) @@ -208,17 +207,17 @@ def add_map_geometries(src_table: str): WHEN ST_Equals(c.the_geom, PointOnSurface(s.geom)) THEN -- Transform to EPSG:4326 for the projection, then back to the original SRID MakeLine( - c.the_geom, PointOnSurface(ST_Transform( ST_Translate( ST_Transform(s.geom, {srid}), 0, 1, 0 ), 4326 - )) + )), + c.the_geom ) ELSE - MakeLine(c.the_geom, PointOnSurface(s.geom)) + MakeLine(PointOnSurface(s.geom), c.the_geom) END FROM v2_connection_nodes c, {src_table} s WHERE c.id = {src_table}_map.connection_node_id @@ -234,7 +233,6 @@ def add_map_geometries(src_table: str): op.execute(sa.text(query)) - def get_global_srid(): conn = op.get_bind() use_0d_inflow = conn.execute(sa.text("SELECT use_0d_inflow FROM simulation_template_settings LIMIT 1")).fetchone() @@ -256,6 +254,7 @@ def copy_polygons(src_table: str, tmp_geom: str): # - copy the first item of all multipolygons # - add new rows for each extra polygon inside a multipolygon conn = op.get_bind() + # Copy polygons directly op.execute(sa.text(f"UPDATE {src_table} SET {tmp_geom} = the_geom WHERE GeometryType(the_geom) = 'POLYGON';")) # Copy first polygon of each multipolygon and correct the area @@ -304,23 +303,25 @@ def copy_polygons(src_table: str, tmp_geom: str): def create_buffer_polygons(src_table: str, tmp_geom: str): # create circular polygon of area 1 around the connection node surf_id = f"{src_table.strip('v2_')}_id" - op.execute(sa.text(f""" - UPDATE {src_table} - SET {tmp_geom} = ( - SELECT ST_Buffer(v2_connection_nodes.the_geom, 1) - FROM v2_connection_nodes - JOIN {src_table}_map - ON v2_connection_nodes.id = {src_table}_map.connection_node_id - WHERE {src_table}.id = {src_table}_map.{surf_id} - ) - WHERE {tmp_geom} IS NULL - AND id IN ( - SELECT {src_table}_map.{surf_id} - FROM v2_connection_nodes - JOIN {src_table}_map - ON v2_connection_nodes.id = {src_table}_map.connection_node_id - ); - """)) + query = f""" + WITH connection_data AS ( + SELECT + {src_table}_map.{surf_id} AS item_id, + ST_Buffer(v2_connection_nodes.the_geom, 1) AS buffer_geom + FROM + v2_connection_nodes + JOIN + {src_table}_map + ON + v2_connection_nodes.id = {src_table}_map.connection_node_id + ) + UPDATE {src_table} + SET {tmp_geom} = connection_data.buffer_geom + FROM connection_data + WHERE {src_table}.id = connection_data.item_id + AND {tmp_geom} IS NULL; + """ + op.execute(sa.text(query)) def create_square_polygons(src_table: str, tmp_geom: str): @@ -332,39 +333,39 @@ def create_square_polygons(src_table: str, tmp_geom: str): srid = get_global_srid() query_str = f""" WITH center AS ( - SELECT {src_table}.id AS item_id, - ST_Centroid(ST_Collect( - ST_Transform(v2_connection_nodes.the_geom, {srid}))) AS geom - FROM {src_table}_map - JOIN v2_connection_nodes ON {src_table}_map.connection_node_id = v2_connection_nodes.id - JOIN {src_table} ON {src_table}_map.{surf_id} = {src_table}.id - WHERE {src_table}_map.{surf_id} = {src_table}.id - GROUP BY {src_table}.id + SELECT {src_table}.id AS item_id, + ST_Centroid(ST_Collect( + ST_Transform(v2_connection_nodes.the_geom, {srid}))) AS geom + FROM {src_table}_map + JOIN v2_connection_nodes ON {src_table}_map.connection_node_id = v2_connection_nodes.id + JOIN {src_table} ON {src_table}_map.{surf_id} = {src_table}.id + GROUP BY {src_table}.id ), side_length AS ( - SELECT {side_expr} AS side + SELECT id, sqrt(area) AS side + FROM {src_table} ) UPDATE {src_table} - SET {tmp_geom} = ( - SELECT ST_Transform( - SetSRID( - ST_GeomFromText('POLYGON((' || - (ST_X(center.geom) - side_length.side / 2) || ' ' || (ST_Y(center.geom) - side_length.side / 2) || ',' || - (ST_X(center.geom) + side_length.side / 2) || ' ' || (ST_Y(center.geom) - side_length.side / 2) || ',' || - (ST_X(center.geom) + side_length.side / 2) || ' ' || (ST_Y(center.geom) + side_length.side / 2) || ',' || - (ST_X(center.geom) - side_length.side / 2) || ' ' || (ST_Y(center.geom) + side_length.side / 2) || ',' || - (ST_X(center.geom) - side_length.side / 2) || ' ' || (ST_Y(center.geom) - side_length.side / 2) || - '))'), - {srid}), - 4326 - ) AS transformed_geom - FROM center, side_length - WHERE center.item_id = {src_table}.id - ) - WHERE {tmp_geom} IS NULL; + SET {tmp_geom} = ST_Transform( + SetSRID( + ST_GeomFromText('POLYGON((' || + (ST_X(center.geom) - side_length.side / 2) || ' ' || (ST_Y(center.geom) - side_length.side / 2) || ',' || + (ST_X(center.geom) + side_length.side / 2) || ' ' || (ST_Y(center.geom) - side_length.side / 2) || ',' || + (ST_X(center.geom) + side_length.side / 2) || ' ' || (ST_Y(center.geom) + side_length.side / 2) || ',' || + (ST_X(center.geom) - side_length.side / 2) || ' ' || (ST_Y(center.geom) + side_length.side / 2) || ',' || + (ST_X(center.geom) - side_length.side / 2) || ' ' || (ST_Y(center.geom) - side_length.side / 2) || + '))'), + {srid}), + 4326 + ) + FROM center + JOIN side_length ON center.item_id = side_length.id + WHERE {src_table}.id = center.item_id + AND {tmp_geom} IS NULL; """ op.execute(sa.text(query_str)) + def fix_src_geometry(src_table: str, tmp_geom: str, create_polygons): conn = op.get_bind() # create columns to store the derived geometries to @@ -382,7 +383,7 @@ def fix_src_geometry(src_table: str, tmp_geom: str, create_polygons): create_polygons(src_table, tmp_geom) -def remove_invalid_rows(src_table:str): +def remove_invalid_rows(src_table: str): # Remove rows with insufficient data op.execute(sa.text(f"DELETE FROM {src_table} WHERE area = 0 " "AND (nr_of_inhabitants = 0 OR dry_weather_flow = 0);")) @@ -398,6 +399,7 @@ def remove_invalid_rows(src_table:str): f"they are not mapped to a connection node in {src_table}_map: {no_map_id}") warnings.warn(msg, NoMappingWarning) + def populate_surface_and_dry_weather_flow(): conn = op.get_bind() use_0d_inflow = conn.execute(sa.text("SELECT use_0d_inflow FROM simulation_template_settings LIMIT 1")).fetchone() @@ -424,7 +426,6 @@ def populate_surface_and_dry_weather_flow(): # Remove rows in maps that refer to non-existing objects remove_orphans_from_map(basename="surface") remove_orphans_from_map(basename="dry_weather_flow") - # Create geometries in new maps add_map_geometries("surface") add_map_geometries("dry_weather_flow") @@ -435,6 +436,17 @@ def populate_surface_and_dry_weather_flow(): # Populate tables with default values populate_dry_weather_flow_distribution() populate_surface_parameters() + update_use_0d_inflow() + +def update_use_0d_inflow(): + op.execute(sa.text(""" + UPDATE simulation_template_settings + SET use_0d_inflow = 0 + WHERE + (SELECT COUNT(*) FROM surface) = 0 + AND + (SELECT COUNT(*) FROM dry_weather_flow) = 0; + """)) def set_surface_parameters_id(): @@ -484,17 +496,11 @@ def fix_geometry_columns(): op.execute(sa.text(migration_query)) -def drop_conflicting(): - new_tables = list(ADD_TABLES.keys()) + [new_name for _, new_name in RENAME_TABLES] - for table_name in new_tables: - op.execute(f"DROP TABLE IF EXISTS {table_name};") - - def upgrade(): connection = op.get_bind() listen(connection.engine, "connect", load_spatialite) # Remove existing tables (outside of the specs) that conflict with new table names - drop_conflicting() + drop_conflicting(op, list(ADD_TABLES.keys()) + [new_name for _, new_name in RENAME_TABLES]) # create new tables and rename existing tables create_new_tables(ADD_TABLES) rename_tables(RENAME_TABLES) diff --git a/threedi_schema/migrations/versions/0224_upgrade_db_structure_control.py b/threedi_schema/migrations/versions/0224_upgrade_db_structure_control.py index aa39ec8..88d5062 100644 --- a/threedi_schema/migrations/versions/0224_upgrade_db_structure_control.py +++ b/threedi_schema/migrations/versions/0224_upgrade_db_structure_control.py @@ -15,6 +15,7 @@ from sqlalchemy.orm import declarative_base from threedi_schema.domain.custom_types import Geometry +from threedi_schema.migrations.utils import drop_conflicting, drop_geo_table # revision identifiers, used by Alembic. revision = "0224" @@ -335,7 +336,7 @@ def remove_column_from_table(table_name: str, column: str): def remove_tables(tables: List[str]): for table in tables: - op.drop_table(table) + drop_geo_table(op, table) def make_geom_col_notnull(table_name): @@ -357,7 +358,7 @@ def make_geom_col_notnull(table_name): temp_name = f'_temp_224_{uuid.uuid4().hex}' op.execute(sa.text(f"CREATE TABLE {temp_name} ({','.join(cols)});")) op.execute(sa.text(f"INSERT INTO {temp_name} ({','.join(col_names)}) SELECT {','.join(col_names)} FROM {table_name}")) - op.execute(sa.text(f"DROP TABLE {table_name};")) + drop_geo_table(op, table_name) op.execute(sa.text(f"ALTER TABLE {temp_name} RENAME TO {table_name};")) @@ -370,15 +371,20 @@ def fix_geometry_columns(): op.execute(sa.text(migration_query)) -def drop_conflicting(): - new_tables = list(ADD_TABLES.keys()) + [new_name for _, new_name in RENAME_TABLES] - for table_name in new_tables: - op.execute(f"DROP TABLE IF EXISTS {table_name};") +def update_use_structure_control(): + op.execute(""" + UPDATE simulation_template_settings SET use_structure_control = CASE + WHEN + (SELECT COUNT(*) FROM table_control) = 0 AND + (SELECT COUNT(*) FROM memory_control) = 0 THEN 0 + ELSE use_structure_control + END; + """) def upgrade(): # Remove existing tables (outside of the specs) that conflict with new table names - drop_conflicting() + drop_conflicting(op, list(ADD_TABLES.keys()) + [new_name for _, new_name in RENAME_TABLES]) # create new tables and rename existing tables create_new_tables(ADD_TABLES) rename_tables(RENAME_TABLES) @@ -400,6 +406,7 @@ def upgrade(): rename_measure_operator('memory_control') move_setting('model_settings', 'use_structure_control', 'simulation_template_settings', 'use_structure_control') + update_use_structure_control() remove_tables(DEL_TABLES) # Fix geometry columns and also make all but geom column nullable fix_geometry_columns() diff --git a/threedi_schema/migrations/versions/0225_migrate_lateral_boundary_condition_tables.py b/threedi_schema/migrations/versions/0225_migrate_lateral_boundary_condition_tables.py index 08e77d0..357499a 100644 --- a/threedi_schema/migrations/versions/0225_migrate_lateral_boundary_condition_tables.py +++ b/threedi_schema/migrations/versions/0225_migrate_lateral_boundary_condition_tables.py @@ -18,6 +18,7 @@ from sqlalchemy.orm import declarative_base from threedi_schema.domain.custom_types import Geometry +from threedi_schema.migrations.utils import drop_conflicting, drop_geo_table # revision identifiers, used by Alembic. revision = "0225" @@ -119,8 +120,13 @@ def rename_tables(table_sets: List[Tuple[str, str]]): # no checks for existence are done, this will fail if a source table doesn't exist + connection = op.get_bind() + spatialite_version = connection.execute(sa.text("SELECT spatialite_version();")).fetchall()[0][0] for src_name, dst_name in table_sets: - op.rename_table(src_name, dst_name) + if spatialite_version.startswith('5'): + op.execute(sa.text(f"SELECT RenameTable(NULL, '{src_name}', '{dst_name}');")) + else: + op.rename_table(src_name, dst_name) def create_new_tables(new_tables: Dict[str, sa.Column]): @@ -183,7 +189,7 @@ def rename_columns(table_name: str, columns: List[Tuple[str, str]]): create_table_query = f"""CREATE TABLE {temp_name} ({', '.join(new_columns_list_sql_formatted)});""" op.execute(sa.text(create_table_query)) op.execute(sa.text(f"INSERT INTO {temp_name} ({','.join(new_columns_list)}) SELECT {','.join(old_columns_list)} from {table_name};")) - op.execute(sa.text(f"DROP TABLE {table_name};")) + drop_geo_table(op, table_name) op.execute(sa.text(f"ALTER TABLE {temp_name} RENAME TO {table_name};")) for entry in new_columns: @@ -215,15 +221,9 @@ def populate_table(table: str, values: dict): op.execute(sa.text(query)) -def drop_conflicting(): - new_tables = [new_name for _, new_name in RENAME_TABLES] - for table_name in new_tables: - op.execute(f"DROP TABLE IF EXISTS {table_name};") - - def upgrade(): # Drop tables that conflict with new table names - drop_conflicting() + drop_conflicting(op, [new_name for _, new_name in RENAME_TABLES]) # rename existing tables rename_tables(RENAME_TABLES) diff --git a/threedi_schema/migrations/versions/0226_upgrade_db_1d_1d2d.py b/threedi_schema/migrations/versions/0226_upgrade_db_1d_1d2d.py index f12d53e..77d6173 100644 --- a/threedi_schema/migrations/versions/0226_upgrade_db_1d_1d2d.py +++ b/threedi_schema/migrations/versions/0226_upgrade_db_1d_1d2d.py @@ -12,7 +12,7 @@ from sqlalchemy import Boolean, Column, Float, Integer, String, Text from sqlalchemy.orm import declarative_base -from threedi_schema.domain.custom_types import Geometry +from threedi_schema.migrations.utils import drop_conflicting, drop_geo_table # revision identifiers, used by Alembic. revision = "0226" @@ -74,7 +74,7 @@ def add_columns_to_tables(table_columns: List[Tuple[str, Column]]): def remove_tables(tables: List[str]): for table in tables: - op.drop_table(table) + drop_geo_table(op, table) def modify_table(old_table_name, new_table_name): @@ -167,15 +167,9 @@ def set_potential_breach_final_exchange_level(): )) -def drop_conflicting(): - new_tables = [new_name for _, new_name in RENAME_TABLES] - for table_name in new_tables: - op.execute(f"DROP TABLE IF EXISTS {table_name};") - - def upgrade(): # Drop tables that conflict with new table names - drop_conflicting() + drop_conflicting(op, [new_name for _, new_name in RENAME_TABLES]) rem_tables = [] for old_table_name, new_table_name in RENAME_TABLES: modify_table(old_table_name, new_table_name) diff --git a/threedi_schema/migrations/versions/0228_upgrade_db_1D.py b/threedi_schema/migrations/versions/0228_upgrade_db_1D.py index d34b56f..04e71ff 100644 --- a/threedi_schema/migrations/versions/0228_upgrade_db_1D.py +++ b/threedi_schema/migrations/versions/0228_upgrade_db_1D.py @@ -12,11 +12,12 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy import Column, Float, func, Integer, select, String +from sqlalchemy import Column, Float, func, Integer, select, String, Text from sqlalchemy.orm import declarative_base, Session -from threedi_schema.domain import constants, models -from threedi_schema.domain.custom_types import IntegerEnum +from threedi_schema.domain import constants +from threedi_schema.domain.custom_types import Geometry, IntegerEnum +from threedi_schema.migrations.utils import drop_conflicting, drop_geo_table Base = declarative_base() @@ -83,10 +84,42 @@ "pump": ["connection_node_end_id", "zoom_category", "classification"] } +ADD_COLUMNS = [ + ("channel", Column("tags", Text)), + ("cross_section_location", Column("tags", Text)), + ("culvert", Column("tags", Text)), + ("culvert", Column("material_id", Integer)), + ("orifice", Column("tags", Text)), + ("orifice", Column("material_id", Integer)), + ("pipe", Column("tags", Text)), + ("pump", Column("tags", Text)), + ("weir", Column("tags", Text)), + ("weir", Column("material_id", Integer)), + ("windshielding_1d", Column("tags", Text)), +] RETYPE_COLUMNS = {} +def add_columns_to_tables(table_columns: List[Tuple[str, Column]]): + # no checks for existence are done, this will fail if any column already exists + for dst_table, col in table_columns: + if isinstance(col.type, Geometry): + add_geometry_column(dst_table, col) + else: + with op.batch_alter_table(dst_table) as batch_op: + batch_op.add_column(col) + + +def add_geometry_column(table: str, geocol: Column): + # Adding geometry columns via alembic doesn't work + # https://postgis.net/docs/AddGeometryColumn.html + geotype = geocol.type + query = ( + f"SELECT AddGeometryColumn('{table}', '{geocol.name}', {geotype.srid}, '{geotype.geometry_type}', 'XY', 1);") + op.execute(sa.text(query)) + + class Schema228UpgradeException(Exception): pass @@ -100,45 +133,72 @@ def add_columns_to_tables(table_columns: List[Tuple[str, Column]]): def remove_tables(tables: List[str]): for table in tables: - op.drop_table(table) + drop_geo_table(op, table) +def get_geom_type(table_name, geo_col_name): + connection = op.get_bind() + columns = connection.execute(sa.text(f"PRAGMA table_info('{table_name}')")).fetchall() + for col in columns: + if col[1] == geo_col_name: + return col[2] + def modify_table(old_table_name, new_table_name): - # Create a new table named `new_table_name` using the declared models + # Create a new table named `new_table_name` by copying the + # data from `old_table_name`. # Use the columns from `old_table_name`, with the following exceptions: + # * columns in `REMOVE_COLUMNS[new_table_name]` are skipped # * columns in `RENAME_COLUMNS[new_table_name]` are renamed + # * columns in `RETYPE_COLUMNS[new_table_name]` change type # * `the_geom` is renamed to `geom` and NOT NULL is enforced - model = find_model(new_table_name) - # create new table - create_sqlite_table_from_model(model) - # get column names from model and match them to available data in sqlite connection = op.get_bind() - rename_cols = {**RENAME_COLUMNS.get(new_table_name, {}), "the_geom": "geom"} - rename_cols_rev = {v: k for k, v in rename_cols.items()} - col_map = [(col.name, rename_cols_rev.get(col.name, col.name)) for col in get_cols_for_model(model)] - available_cols = [col[1] for col in connection.execute(sa.text(f"PRAGMA table_info('{old_table_name}')")).fetchall()] - new_col_names, old_col_names = zip(*[(new_col, old_col) for new_col, old_col in col_map if old_col in available_cols]) + columns = connection.execute(sa.text(f"PRAGMA table_info('{old_table_name}')")).fetchall() + # get all column names and types + col_names = [col[1] for col in columns] + col_types = [col[2] for col in columns] + # get type of the geometry column + geom_type = get_geom_type(old_table_name, 'the_geom') + # create list of new columns and types for creating the new table + # create list of old columns to copy to new table + skip_cols = ['id', 'the_geom'] + if new_table_name in REMOVE_COLUMNS: + skip_cols += REMOVE_COLUMNS[new_table_name] + old_col_names = [] + new_col_names = [] + new_col_types = [] + for cname, ctype in zip(col_names, col_types): + if cname in skip_cols: + continue + old_col_names.append(cname) + if new_table_name in RENAME_COLUMNS and cname in RENAME_COLUMNS[new_table_name]: + new_col_names.append(RENAME_COLUMNS[new_table_name][cname]) + else: + new_col_names.append(cname) + if new_table_name in RETYPE_COLUMNS and cname in RETYPE_COLUMNS[new_table_name]: + new_col_types.append(RETYPE_COLUMNS[new_table_name][cname]) + else: + new_col_types.append(ctype) + # add to the end manually + old_col_names.append('the_geom') + new_col_names.append('geom') + new_col_types.append(f'{geom_type} NOT NULL') + # Create new table (temp), insert data, drop original and rename temp to table_name + new_col_str = ','.join(['id INTEGER PRIMARY KEY NOT NULL'] + [f'{cname} {ctype}' for cname, ctype in + zip(new_col_names, new_col_types)]) + op.execute(sa.text(f"CREATE TABLE {new_table_name} ({new_col_str});")) # Copy data - # This may copy wrong type data because some types change!! - op.execute(sa.text(f"INSERT INTO {new_table_name} ({','.join(new_col_names)}) " - f"SELECT {','.join(old_col_names)} FROM {old_table_name}")) + op.execute(sa.text(f"INSERT INTO {new_table_name} (id, {','.join(new_col_names)}) " + f"SELECT id, {','.join(old_col_names)} FROM {old_table_name}")) -def find_model(table_name): - for model in models.DECLARED_MODELS: - if model.__tablename__ == table_name: - return model - # This can only go wrong if the migration or model is incorrect - raise - def fix_geometry_columns(): - update_models = [models.Channel, models.ConnectionNode, models.CrossSectionLocation, - models.Culvert, models.Orifice, models.Pipe, models.Pump, - models.PumpMap, models.Weir, models.Windshielding] - for model in update_models: - op.execute(sa.text(f"SELECT RecoverGeometryColumn('{model.__tablename__}', " - f"'geom', {4326}, '{model.geom.type.geometry_type}', 'XY')")) - op.execute(sa.text(f"SELECT CreateSpatialIndex('{model.__tablename__}', 'geom')")) + tables = ['channel', 'connection_node', 'cross_section_location', 'culvert', + 'orifice', 'pipe', 'pump', 'pump_map', 'weir', 'windshielding_1d'] + for table in tables: + geom_type = get_geom_type(table, geo_col_name='geom') + op.execute(sa.text(f"SELECT RecoverGeometryColumn('{table}', " + f"'geom', {4326}, '{geom_type}', 'XY')")) + op.execute(sa.text(f"SELECT CreateSpatialIndex('{table}', 'geom')")) class Temp(Base): @@ -304,29 +364,16 @@ def set_geom_for_v2_pumpstation(): op.execute(sa.text(q)) -def get_cols_for_model(model, skip_cols=None): - from sqlalchemy.orm.attributes import InstrumentedAttribute - if skip_cols is None: - skip_cols = [] - return [getattr(model, item) for item in model.__dict__ - if item not in skip_cols - and isinstance(getattr(model, item), InstrumentedAttribute)] - - -def create_sqlite_table_from_model(model): - cols = get_cols_for_model(model, skip_cols = ["id", "geom"]) - op.execute(sa.text(f""" - CREATE TABLE {model.__tablename__} ( - id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, - {','.join(f"{col.name} {col.type}" for col in cols)}, - geom {model.geom.type.geometry_type} NOT NULL - );""")) - - def create_pump_map(): # Create table - create_sqlite_table_from_model(models.PumpMap) - + query = """ + CREATE TABLE pump_map ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + pump_id INTEGER,connection_node_id_end INTEGER,tags TEXT,code VARCHAR(100),display_name VARCHAR(255), + geom LINESTRING NOT NULL + ); + """ + op.execute(sa.text(query)) # Create geometry op.execute(sa.text(f"SELECT AddGeometryColumn('v2_pumpstation', 'map_geom', 4326, 'LINESTRING', 'XY', 0);")) op.execute(sa.text(""" @@ -357,7 +404,15 @@ def create_pump_map(): def create_connection_node(): - create_sqlite_table_from_model(models.ConnectionNode) + # Create table + query = """ + CREATE TABLE connection_node ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + code VARCHAR(100),tags TEXT,display_name TEXT,storage_area FLOAT,initial_water_level FLOAT,visualisation INTEGER,manhole_surface_level FLOAT,bottom_level FLOAT,exchange_level FLOAT,exchange_type INTEGER,exchange_thickness FLOAT,hydraulic_conductivity_in FLOAT,hydraulic_conductivity_out FLOAT, + geom POINT NOT NULL + ); + """ + op.execute(sa.text(query)) # copy from v2_connection_nodes old_col_names = ["id", "initial_waterlevel", "storage_area", "the_geom", "code"] rename_map = {"initial_waterlevel": "initial_water_level", "the_geom": "geom"} @@ -388,6 +443,15 @@ def create_connection_node(): """)) +# define Material class needed to populate table in create_material +class Material(Base): + __tablename__ = "material" + id = Column(Integer, primary_key=True) + description = Column(Text) + friction_type = Column(IntegerEnum(constants.FrictionType)) + friction_coefficient = Column(Float) + + def create_material(): op.execute(sa.text(""" CREATE TABLE material ( @@ -396,12 +460,13 @@ def create_material(): friction_type INTEGER, friction_coefficient REAL); """)) + connection = op.get_bind() + nof_settings = connection.execute(sa.text("SELECT COUNT(*) FROM model_settings")).scalar() session = Session(bind=op.get_bind()) - nof_settings = session.execute(select(func.count()).select_from(models.ModelSettings)).scalar() if nof_settings > 0: with open(data_dir.joinpath('0228_materials.csv')) as file: reader = csv.DictReader(file) - session.bulk_save_objects([models.Material(**row) for row in reader]) + session.bulk_save_objects([Material(**row) for row in reader]) session.commit() @@ -448,21 +513,12 @@ def fix_material_id(): f"{' '.join([f'WHEN {old} THEN {new}' for old, new in replace_map.items()])} " "ELSE material_id END")) - - -def drop_conflicting(): - new_tables = [new_name for _, new_name in RENAME_TABLES] + ['material', 'pump_map'] - for table_name in new_tables: - op.execute(f"DROP TABLE IF EXISTS {table_name};") - - - def upgrade(): # Empty or non-existing connection node id (start or end) in Orifice, Pipe, Pumpstation or Weir will break # migration, so an error is raised in these cases check_for_null_geoms() # Prevent custom tables in schematisation from breaking migration when they conflict with new table names - drop_conflicting() + drop_conflicting(op, [new_name for _, new_name in RENAME_TABLES] + ['material', 'pump_map']) # Extent cross section definition table (actually stored in temp) extend_cross_section_definition_table() # Migrate data from cross_section_definition to cross_section_location @@ -476,6 +532,7 @@ def upgrade(): set_geom_for_v2_pumpstation() for old_table_name, new_table_name in RENAME_TABLES: modify_table(old_table_name, new_table_name) + add_columns_to_tables(ADD_COLUMNS) # Create new tables create_pump_map() create_material() diff --git a/threedi_schema/migrations/versions/0229_clean_up.py b/threedi_schema/migrations/versions/0229_clean_up.py new file mode 100644 index 0000000..f205680 --- /dev/null +++ b/threedi_schema/migrations/versions/0229_clean_up.py @@ -0,0 +1,120 @@ +""" + +Revision ID: 022 +9Revises: +Create Date: 2024-11-15 14:18 + +""" +import uuid +from typing import List + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "0229" +down_revision = "0228" +branch_labels = None +depends_on = None + + +def get_geom_type(table_name, geo_col_name): + connection = op.get_bind() + columns = connection.execute(sa.text(f"PRAGMA table_info('{table_name}')")).fetchall() + for col in columns: + if col[1] == geo_col_name: + return col[2] + +def change_types_in_settings_table(): + temp_table_name = f'_temp_229_{uuid.uuid4().hex}' + table_name = 'model_settings' + change_types = {'use_d2_rain': 'bool', 'friction_averaging': 'bool'} + connection = op.get_bind() + columns = connection.execute(sa.text(f"PRAGMA table_info('{table_name}')")).fetchall() + # get all column names and types + skip_cols = ['id', 'the_geom'] + col_names = [col[1] for col in columns if col[1] not in skip_cols] + old_col_types = [col[2] for col in columns if col[1] not in skip_cols] + col_types = [change_types.get(col_name, col_type) for col_name, col_type in zip(col_names, old_col_types)] + # Create new table, insert data, drop original and rename temp to table_name + col_str = ','.join(['id INTEGER PRIMARY KEY NOT NULL'] + [f'{cname} {ctype}' for cname, ctype in + zip(col_names, col_types)]) + op.execute(sa.text(f"CREATE TABLE {temp_table_name} ({col_str});")) + # Copy data + op.execute(sa.text(f"INSERT INTO {temp_table_name} (id, {','.join(col_names)}) " + f"SELECT id, {','.join(col_names)} FROM {table_name}")) + op.execute(sa.text(f"DROP TABLE {table_name}")) + op.execute(sa.text(f"ALTER TABLE {temp_table_name} RENAME TO {table_name};")) + + +def remove_tables(tables: List[str]): + for table in tables: + op.drop_table(table) + + +def find_tables_by_pattern(pattern: str) -> List[str]: + connection = op.get_bind() + query = connection.execute( + sa.text(f"select name from sqlite_master where type = 'table' and name like '{pattern}'")) + return [item[0] for item in query.fetchall()] + + +def remove_old_tables(): + remaining_v2_idx_tables = find_tables_by_pattern('idx_v2_%_the_geom') + remaining_alembic = find_tables_by_pattern('%_alembic_%_the_geom') + remove_tables(remaining_v2_idx_tables + remaining_alembic) + + +def clean_geometry_columns(): + """ Remove columns referencing v2 in geometry_columns """ + op.execute(sa.text(""" + DELETE FROM geometry_columns WHERE f_table_name IN ( + SELECT g.f_table_name FROM geometry_columns g + LEFT JOIN sqlite_master m ON g.f_table_name = m.name + WHERE m.name IS NULL AND g.f_table_name like "%v2%" + ); + """)) + + +def clean_by_type(type: str): + connection = op.get_bind() + items = [item[0] for item in connection.execute( + sa.text(f"SELECT tbl_name FROM sqlite_master WHERE type='{type}' AND tbl_name LIKE '%v2%';")).fetchall()] + for item in items: + op.execute(f"DROP {type} IF EXISTS {item};") + + +def update_use_settings(): + # Ensure that use_* settings are only True when there is actual data for them + use_settings = [ + ('use_groundwater_storage', 'groundwater'), + ('use_groundwater_flow', 'groundwater'), + ('use_interflow', 'interflow'), + ('use_simple_infiltration', 'simple_infiltration'), + ('use_vegetation_drag_2d', 'vegetation_drag_2d'), + ('use_interception', 'interception') + ] + connection = op.get_bind() # Get the connection for raw SQL execution + for setting, table in use_settings: + use_row = connection.execute(sa.text(f"SELECT {setting} FROM model_settings")).scalar() + if not use_row: + continue + row = connection.execute(sa.text(f"SELECT * FROM {table}")).first() + use_row = (row is not None) + if use_row: + use_row = not all(item in (None, "") for item in row[1:]) + if not use_row: + connection.execute(sa.text(f"UPDATE model_settings SET {setting} = 0")) + + +def upgrade(): + remove_old_tables() + clean_geometry_columns() + clean_by_type('trigger') + clean_by_type('view') + update_use_settings() + change_types_in_settings_table() + + +def downgrade(): + pass diff --git a/threedi_schema/tests/test_custom_types.py b/threedi_schema/tests/test_custom_types.py new file mode 100644 index 0000000..988c561 --- /dev/null +++ b/threedi_schema/tests/test_custom_types.py @@ -0,0 +1,50 @@ +import pytest + +from threedi_schema.domain.custom_types import clean_csv_string, clean_csv_table + + +@pytest.mark.parametrize( + "value", + [ + "1,2,3", + "1, 2, 3 ", + "1,\t2,3", + "1,\r2,3 ", + "1,\n2,3 ", + "1, 2,3", + "1, 2 ,3", + " 1,2,3 ", + "\n1,2,3", + "\t1,2,3", + "\r1,2,3", + "1,2,3\t", + "1,2,3\n", + "1,2,3\r", + ], +) +def test_clean_csv_string(value): + assert clean_csv_string(value) == "1,2,3" + + +def test_clean_csv_string_with_whitespace(): + assert clean_csv_string("1,2 3,4") == "1,2 3,4" + + +@pytest.mark.parametrize( + "value", + [ + "1,2,3\n4,5,6", + "1,2,3\r\n4,5,6", + "\n1,2,3\n4,5,6", + "1,2,3\n4,5,6\n", + ], +) +def test_clean_csv_table(value): + assert clean_csv_table(value) == "1,2,3\n4,5,6" + + +@pytest.mark.parametrize( + "value", [" ", "0 1", "3;5", "foo", "1,2\n3,", ",2", ",2\n3,4"] +) +def test_clean_csv_table_no_fail(value): + clean_csv_table(value) diff --git a/threedi_schema/tests/test_migration.py b/threedi_schema/tests/test_migration.py index 6581f43..8d58175 100644 --- a/threedi_schema/tests/test_migration.py +++ b/threedi_schema/tests/test_migration.py @@ -64,7 +64,11 @@ def get_columns_from_sqlite(cursor, table_name): for c in cursor.fetchall(): if 'geom' in c[1]: continue - type_str = c[2].lower() if c[2] != 'bool' else 'boolean' + type_str = c[2].lower() + if type_str == 'bool': + type_str = 'boolean' + if type_str == 'int': + type_str = 'integer' col_map[c[1]] = (type_str, not c[3]) return col_map @@ -216,7 +220,7 @@ class TestMigration223: pytestmark = pytest.mark.migration_223 removed_tables = set(['v2_surface', 'v2_surface_parameters', 'v2_surface_map', 'v2_impervious_surface', 'v2_impervious_surface_map']) - added_tables = set(['surface', 'surface_map', 'surface_parameters', 'tags', + added_tables = set(['surface', 'surface_map', 'surface_parameters', 'tag', 'dry_weather_flow', 'dry_weather_flow_map', 'dry_weather_flow_distribution']) def test_tables(self, schema_ref, schema_upgraded): diff --git a/threedi_schema/tests/test_schema.py b/threedi_schema/tests/test_schema.py index 5e24367..6ce44b4 100644 --- a/threedi_schema/tests/test_schema.py +++ b/threedi_schema/tests/test_schema.py @@ -219,7 +219,7 @@ def test_set_spatial_indexes(in_memory_sqlite): connection.execute( text("SELECT DisableSpatialIndex('connection_node', 'geom')") ).scalar() - connection.execute(text("DROP TABLE idx_v2_connection_nodes_the_geom")) + connection.execute(text("DROP TABLE idx_connection_node_geom")) schema.set_spatial_indexes()