Skip to content

Commit

Permalink
remove prints + refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jsj authored and mrjsj committed Dec 21, 2024
1 parent 9f7ab6e commit 2f939af
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions src/msfabricutils/etl/transform/transforms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Callable

import polars as pl
import polars.selectors as cs
from polars.exceptions import ColumnNotFoundError

from msfabricutils.etl.config import Column
Expand Down Expand Up @@ -302,6 +301,11 @@ def apply_scd_type_2(
if isinstance(primary_key_columns, str):
primary_key_columns = [primary_key_columns]

target_columns = target_df.schema.names() if isinstance(target_df, pl.DataFrame) else target_df.collect_schema().names()

# Find records in the target table, which potentially need to be updated.
# In combination with the following filter, it executes a non-equi join.
# Essentially, we just want to find the rows in the target table which "surround" the row in the source table by the valid_from column.
target_records_to_be_updated = (
target_df.join(
source_df,
Expand All @@ -310,7 +314,7 @@ def apply_scd_type_2(
suffix="__source",
)
)
print(target_records_to_be_updated)

target_records_to_be_updated = (
target_records_to_be_updated
.filter(
Expand All @@ -325,29 +329,28 @@ def apply_scd_type_2(
)
)
)
print(target_records_to_be_updated)

# The above join can produce duplicates, if a target row surrounds multiple source rows, so they are removed.
target_records_to_be_updated = target_records_to_be_updated.unique(
subset=primary_key_columns + [valid_from_column]
)
print(target_records_to_be_updated)

# Remove columns from the source df.
target_records_to_be_updated = (
target_records_to_be_updated
.select(~cs.ends_with("__source"))
.join(
source_df,
on=primary_key_columns + [valid_from_column],
how="anti",
)
.select(target_columns)
)
print(target_records_to_be_updated)


upsert_df: PolarsFrame = pl.concat([target_records_to_be_updated, source_df])

# Calculate the valid to column.
# We do this for both the source and target rows, because target rows may need to be updated.
upsert_df = upsert_df.with_columns(
pl.col(valid_from_column)
.shift(-1)
.over(partition_by=primary_key_columns, order_by=valid_from_column)
.alias(valid_to_column)
)
print(upsert_df.sort(primary_key_columns + [valid_from_column]))

return upsert_df

0 comments on commit 2f939af

Please sign in to comment.