Skip to content

Commit

Permalink
RegisterTableFunction errors handling #234
Browse files Browse the repository at this point in the history
  • Loading branch information
VitaliyMF committed Nov 25, 2024
1 parent 079dc76 commit 7ea612c
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ public static class TableFunction
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_function_get_bind_data")]
public static extern unsafe IntPtr DuckDBFunctionGetBindData(IntPtr info);

[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_function_set_error")]
public static extern unsafe void DuckDBFunctionSetError(IntPtr info, SafeUnmanagedMemoryHandle error);

#endregion
}
}
127 changes: 73 additions & 54 deletions DuckDB.NET.Data/DuckDBConnection.TableFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,45 +102,54 @@ private unsafe void RegisterTableFunctionInternal(string name, Func<IReadOnlyLis
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
public static unsafe void Bind(IntPtr info)
{
var handle = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBBindGetExtraInfo(info));

if (handle.Target is not TableFunctionInfo functionInfo)
IDuckDBValueReader[]? parameters = null;
try
{
throw new InvalidOperationException("User defined table function bind failed. Bind extra info is null");
}
var handle = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBBindGetExtraInfo(info));

var parameters = new IDuckDBValueReader[NativeMethods.TableFunction.DuckDBBindGetParameterCount(info)];
if (handle.Target is not TableFunctionInfo functionInfo)
{
throw new InvalidOperationException("User defined table function bind failed. Bind extra info is null");
}

for (var i = 0; i < parameters.Length; i++)
{
var value = NativeMethods.TableFunction.DuckDBBindGetParameter(info, (ulong)i);
parameters[i] = value;
}
parameters = new IDuckDBValueReader[NativeMethods.TableFunction.DuckDBBindGetParameterCount(info)];

for (var i = 0; i < parameters.Length; i++)
{
var value = NativeMethods.TableFunction.DuckDBBindGetParameter(info, (ulong)i);
parameters[i] = value;
}

var tableFunctionData = functionInfo.Bind(parameters);

TableFunction tableFunctionData;
try {
tableFunctionData = functionInfo.Bind(parameters);
} catch (Exception ex) {
foreach (var columnInfo in tableFunctionData.Columns)
{
using var logicalType = DuckDBTypeMap.GetLogicalType(columnInfo.Type);
NativeMethods.TableFunction.DuckDBBindAddResultColumn(info, columnInfo.Name.ToUnmanagedString(), logicalType);
}

var bindData = new TableFunctionBindData(tableFunctionData.Columns, tableFunctionData.Data.GetEnumerator());

NativeMethods.TableFunction.DuckDBBindSetBindData(info, bindData.ToHandle(), &DestroyExtraInfo);
}
catch (Exception ex)
{
using (var errMsgHandle = ex.Message.ToUnmanagedString())
{
NativeMethods.TableFunction.DuckDBBindSetError(info, errMsgHandle);
}
return;
} finally {
foreach (var parameter in parameters) {
((DuckDBValue)parameter).Dispose();
}
}

foreach (var columnInfo in tableFunctionData.Columns)
finally
{
using var logicalType = DuckDBTypeMap.GetLogicalType(columnInfo.Type);
NativeMethods.TableFunction.DuckDBBindAddResultColumn(info, columnInfo.Name.ToUnmanagedString(), logicalType);
if (parameters!=null)
foreach (var parameter in parameters)
{
if (parameter != null)
((DuckDBValue)parameter).Dispose();
}
}

var bindData = new TableFunctionBindData(tableFunctionData.Columns, tableFunctionData.Data.GetEnumerator());

NativeMethods.TableFunction.DuckDBBindSetBindData(info, bindData.ToHandle(), &DestroyExtraInfo);
}

[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
Expand All @@ -149,46 +158,56 @@ public static void Init(IntPtr info) { }
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
public static void TableFunction(IntPtr info, IntPtr chunk)
{
var bindData = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBFunctionGetBindData(info));
var extraInfo = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBFunctionGetExtraInfo(info));

if (bindData.Target is not TableFunctionBindData tableFunctionBindData)
try
{
throw new InvalidOperationException("User defined table function failed. Function bind data is null");
}
var bindData = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBFunctionGetBindData(info));
var extraInfo = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBFunctionGetExtraInfo(info));

if (extraInfo.Target is not TableFunctionInfo tableFunctionInfo)
{
throw new InvalidOperationException("User defined table function failed. Function extra info is null");
}
if (bindData.Target is not TableFunctionBindData tableFunctionBindData)
{
throw new InvalidOperationException("User defined table function failed. Function bind data is null");
}

var dataChunk = new DuckDBDataChunk(chunk);
if (extraInfo.Target is not TableFunctionInfo tableFunctionInfo)
{
throw new InvalidOperationException("User defined table function failed. Function extra info is null");
}

var writers = new VectorDataWriterBase[tableFunctionBindData.Columns.Count];
for (var columnIndex = 0; columnIndex < tableFunctionBindData.Columns.Count; columnIndex++)
{
var column = tableFunctionBindData.Columns[columnIndex];
var vector = NativeMethods.DataChunks.DuckDBDataChunkGetVector(dataChunk, columnIndex);
var dataChunk = new DuckDBDataChunk(chunk);

using var logicalType = DuckDBTypeMap.GetLogicalType(column.Type);
writers[columnIndex] = VectorDataWriterFactory.CreateWriter(vector, logicalType);
}
var writers = new VectorDataWriterBase[tableFunctionBindData.Columns.Count];
for (var columnIndex = 0; columnIndex < tableFunctionBindData.Columns.Count; columnIndex++)
{
var column = tableFunctionBindData.Columns[columnIndex];
var vector = NativeMethods.DataChunks.DuckDBDataChunkGetVector(dataChunk, columnIndex);

ulong size = 0;
using var logicalType = DuckDBTypeMap.GetLogicalType(column.Type);
writers[columnIndex] = VectorDataWriterFactory.CreateWriter(vector, logicalType);
}

for (; size < DuckDBGlobalData.VectorSize; size++)
{
if (tableFunctionBindData.DataEnumerator.MoveNext())
ulong size = 0;

for (; size < DuckDBGlobalData.VectorSize; size++)
{
tableFunctionInfo.Mapper(tableFunctionBindData.DataEnumerator.Current, writers, size);
if (tableFunctionBindData.DataEnumerator.MoveNext())
{
tableFunctionInfo.Mapper(tableFunctionBindData.DataEnumerator.Current, writers, size);
}
else
{
break;
}
}
else

NativeMethods.DataChunks.DuckDBDataChunkSetSize(dataChunk, size);
}
catch (Exception ex)
{
using (var errMsgHandle = ex.Message.ToUnmanagedString())
{
break;
NativeMethods.TableFunction.DuckDBFunctionSetError(info, errMsgHandle);
}
}

NativeMethods.DataChunks.DuckDBDataChunkSetSize(dataChunk, size);
}
#endif
}
17 changes: 7 additions & 10 deletions DuckDB.NET.Test/TableFunctionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,10 @@ public void RegisterTableFunctionWithErrors() {
}, (item, writer, rowIndex) => {
});

Assert.Throws<DuckDBException>(() => {
try {
Assert.Contains("bind_err_msg",
Assert.Throws<DuckDBException>(() => {
var data = Connection.Query<int>($"SELECT * FROM bind_err('')").ToList();
} catch (Exception ex) {
Assert.Contains("bind_err_msg", ex.Message);
throw;
}
});
}).Message);

Connection.RegisterTableFunction<string>("map_err", parameters => {
return new TableFunction(
Expand All @@ -199,9 +195,10 @@ public void RegisterTableFunctionWithErrors() {
}, (item, writer, rowIndex) => {
throw new NotSupportedException("map_err_msg");
});
Assert.Throws<NotSupportedException>(() => {
var data = Connection.Query<int>($"SELECT * FROM map_err('')").ToList();
});
Assert.Contains("map_err_msg",
Assert.Throws<DuckDBException>(() => {
var data = Connection.Query<int>($"SELECT * FROM map_err('')").ToList();
}).Message);
}

}

0 comments on commit 7ea612c

Please sign in to comment.