Skip to content

Commit

Permalink
Merge branch 'main' into fix/remove_update_iceberg_ts
Browse files Browse the repository at this point in the history
  • Loading branch information
aiss93 authored Oct 3, 2024
2 parents 07fa648 + abf8c52 commit fa84a57
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
## New version
- Allow to load big seed files
- Add a configuration to disable the adding ot the hard-coded update_iceberg_ts column when using merge incremental strategy with iceberg.

## v1.8.6
Expand Down
55 changes: 44 additions & 11 deletions dbt/adapters/glue/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,37 @@ def create_csv_table(self, model, agate_table):
else:
mode = "False"

code = f'''
custom_glue_code_for_dbt_adapter
csv = {json.loads(f.getvalue())}
csv_chunks = self._split_csv_records_into_chunks(json.loads(f.getvalue()))
statements = self._map_csv_chunks_to_code(csv_chunks, session, model, mode)
try:
cursor = session.cursor()
for statement in statements:
cursor.execute(statement)
except DbtDatabaseError as e:
raise DbtDatabaseError(msg="GlueCreateCsvFailed") from e
except Exception as e:
logger.error(e)

def _map_csv_chunks_to_code(self, csv_chunks: List[List[dict]], session: GlueConnection, model, mode):
statements = []
for i, csv_chunk in enumerate(csv_chunks):
is_first = i == 0
is_last = i == len(csv_chunks) - 1
code = "custom_glue_code_for_dbt_adapter\n"
if is_first:
code += f"""
csv = {csv_chunk}
"""
else:
code += f"""
csv.extend({csv_chunk})
"""
if not is_last:
code += f'''
SqlWrapper2.execute("""select 1""")
'''
else:
code += f'''
df = spark.createDataFrame(csv)
table_name = '{model["schema"]}.{model["name"]}'
if (spark.sql("show tables in {model["schema"]}").where("tableName == lower('{model["name"]}')").count() > 0):
Expand All @@ -551,20 +579,25 @@ def create_csv_table(self, model, agate_table):
.saveAsTable(table_name)
SqlWrapper2.execute("""select * from {model["schema"]}.{model["name"]} limit 1""")
'''
try:
session.cursor().execute(code)
except DbtDatabaseError as e:
raise DbtDatabaseError(msg="GlueCreateCsvFailed") from e
except Exception as e:
logger.error(e)
statements.append(code)
return statements

def _split_csv_records_into_chunks(self, records: List[dict], target_size=60000):
chunks = [[]]
for record in records:
if len(str([*chunks[-1], record])) > target_size:
chunks.append([record])
else:
chunks[-1].append(record)
return chunks

def _update_additional_location(self, target_relation, location):
session, client = self.get_connection()
table_input = {}
try:
table_input = client.get_table(
DatabaseName=f'{target_relation.schema}',
Name=f'{session.credentials.delta_athena_prefix}_{target_relation.name}',
DatabaseName=f"{target_relation.schema}",
Name=f"{session.credentials.delta_athena_prefix}_{target_relation.name}",
).get("Table", {})
except client.exceptions.EntityNotFoundException as e:
logger.debug(e)
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, Dict, Optional
import unittest
from unittest import mock
from unittest.mock import Mock
from multiprocessing import get_context
from botocore.client import BaseClient
from moto import mock_aws

import agate
from dbt.config import RuntimeConfig

import dbt.flags as flags
Expand Down Expand Up @@ -86,3 +88,15 @@ def test_get_table_type(self):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
self.assertEqual(adapter.get_table_type(target_relation), "iceberg_table")

def test_create_csv_table_slices_big_datasets(self):
config = self._get_config()
adapter = GlueAdapter(config, get_context("spawn"))
model = {"name": "mock_model", "schema": "mock_schema"}
session_mock = Mock()
adapter.get_connection = lambda: (session_mock, 'mock_client')
test_table = agate.Table([(f'mock_value_{i}',f'other_mock_value_{i}') for i in range(2000)], column_names=['value', 'other_value'])
adapter.create_csv_table(model, test_table)

# test table is between 120000 and 180000 characters so it should be split three times (max chunk is 60000)
self.assertEqual(session_mock.cursor().execute.call_count, 3)

0 comments on commit fa84a57

Please sign in to comment.