Skip to content

Commit

Permalink
Comments addressed
Browse files Browse the repository at this point in the history
  • Loading branch information
pruzko committed Jan 19, 2025
1 parent 1fc8459 commit 4886783
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 44 deletions.
21 changes: 0 additions & 21 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,27 +1286,6 @@ def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str:
return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))


# def locate_to_strposition(args: t.List) -> exp.Expression:
# return exp.StrPosition(
# this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
# )

# @unsupported_args("occurrence")
# def strposition_to_charindex_sql(self: Generator, expression: exp.StrPosition) -> str:
# return self.func(
# "CHARINDEX",
# expression.args.get("substr"),
# expression.this,
# expression.args.get("position"),
# )

# @unsupported_args("occurrence")
# def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
# return self.func(
# "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
# )


def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
return self.sql(
exp.Substring(
Expand Down
35 changes: 14 additions & 21 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,16 @@ def build_coalesce(args: t.List, is_nvl: t.Optional[bool] = None) -> exp.Coalesc
return exp.Coalesce(this=seq_get(args, 0), expressions=args[1:], is_nvl=is_nvl)


def build_strposition(args: t.List, func_name: str):
if func_name in ["LOCATE", "CHARINDEX"]:
return exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
)
return exp.StrPosition.from_arg_list(args)


class _Parser(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
Expand Down Expand Up @@ -231,27 +241,10 @@ class Parser(metaclass=_Parser):
"SCOPE_RESOLUTION": lambda args: exp.ScopeResolution(expression=seq_get(args, 0))
if len(args) != 2
else exp.ScopeResolution(this=seq_get(args, 0), expression=seq_get(args, 1)),
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0),
substr=seq_get(args, 1),
position=seq_get(args, 2),
),
"CHARINDEX": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
"INSTR": lambda args: exp.StrPosition(
this=seq_get(args, 0),
substr=seq_get(args, 1),
position=seq_get(args, 2),
occurrence=seq_get(args, 3),
),
"LOCATE": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
"STRPOS": lambda args: build_strposition(args, func_name="STRPOS"),
"CHARINDEX": lambda args: build_strposition(args, func_name="CHARINDEX"),
"INSTR": lambda args: build_strposition(args, func_name="INSTR"),
"LOCATE": lambda args: build_strposition(args, func_name="LOCATE"),
"TIME_TO_TIME_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
Expand Down
18 changes: 16 additions & 2 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,24 +1689,31 @@ def test_operators(self):
"presto": "STRPOS(haystack, needle)",
"sqlite": "INSTR(haystack, needle)",
"tableau": "FIND(haystack, needle)",
# "teradata": "INDEX(haystack, needle)",
"tsql": "CHARINDEX(needle, haystack)",
},
write={
"athena": "STRPOS(haystack, needle)",
"bigquery": "INSTR(haystack, needle)",
"clickhouse": "LOCATE(needle, haystack)",
"databricks": "LOCATE(needle, haystack)",
"doris": "LOCATE(needle, haystack)",
"drill": "STRPOS(haystack, needle)",
"duckdb": "STRPOS(haystack, needle)",
"hive": "LOCATE(needle, haystack)",
"materialize": "POSITION(needle IN haystack)",
"mysql": "LOCATE(needle, haystack)",
"oracle": "INSTR(haystack, needle)",
"postgres": "POSITION(needle IN haystack)",
"presto": "STRPOS(haystack, needle)",
"redshift": "POSITION(needle IN haystack)",
"risingwave": "POSITION(needle IN haystack)",
"snowflake": "CHARINDEX(needle, haystack)",
"spark": "LOCATE(needle, haystack)",
"spark2": "LOCATE(needle, haystack)",
"sqlite": "INSTR(haystack, needle)",
"tableau": "FIND(haystack, needle)",
"teradata": "INSTR(haystack, needle)",
"trino": "STRPOS(haystack, needle)",
"tsql": "CHARINDEX(needle, haystack)",
},
)
Expand All @@ -1718,24 +1725,31 @@ def test_operators(self):
"snowflake": "POSITION(needle, haystack, position)",
"sqlite": "INSTR(haystack, needle, position)",
"tableau": "FINDNTH(haystack, needle, position)",
# "teradata": "INDEX(haystack, needle, position)",
"tsql": "CHARINDEX(needle, haystack, position)",
},
write={
"athena": "IF(STRPOS(SUBSTRING(haystack, position), needle) = 0, 0, STRPOS(SUBSTRING(haystack, position), needle) + position - 1)",
"bigquery": "INSTR(haystack, needle, position)",
"clickhouse": "LOCATE(needle, haystack, position)",
"databricks": "LOCATE(needle, haystack, position)",
"doris": "LOCATE(needle, haystack, position)",
"drill": "`IF`(STRPOS(SUBSTRING(haystack, position), needle) = 0, 0, STRPOS(SUBSTRING(haystack, position), needle) + position - 1)",
"duckdb": "CASE WHEN STRPOS(SUBSTRING(haystack, position), needle) = 0 THEN 0 ELSE STRPOS(SUBSTRING(haystack, position), needle) + position - 1 END",
"hive": "LOCATE(needle, haystack, position)",
"materialize": "CASE WHEN POSITION(needle IN SUBSTRING(haystack FROM position)) = 0 THEN 0 ELSE POSITION(needle IN SUBSTRING(haystack FROM position)) + position - 1 END",
"mysql": "LOCATE(needle, haystack, position)",
"oracle": "INSTR(haystack, needle, position)",
"postgres": "CASE WHEN POSITION(needle IN SUBSTRING(haystack FROM position)) = 0 THEN 0 ELSE POSITION(needle IN SUBSTRING(haystack FROM position)) + position - 1 END",
"presto": "IF(STRPOS(SUBSTRING(haystack, position), needle) = 0, 0, STRPOS(SUBSTRING(haystack, position), needle) + position - 1)",
"redshift": "CASE WHEN POSITION(needle IN SUBSTRING(haystack FROM position)) = 0 THEN 0 ELSE POSITION(needle IN SUBSTRING(haystack FROM position)) + position - 1 END",
"risingwave": "CASE WHEN POSITION(needle IN SUBSTRING(haystack FROM position)) = 0 THEN 0 ELSE POSITION(needle IN SUBSTRING(haystack FROM position)) + position - 1 END",
"snowflake": "CHARINDEX(needle, haystack, position)",
"spark": "LOCATE(needle, haystack, position)",
"spark2": "LOCATE(needle, haystack, position)",
"sqlite": "IIF(INSTR(SUBSTRING(haystack, position), needle) = 0, 0, INSTR(SUBSTRING(haystack, position), needle) + position - 1)",
"tableau": "FINDNTH(haystack, needle, position)",
"teradata": "INSTR(haystack, needle, position)",
"trino": "IF(STRPOS(SUBSTRING(haystack, position), needle) = 0, 0, STRPOS(SUBSTRING(haystack, position), needle) + position - 1)",
"tsql": "CHARINDEX(needle, haystack, position)",
},
)
Expand Down

0 comments on commit 4886783

Please sign in to comment.