Skip to content

Commit

Permalink
more support for concrete syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
robertmuth committed Jun 1, 2024
1 parent c4117c8 commit dd1ca38
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 48 deletions.
95 changes: 50 additions & 45 deletions FrontEnd/LangTest/sum_untagged_test.cw
Original file line number Diff line number Diff line change
@@ -1,101 +1,116 @@
@doc "union"
@doc "union"
(module main [] :
(import test)


(@wrapped type t1 s32)


(@wrapped type t2 void)


(@wrapped type t3 void)


(type type_ptr (ptr! s32))


(type UntaggedUnion1 (@untagged union [
s32
void
type_ptr]))
(type UntaggedUnion1 (@untagged union [s32 void type_ptr]))


(static_assert (== (sizeof UntaggedUnion1) 8))


(type UntaggedUnion2 (@untagged union [
s32
void
(@untagged union [UntaggedUnion1 u8])]))
(type UntaggedUnion2 (@untagged union [s32 void (@untagged union [UntaggedUnion1 u8])]))


(static_assert (== (sizeof UntaggedUnion2) 8))


(type UntaggedUnion3 (@untagged union [
bool
s32
s64]))
(type UntaggedUnion3 (@untagged union [bool s32 s64]))


(static_assert (== (sizeof UntaggedUnion2) 8))


(type UntaggedUnion4 (@untagged union [bool s32]))


(static_assert (== (sizeof UntaggedUnion4) 4))


@pub (type UntaggedUnion5 (@untagged union [
t2
t3
s8]))
@pub (type UntaggedUnion5 (@untagged union [t2 t3 s8]))


(static_assert (== (sizeof UntaggedUnion5) 1))


(type UntaggedUnion6 (@untagged union [bool u16]))
(type UntaggedUnion6 (@untagged union [bool u16]))


(static_assert (== (sizeof UntaggedUnion6) 2))


(type UntaggedUnion (@untagged union [
bool
u64
u32
r32
r64
(array 32 u8)]))


(type TaggedUnion (union [
bool
u64
u32
r32
r64
(array 32 u8)]))

(type UntaggedUnion (@untagged union [bool u64 u32 r32 r64 (array 32 u8)]))
(type TaggedUnion (union [bool u64 u32 r32 r64 (array 32 u8)]))

(static_assert (== (sizeof UntaggedUnion) 32))


(defrec RecordWithUntaggedUnion :
(field t1 bool)
(field t2 u32)
(field t3 UntaggedUnion)
(field t4 bool))



(fun with_union_result [(param a bool) (param b u32) (param c r32)] UntaggedUnion :
(fun with_union_result [
(param a bool)
(param b u32)
(param c r32)] UntaggedUnion :
(let! out UntaggedUnion undef)
(if a :
(= out b)
: (= out c))
(return out)
)
:
(= out c))
(return out))


(fun test_untagged_union [] void :
@doc "straight up union"
(let! u1 UntaggedUnion)
(let! u2 UntaggedUnion undef)
(let! u3 UntaggedUnion 2.0_r32)
(let! u4 UntaggedUnion 777_u32)

(let s1 u32 (narrowto u3 u32))
(test::AssertEq# s1 0x40000000_u32)
(test::AssertEq# (at (narrowto u3 (array 32 u8)) 0) 0_u8)
(test::AssertEq# (at (narrowto u3 (array 32 u8)) 1) 0_u8)
(test::AssertEq# (at (narrowto u3 (array 32 u8)) 2) 0_u8)
(test::AssertEq# (at (narrowto u3 (array 32 u8)) 3) 0x40_u8)

(= (at (narrowto u3 (array 32 u8)) 2) 0x28_u8)
(= (at (narrowto u3 (array 32 u8)) 3) 0x42_u8)
(test::AssertEq# (narrowto u3 u32) 0x42280000_u32)
(test::AssertEq# (narrowto u3 r32) 42_r32)

(= u3 2.0_r64)
(test::AssertEq# (narrowto u3 u64) 0x4000000000000000_u64)
(test::AssertEq# (at (narrowto u3 (array 32 u8)) 3) 0_u8)
(test::AssertEq# (at (narrowto u3 (array 32 u8)) 7) 0x40_u8)

@doc "union embedded in record"
(let! rec1 RecordWithUntaggedUnion undef)
(= (. rec1 t3) 2.0_r32)
Expand All @@ -104,30 +119,25 @@
(test::AssertEq# (at (narrowto (. rec1 t3) (array 32 u8)) 1) 0_u8)
(test::AssertEq# (at (narrowto (. rec1 t3) (array 32 u8)) 2) 0_u8)
(test::AssertEq# (at (narrowto (. rec1 t3) (array 32 u8)) 3) 0x40_u8)

@doc "union embedded in record 2"
(let! rec2 auto (rec_val RecordWithUntaggedUnion [
(field_val false)
(field_val 0x12344321)
(field_val 2.0_r32)
(field_val true)]))
(field_val false)
(field_val 0x12344321)
(field_val 2.0_r32)
(field_val true)]))
(test::AssertEq# (. rec2 t1) false)
(test::AssertEq# (. rec2 t2) 0x12344321_u32)
(test::AssertEq# (narrowto (. rec2 t3) u32) 0x40000000_u32)
(test::AssertEq# (. rec2 t4) true)


@doc ""
(= (at (narrowto (. rec1 t3) (array 32 u8)) 2) 0x28_u8)
(= (at (narrowto (. rec1 t3) (array 32 u8)) 3) 0x42_u8)
(test::AssertEq# (narrowto (. rec1 t3) u32) 0x42280000_u32)
(test::AssertEq# (narrowto (. rec1 t3) r32) 42_r32)

(= (. rec1 t3) 2.0_r64)
(test::AssertEq# (narrowto (. rec1 t3) u64) 0x4000000000000000_u64)
(test::AssertEq# (at (narrowto (. rec1 t3) (array 32 u8)) 3) 0_u8)
(test::AssertEq# (at (narrowto (. rec1 t3) (array 32 u8)) 7) 0x40_u8)

@doc "array of union"
(let! array1 (array 16 UntaggedUnion) undef)
(= (at array1 13) 2.0_r32)
Expand All @@ -136,29 +146,24 @@
(test::AssertEq# (at (narrowto (at array1 13) (array 32 u8)) 1) 0_u8)
(test::AssertEq# (at (narrowto (at array1 13) (array 32 u8)) 2) 0_u8)
(test::AssertEq# (at (narrowto (at array1 13) (array 32 u8)) 3) 0x40_u8)

(= (at (narrowto (at array1 13) (array 32 u8)) 2) 0x28_u8)
(= (at (narrowto (at array1 13) (array 32 u8)) 3) 0x42_u8)
(test::AssertEq# (narrowto (at array1 13) u32) 0x42280000_u32)
(test::AssertEq# (narrowto (at array1 13) r32) 42_r32)

(= u1 (with_union_result [true 10 2.0]))
(test::AssertEq# (narrowto u1 u32) 10_u32)
(= u1 (with_union_result [false 10 2.0]))
(test::AssertEq# (narrowto u1 u32) 0x40000000_u32)

(= (at array1 13) 2.0_r64)
(test::AssertEq# (narrowto (at array1 13) u64) 0x4000000000000000_u64)
(test::AssertEq# (at (narrowto (at array1 13) (array 32 u8)) 3) 0_u8)
(test::AssertEq# (at (narrowto (at array1 13) (array 32 u8)) 7) 0x40_u8)
)
(test::AssertEq# (at (narrowto (at array1 13) (array 32 u8)) 7) 0x40_u8))


@cdecl (fun main [(param argc s32) (param argv (ptr (ptr u8)))] s32 :
(shed (test_untagged_union []))

@doc "test end"
(test::Success#)
(return 0))
)

)
1 change: 1 addition & 0 deletions FrontEnd/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ TESTS_CONCRETE = \
LangTest/defer_test.cw \
LangTest/enum_test.cw \
LangTest/sum_tagged_test.cw \
LangTest/sum_untagged_test.cw \
TestData/asciiquarium.cw \
TestData/binary_tree.cw \
TestData/cast.cw \
Expand Down
8 changes: 5 additions & 3 deletions FrontEnd/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,10 +616,12 @@ def _PParseFieldAccess(inp: Lexer, rec, _tk: TK, _precedence) -> Any:
field = inp.match_or_die(TK_KIND.ID)
return cwast.ExprField(rec, field.text)


def _PParseDerefFieldAccess(inp: Lexer, rec, _tk: TK, _precedence) -> Any:
field = inp.match_or_die(TK_KIND.ID)
return cwast.MacroInvoke("^.", [rec, cwast.Id(field.text)])


def _PParseTernary(inp: Lexer, cond, _tk: TK, _precedence) -> Any:
expr_t = _ParseExpr(inp)
inp.match_or_die(TK_KIND.COLON)
Expand Down Expand Up @@ -697,7 +699,7 @@ def _ParseTypeExpr(inp: Lexer):
inp.match_or_die(TK_KIND.COMMA)
first = False
members.append(_ParseTypeExpr(inp))
return cwast.TypeUnion(members)
return cwast.TypeUnion(members, **_ExtractAnnotations(tk))
kind = cwast.KeywordToBaseTypeKind(tk.text)
assert kind is not cwast.BASE_TYPE_KIND.INVALID, f"{tk}"
return cwast.TypeBase(kind)
Expand Down Expand Up @@ -837,7 +839,7 @@ def _ParseStatement(inp: Lexer):
kind = inp.next()
rhs = _ParseExpr(inp)
if kind.kind is TK_KIND.ASSIGN:
return cwast.StmtAssignment(lhs, rhs)
return cwast.StmtAssignment(lhs, rhs, ** _ExtractAnnotations(kw))
else:
assert kind.kind is TK_KIND.COMPOUND_ASSIGN, f"{kind}"
op = cwast.ASSIGNMENT_SHORTCUT[kind.text]
Expand Down Expand Up @@ -878,7 +880,7 @@ def _ParseStatement(inp: Lexer):
return cwast.MacroFor(var.text, container.text, stmts)
elif kw.text == "shed":
expr = _ParseExpr(inp)
return cwast.StmtExpr(expr,**_ExtractAnnotations(kw))
return cwast.StmtExpr(expr, **_ExtractAnnotations(kw))
elif kw.text == "trap":
return cwast.StmtTrap()
elif kw.text == "defer":
Expand Down

0 comments on commit dd1ca38

Please sign in to comment.