Skip to content

Commit

Permalink
Add implementation for grammar.Frame
Browse files Browse the repository at this point in the history
Co-authored-by: nikhilkhatri <nikhil.khatri@quantinuum.com>
  • Loading branch information
neiljdo and nikhilkhatri committed Oct 23, 2024
1 parent 8109d95 commit ef3fe12
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 1 deletion.
103 changes: 102 additions & 1 deletion lambeq/backend/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def rotate(self, z: int) -> Diagrammable:
"""

def dagger(self) -> Diagrammable:
"""Implements conjugation of diagrams."""
"""Apply the dagger operation."""

def __matmul__(self, rhs: Diagrammable | Ty) -> Diagrammable:
"""Implements the tensor operator `@` with another diagram."""
Expand Down Expand Up @@ -750,6 +750,12 @@ def is_id(self) -> bool:
def boxes(self) -> list[Box]:
return [layer.box for layer in self.layers]

@property
def has_frames(self) -> bool:
return any([(isinstance(box, Frame)
or isinstance(box, DaggeredFrame))
for box in self.boxes])

@classmethod
def create_pregroup_diagram(
cls,
Expand Down Expand Up @@ -1986,3 +1992,98 @@ def ar(self, ar: Box) -> Diagrammable:
'use the functor on boxes.')

return self.custom_ar(self, ar)


@Diagram.register_special_box('frame')
@dataclass
class Frame(Box):
"""A frame in the grammar category.
It can contain other diagrams as its components. Frame is
an abstract container, which means that the relationship
between its domain/codomain with those of the individual
nested diagrams remains undefined at this level, and is
left to be implemented by the application of purpose-specific
ansatze and rewriters.
Frames can be nested to an arbitrary depth.
Parameters
----------
name : str
The name of the frame.
dom : Ty
The domain of the frame.
cod : Ty
The codomain of the frame.
z : int, optional
The winding number of the frame, by default 0.
components : list of `Diagrammable`
The components inside this frame.
"""

name: str
dom: Ty
cod: Ty
components: list[Diagrammable] = field(default_factory=list)
z: int = 0

def __repr__(self):
return (f'Frame({self.name}, '
+ f'dom={self.dom}, '
+ f'cod={self.cod}, '
+ f'z={self.z}, '
+ 'components=['
+ ' @ '.join(map(repr, self.components)) + ']')

def rotate(self, z: int) -> Self:
"""Rotate the box, changing the winding number."""
return replace(self,
dom=self.dom.rotate(z),
cod=self.cod.rotate(z),
z=self.z + z,
components=[c.rotate(z) for c in
reversed(self.components)])

def dagger(self) -> DaggeredFrame | Frame:
return DaggeredFrame(self)

def __hash__(self) -> int:
return hash(repr(self))


@dataclass
class DaggeredFrame(Frame):
"""A daggered frame.
Parameters
----------
frame : Frame
The frame to be daggered.
"""

frame: Frame
name: str = field(init=False)
dom: Ty = field(init=False)
cod: Ty = field(init=False)
z: int = field(init=False)
components: list[Diagrammable] = field(init=False)

def __post_init__(self) -> None:
self.name = self.frame.name + '†'
self.dom = self.frame.cod
self.cod = self.frame.dom
self.z = self.frame.z
self.components = [c.dagger() for c in self.frame.components]

def rotate(self, z: int) -> Self:
"""Rotate the daggered frame."""
return type(self)(self.frame.rotate(z))

def dagger(self) -> Frame:
return self.frame

def __hash__(self) -> int:
return hash(repr(self))
84 changes: 84 additions & 0 deletions tests/backend/test_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,87 @@ def test_to_from_json():
assert case_json['category'] == 'grammar'
assert 'entity' in case_json
assert grammar.from_json(json.dumps(case_json)) == case


def test_frame():
n = Ty('n')
s = Ty('s')
d = ((Word('Alice', s) @ Word('runs', s.r @ s))
>> (Cup(s, s.r) @ Id(s)))

f = Frame(
'f1', n @ n, n @ n,
components=[
Box('b1', Ty(), n),
Box('b1', Ty(), Ty()),
Box('b1', n, Ty()),
d,
]
)
assert f.name == 'f1'
assert len(f.components) == 4


def test_diagram_has_frame():
n = Ty('n')
s = Ty('s')
d = ((Word('Alice', s) @ Word('runs', s.r @ s))
>> (Cup(s, s.r) @ Id(s)))

assert not d.has_frames

f = Frame(
'f1', n @ n, n @ n,
components=[
Box('b1', Ty(), n),
Box('b1', Ty(), Ty()),
Box('b1', n, Ty()),
d,
]
)
d @= f
assert d.has_frames


def test_frame_manipulation():
n, s = Ty('n'), Ty('s')
ba = Box('A', n, n @ s)
bb = Box('B', s, Ty())
f = Frame('F', n, s, 0, [ba >> n @ bb, bb])
fl = Frame('F', n.l, s.l, -1, [bb.l, (ba >> n @ bb).l])

assert f.l == fl
assert f.dagger().dagger() == f
assert f.dagger().l == f.l.dagger()


def test_frame_functor():
n, s = Ty('n'), Ty('s')
ba = Box('A', n, n @ s)
bb = Box('B', s, Ty())
ba_BOX = Box('BOX', n, n @ s)
bb_BOX = Box('BOX', s, Ty())
d = Frame('F', n, s, 0, [ba >> n @ bb, bb])
rename_d = Frame('F', n, s, 0, [ba_BOX >> n @ bb_BOX, bb_BOX])

# Identity on all elements
f_id = Functor(grammar,
ob=lambda _, ty: ty,
ar=lambda _, ob: ob)

def nested_ar(functor, ob):
if isinstance(ob, Frame):
return Frame(ob.name,
ob.dom,
ob.cod,
ob.z,
[functor(c) for c in ob.components])
return Box("BOX", ob.dom, ob.cod)

# Identity on types and frames, rename boxes
f_rename_boxes = Functor(grammar,
ob=lambda _, ty: ty,
ar=nested_ar)

assert f_id(d) == d
assert f_rename_boxes(d) == rename_d

0 comments on commit ef3fe12

Please sign in to comment.