Skip to content

Commit

Permalink
Merge pull request #29 from Merlin-san/cowvalue-symbol-leak-fix
Browse files Browse the repository at this point in the history
Fix issue with COWValues leaking symbol allocations
  • Loading branch information
MerlinVR authored May 20, 2020
2 parents 82b3995 + 9582481 commit 7278c89
Show file tree
Hide file tree
Showing 11 changed files with 4,556 additions and 31 deletions.
23 changes: 13 additions & 10 deletions Assets/UdonSharp/Editor/UdonSharpASTVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,23 @@ public ASTVisitorContext(ResolverContext resolver, SymbolTable rootTable, LabelT
}
}

public void PopTable()
{
if (symbolTableStack.Count == 1)
throw new System.Exception("Cannot pop root table, mismatched scope entry and exit!");

symbolTableStack.Pop();
}

public void PushTable(SymbolTable newTable)
{
if (newTable.parentSymbolTable != topTable)
throw new System.ArgumentException("Parent symbol table is not valid for given context.");

symbolTableStack.Push(newTable);
newTable.OpenSymbolTable();
}

public void PopTable()
{
if (symbolTableStack.Count == 1)
throw new System.Exception("Cannot pop root table, mismatched scope entry and exit!");

SymbolTable table = symbolTableStack.Pop();
table.CloseSymbolTable();
}

public void PushCaptureScope(ExpressionCaptureScope captureScope)
Expand Down Expand Up @@ -493,7 +496,7 @@ public override void VisitArrayCreationExpression(ArrayCreationExpressionSyntax
using (ExpressionCaptureScope arraySetIdxScope = new ExpressionCaptureScope(visitorContext, null))
{
arraySetIdxScope.SetToLocalSymbol(arraySymbol);
using (SymbolDefinition.COWValue arrayIndex = visitorContext.topTable.CreateConstSymbol(typeof(int), i).GetCOWValue(visitorContext.uasmBuilder, visitorContext.topTable))
using (SymbolDefinition.COWValue arrayIndex = visitorContext.topTable.CreateConstSymbol(typeof(int), i).GetCOWValue(visitorContext))
{
arraySetIdxScope.HandleArrayIndexerAccess(arrayIndex);
}
Expand Down Expand Up @@ -590,7 +593,7 @@ public override void VisitImplicitArrayCreationExpression(ImplicitArrayCreationE
using (ExpressionCaptureScope arrayIdxSetScope = new ExpressionCaptureScope(visitorContext, null))
{
arrayIdxSetScope.SetToLocalSymbol(arraySymbol);
using (SymbolDefinition.COWValue arrayIndex = visitorContext.topTable.CreateConstSymbol(typeof(int), i).GetCOWValue(visitorContext.uasmBuilder, visitorContext.topTable))
using (SymbolDefinition.COWValue arrayIndex = visitorContext.topTable.CreateConstSymbol(typeof(int), i).GetCOWValue(visitorContext))
{
arrayIdxSetScope.HandleArrayIndexerAccess(arrayIndex);
}
Expand Down Expand Up @@ -2190,7 +2193,7 @@ public override void VisitForEachStatement(ForEachStatementSyntax node)
using (ExpressionCaptureScope indexAccessExecuteScope = new ExpressionCaptureScope(visitorContext, null))
{
indexAccessExecuteScope.SetToLocalSymbol(arraySymbol);
using (SymbolDefinition.COWValue arrayIndex = indexSymbol.GetCOWValue(visitorContext.uasmBuilder, visitorContext.topTable))
using (SymbolDefinition.COWValue arrayIndex = indexSymbol.GetCOWValue(visitorContext))
{
indexAccessExecuteScope.HandleArrayIndexerAccess(arrayIndex, valueSymbol);
}
Expand Down
4 changes: 4 additions & 0 deletions Assets/UdonSharp/Editor/UdonSharpCompilationModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ public int Compile(List<ClassDefinition> classDefinitions)
return errorCount;
}

moduleSymbols.OpenSymbolTable();

Profiler.BeginSample("Visit");
UdonSharpFieldVisitor fieldVisitor = new UdonSharpFieldVisitor(fieldsWithInitializers);
fieldVisitor.Visit(tree.GetRoot());
Expand Down Expand Up @@ -141,6 +143,8 @@ public int Compile(List<ClassDefinition> classDefinitions)
}
Profiler.EndSample();

moduleSymbols.CloseSymbolTable();

if (errorCount == 0)
{
compiledClassDefinition = classDefinitions.Find(e => e.userClassType == visitor.visitorContext.behaviourUserType);
Expand Down
5 changes: 5 additions & 0 deletions Assets/UdonSharp/Editor/UdonSharpCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ private List<ClassDefinition> BuildClassDefinitions()

ResolverContext resolver = new ResolverContext();
SymbolTable classSymbols = new SymbolTable(resolver, null);

classSymbols.OpenSymbolTable();

LabelTable classLabels = new LabelTable();

SyntaxTree tree = CSharpSyntaxTree.ParseText(programSource);
Expand All @@ -394,6 +397,8 @@ private List<ClassDefinition> BuildClassDefinitions()
return null;
}

classSymbols.CloseSymbolTable();

classVisitor.classDefinition.classScript = udonSharpProgram.sourceCsScript;
classDefinitions.Add(classVisitor.classDefinition);
}
Expand Down
16 changes: 8 additions & 8 deletions Assets/UdonSharp/Editor/UdonSharpExpressionCapture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ public SymbolDefinition.COWValue ExecuteGetCOW()
{
if (cowValue == null)
{
cowValue = ExecuteGet().GetCOWValue(visitorContext.uasmBuilder, visitorContext.topTable);
cowValue = ExecuteGet().GetCOWValue(visitorContext);
}

return cowValue;
Expand Down Expand Up @@ -850,7 +850,7 @@ private SymbolDefinition[] GetExpandedInvokeParams(MethodBase targetMethod, Symb
using (ExpressionCaptureScope paramArraySetterScope = new ExpressionCaptureScope(visitorContext, null))
{
paramArraySetterScope.SetToLocalSymbol(paramsArraySymbol);
using (SymbolDefinition.COWValue arrayIndex = arrayIndexSymbol.GetCOWValue(visitorContext.uasmBuilder, visitorContext.topTable)) {
using (SymbolDefinition.COWValue arrayIndex = arrayIndexSymbol.GetCOWValue(visitorContext)) {
paramArraySetterScope.HandleArrayIndexerAccess(arrayIndex);
}
paramArraySetterScope.ExecuteSet(invokeParams[j]);
Expand Down Expand Up @@ -971,7 +971,7 @@ private SymbolDefinition HandleGenericGetComponentSingle(SymbolDefinition compon
using (ExpressionCaptureScope componentArrayGetter = new ExpressionCaptureScope(visitorContext, null))
{
componentArrayGetter.SetToLocalSymbol(componentArray);
using (SymbolDefinition.COWValue arrayIndexValue = arrayIndex.GetCOWValue(visitorContext.uasmBuilder, visitorContext.topTable))
using (SymbolDefinition.COWValue arrayIndexValue = arrayIndex.GetCOWValue(visitorContext))
{
componentArrayGetter.HandleArrayIndexerAccess(arrayIndexValue);
}
Expand Down Expand Up @@ -1093,7 +1093,7 @@ private SymbolDefinition HandleGenericGetComponentArray(SymbolDefinition compone
using (ExpressionCaptureScope componentArrayGetter = new ExpressionCaptureScope(visitorContext, null))
{
componentArrayGetter.SetToLocalSymbol(componentArray);
using (SymbolDefinition.COWValue arrayIndexValue = arrayIndex.GetCOWValue(visitorContext.uasmBuilder, visitorContext.topTable))
using (SymbolDefinition.COWValue arrayIndexValue = arrayIndex.GetCOWValue(visitorContext))
{
componentArrayGetter.HandleArrayIndexerAccess(arrayIndexValue);
}
Expand Down Expand Up @@ -1209,7 +1209,7 @@ private SymbolDefinition HandleGenericGetComponentArray(SymbolDefinition compone
using (ExpressionCaptureScope componentArrayGetter = new ExpressionCaptureScope(visitorContext, null))
{
componentArrayGetter.SetToLocalSymbol(componentArray);
using (SymbolDefinition.COWValue arrayIndexValue = arrayIndex.GetCOWValue(visitorContext.uasmBuilder, visitorContext.topTable))
using (SymbolDefinition.COWValue arrayIndexValue = arrayIndex.GetCOWValue(visitorContext))
{
componentArrayGetter.HandleArrayIndexerAccess(arrayIndexValue);
}
Expand Down Expand Up @@ -1242,15 +1242,15 @@ private SymbolDefinition HandleGenericGetComponentArray(SymbolDefinition compone
using (ExpressionCaptureScope setArrayValueScope = new ExpressionCaptureScope(visitorContext, null))
{
setArrayValueScope.SetToLocalSymbol(resultSymbol);
using (SymbolDefinition.COWValue destIdxValue = destIdxSymbol.GetCOWValue(visitorContext.uasmBuilder, visitorContext.topTable))
using (SymbolDefinition.COWValue destIdxValue = destIdxSymbol.GetCOWValue(visitorContext))
{
setArrayValueScope.HandleArrayIndexerAccess(destIdxValue);
}

using (ExpressionCaptureScope sourceValueGetScope = new ExpressionCaptureScope(visitorContext, null))
{
sourceValueGetScope.SetToLocalSymbol(componentArray);
using (SymbolDefinition.COWValue arrayIndexValue = arrayIndex.GetCOWValue(visitorContext.uasmBuilder, visitorContext.topTable)) {
using (SymbolDefinition.COWValue arrayIndexValue = arrayIndex.GetCOWValue(visitorContext)) {
sourceValueGetScope.HandleArrayIndexerAccess(arrayIndexValue);
}

Expand Down Expand Up @@ -2191,7 +2191,7 @@ public void HandleArrayIndexerAccess(SymbolDefinition.COWValue indexerValue, Sym
else
{
// We needed to do a cast, so create a COW value for the post-cast value
using (SymbolDefinition.COWValue cowIndex = indexerSymbol.GetCOWValue(visitorContext.uasmBuilder, visitorContext.topTable))
using (SymbolDefinition.COWValue cowIndex = indexerSymbol.GetCOWValue(visitorContext))
{
arrayIndexerIndexValue = cowIndex;
}
Expand Down
134 changes: 121 additions & 13 deletions Assets/UdonSharp/Editor/UdonSharpSymbolTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,28 @@ public void AssertCOWClosed()
#endif
}

public COWValue GetCOWValue(AssemblyBuilder assemblyBuilder, SymbolTable symbolTable)
public COWValue GetCOWValue(ASTVisitorContext visitorContext)
{
if (cowValue != null)
{
if (cowValue.assemblyBuilder != assemblyBuilder)
if (cowValue.visitorContext != visitorContext)
{
// Hmm... new compilation context? Dirty it and get a new one.
cowValue.MarkDirty();
cowValue = null;
} else if (cowValue.isDirty) {
}
else if (cowValue.isDirty || cowValue.referenceCount == 0)
{
// If the reference count is 0, we've probably moved scopes. We clear out the cowValue here to make sure that a cowValue is only used in one scope at a time.
cowValue = null;
} else {
}
else
{
return new COWValue(cowValue);
}
}

cowValue = new COWValueInternal(assemblyBuilder, this, symbolTable);
cowValue = new COWValueInternal(visitorContext, this);
return new COWValue(cowValue);
}

Expand All @@ -137,25 +142,25 @@ internal class COWValueInternal

public int instanceIndex = index++;

public AssemblyBuilder assemblyBuilder;

public int referenceCount = 0;
public bool isDirty = false;

public SymbolDefinition symbol { get; private set; } = null;
public SymbolDefinition originalSymbol { get; private set; } = null;

public SymbolTable symbolTable;
public ASTVisitorContext visitorContext { get; private set; } = null;

#if UDONSHARP_DEBUG
private HashSet<COWValue> holders = new HashSet<COWValue>();
#endif
private SymbolTable tableCreationScope;

public COWValueInternal(AssemblyBuilder assemblyBuilder, SymbolDefinition symbol, SymbolTable table)
public COWValueInternal(ASTVisitorContext visitorContext, SymbolDefinition symbol)
{
this.assemblyBuilder = assemblyBuilder;
this.symbol = this.originalSymbol = symbol;
this.symbolTable = table;
this.visitorContext = visitorContext;

tableCreationScope = visitorContext.topTable;
}

public void AddRef(COWValue holder)
Expand All @@ -164,6 +169,11 @@ public void AddRef(COWValue holder)
#if UDONSHARP_DEBUG
holders.Add(holder);
#endif

if (visitorContext.topTable != tableCreationScope)
{
throw new Exception($"COWSymbolValue for {originalSymbol} has had ref added from different symbol table scope.");
}
}

public void ClearRef(COWValue holder)
Expand All @@ -175,6 +185,11 @@ public void ClearRef(COWValue holder)
throw new Exception("No matching holder for COWValue");
}
#endif

if (visitorContext.topTable != tableCreationScope)
{
throw new Exception($"COWSymbolValue for {originalSymbol} has been disposed from different symbol table scope.");
}
}

public void AssertNoLeaks()
Expand All @@ -201,8 +216,8 @@ public void MarkDirty()

if (!isDirty)
{
SymbolDefinition temporary = symbolTable.CreateUnnamedSymbol(symbol.internalType, SymbolDeclTypeFlags.Internal | SymbolDeclTypeFlags.Local);
assemblyBuilder.AddCopy(temporary, symbol, " Copy-on-write symbol value dirtied");
SymbolDefinition temporary = visitorContext.topTable.CreateUnnamedSymbol(symbol.internalType, SymbolDeclTypeFlags.Internal | SymbolDeclTypeFlags.Local);
visitorContext.uasmBuilder.AddCopy(temporary, symbol, " Copy-on-write symbol value dirtied");
symbol = temporary;
isDirty = true;
}
Expand Down Expand Up @@ -278,6 +293,14 @@ public class SymbolTable

private Dictionary<string, int> namedSymbolCounters;

private bool IsTableReadOnly = true;

#if UDONSHARP_DEBUG
private System.Diagnostics.StackTrace creationTrace;
#endif

private List<(SymbolTable, Dictionary<string, int>)> initialSymbolCounters = new List<(SymbolTable, Dictionary<string, int>)>();

public SymbolTable GetGlobalSymbolTable()
{
SymbolTable currentTable = this;
Expand All @@ -300,6 +323,30 @@ public SymbolTable(ResolverContext resolverContext, SymbolTable parentTable)

symbolDefinitions = new List<SymbolDefinition>();
namedSymbolCounters = new Dictionary<string, int>();

#if UDONSHARP_DEBUG
creationTrace = new System.Diagnostics.StackTrace(true);
#endif
}

public void OpenSymbolTable()
{
IsTableReadOnly = false;

// Copy the current symbol counters for checking when the symbol table has been closed
SymbolTable currentTable = parentSymbolTable;
while (currentTable != null)
{
initialSymbolCounters.Add((currentTable, new Dictionary<string, int>(currentTable.namedSymbolCounters)));
currentTable = currentTable.parentSymbolTable;
}
}

public void CloseSymbolTable()
{
IsTableReadOnly = true;

ValidateParentTableCounters();
}

protected int IncrementUniqueNameCounter(string symbolName)
Expand Down Expand Up @@ -334,6 +381,38 @@ public int GetUniqueNameCounter(string symbolName)
return -1;
}

public void ValidateParentTableCounters()
{
SymbolTable currentTable = parentSymbolTable;

int tableIdx = 0;

while (currentTable != null)
{
(SymbolTable, Dictionary<string, int>) counterPair = initialSymbolCounters[tableIdx];

if (counterPair.Item1 != currentTable)
throw new Exception("Table mismatch, parent tables have changed during the lifetime of a symbol table.");

Dictionary<string, int> initialCounters = counterPair.Item2;

if (!currentTable.IsGlobalSymbolTable)
{
foreach (var currentCounters in currentTable.namedSymbolCounters)
{
if (!initialCounters.ContainsKey(currentCounters.Key))
throw new Exception($"Counter for symbol {currentCounters.Key} has been added while table is not valid for modification.");

if (initialCounters[currentCounters.Key] != currentCounters.Value)
throw new Exception($"Counter for symbol {currentCounters.Key} in symbol table has been modified while table is not valid for modification.");
}
}

++tableIdx;
currentTable = currentTable.parentSymbolTable;
}
}

/// <summary>
/// This function expects the given symbolName to have some marker to indicate that they are global-only
/// in order to prevent collisions with child symbol table symbols.
Expand Down Expand Up @@ -626,6 +705,35 @@ private SymbolDefinition CreateNamedSymbolInternal(string symbolName, System.Typ
symbolDefinitions.Add(symbolDefinition);
}

#if UDONSHARP_DEBUG
if (IsTableReadOnly)
throw new Exception($"Cannot add symbol {symbolDefinition}, symbol table is readonly. Symbol Table creation stacktrace \n\n{creationTrace}");
#else
if (IsTableReadOnly)
throw new Exception($"Cannot add symbol {symbolDefinition}, symbol table is readonly.");
#endif

if (IsGlobalSymbolTable)
{
bool anyChildTableOpen = false;
foreach (SymbolTable childTable in childSymbolTables)
{
if (!childTable.IsTableReadOnly)
{
anyChildTableOpen = true;
break;
}
}

if (anyChildTableOpen)
{
if (!declType.HasFlag(SymbolDeclTypeFlags.Reflection) &&
!declType.HasFlag(SymbolDeclTypeFlags.Constant) &&
!declType.HasFlag(SymbolDeclTypeFlags.This))
throw new Exception($"Cannot add symbol {symbolDefinition} to root table while other tables are in use.");
}
}

return symbolDefinition;
}

Expand Down
Loading

0 comments on commit 7278c89

Please sign in to comment.