Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2/n] MultiEmbeddingTensor: Add MultiEmbeddingTensor #181

Merged
merged 17 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `AmazonFineFoodReviews` dataset and OpenAI embedding example ([#182](https://github.com/pyg-team/pytorch-frame/pull/182))
- Added save and load logic for `FittableBaseTransform` ([#178](https://github.com/pyg-team/pytorch-frame/pull/178))
- Added `MultiEmbeddingTensor` ([#181](https://github.com/pyg-team/pytorch-frame/pull/181))
- Added `to_dense()` for `MultiNestedTensor` ([#170](https://github.com/pyg-team/pytorch-frame/pull/170))
- Added example for `multicategorical` stype ([#162](https://github.com/pyg-team/pytorch-frame/pull/162))
- Added `sequence_numerical` stype ([#159](https://github.com/pyg-team/pytorch-frame/pull/159))
Expand Down
118 changes: 118 additions & 0 deletions test/data/test_multi_embedding_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import random
from typing import List, Tuple

import pytest
import torch

from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor


def assert_equal(
tensor_list: List[torch.Tensor],
met: MultiEmbeddingTensor,
) -> None:
assert len(tensor_list) == met.num_cols
assert len(tensor_list[0]) == met.num_rows
for i in range(met.num_rows):
for j in range(met.num_cols):
# Note: tensor_list[j] is a tensor of j-th column of size
# [num_rows, dim_emb_j]. See the docs for more info.
assert torch.allclose(tensor_list[j][i], met[i, j])


def get_fake_multi_embedding_tensor(
num_rows: int,
num_cols: int,
) -> Tuple[MultiEmbeddingTensor, List[torch.Tensor]]:
tensor_list = []
for _ in range(num_cols):
embedding_dim = random.randint(1, 5)
tensor = torch.randn((num_rows, embedding_dim))
tensor_list.append(tensor)
return MultiEmbeddingTensor.from_list(tensor_list), tensor_list


def test_size():
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
num_rows = 8
num_cols = 3
lengths = torch.tensor([1, 2, 3])
offset = torch.tensor([0, 1, 3, 6])
values = torch.rand((num_rows, lengths.sum().item()))
met = MultiEmbeddingTensor(
num_rows=num_rows,
num_cols=num_cols,
values=values,
offset=offset,
)

assert met.size(0) == num_rows
assert met.size(1) == num_cols
with pytest.raises(IndexError, match="not have a fixed length"):
met.size(2)
with pytest.raises(IndexError, match="Dimension out of range"):
met.size(3)

assert met.shape[0] == num_rows
assert met.shape[1] == num_cols
assert met.shape[2] == -1


def test_from_list():
num_rows = 2
num_cols = 3
tensor_list = [
torch.tensor([[0, 1, 2], [3, 4, 5]]),
torch.tensor([[6, 7], [8, 9]]),
torch.tensor([[10], [11]]),
]
met = MultiEmbeddingTensor.from_list(tensor_list)
assert met.num_rows == num_rows
assert met.num_cols == num_cols
expected_values = torch.tensor([
[0, 1, 2, 6, 7, 10],
[3, 4, 5, 8, 9, 11],
])
assert torch.allclose(met.values, expected_values)
expected_offset = torch.tensor([0, 3, 5, 6])
assert torch.allclose(met.offset, expected_offset)
assert_equal(tensor_list, met)

# case: empty list
with pytest.raises(AssertionError):
MultiEmbeddingTensor.from_list([])

# case: list of non-2d tensors
with pytest.raises(AssertionError):
MultiEmbeddingTensor.from_list([torch.rand(1)])

# case: list of tensors having different num_rows
with pytest.raises(AssertionError):
MultiEmbeddingTensor.from_list([torch.rand(2, 1), torch.rand(3, 1)])

# case: list of tensors on different devices
with pytest.raises(AssertionError):
MultiEmbeddingTensor.from_list([
torch.rand(2, 1, device="cpu"),
torch.rand(2, 1, device="meta"),
])


def test_index():
met, tensor_list = get_fake_multi_embedding_tensor(
num_rows=2,
num_cols=3,
)
# case met[i, j]: a tuple of two integers
assert_equal(tensor_list, met)


def test_clone():
met, _ = get_fake_multi_embedding_tensor(
num_rows=2,
num_cols=3,
)
met_clone = met.clone()
met.values[0, 0] = 12345.
assert met_clone.values[0, 0] != 12345.
met.offset[0] = -1
assert met_clone.offset[0] != -1
2 changes: 1 addition & 1 deletion test/data/test_multi_nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_multi_nested_tensor_basic():
assert multi_nested_tensor.size(0) == num_rows
assert multi_nested_tensor.shape[1] == num_cols
assert multi_nested_tensor.size(1) == num_cols
with pytest.raises(ValueError, match="not have a fixed length"):
with pytest.raises(IndexError, match="not have a fixed length"):
multi_nested_tensor.size(2)
with pytest.raises(IndexError, match="Dimension out of range"):
multi_nested_tensor.size(3)
Expand Down
2 changes: 2 additions & 0 deletions torch_frame/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa

from .tensor_frame import TensorFrame
from .multi_embedding_tensor import MultiEmbeddingTensor
from .multi_nested_tensor import MultiNestedTensor
from .stats import StatType
from .dataset import Dataset, DataFrameToTensorFrameConverter
Expand All @@ -9,6 +10,7 @@

data_classes = [
'TensorFrame',
'MultiEmbeddingTensor',
'MultiNestedTensor',
'Dataset',
]
Expand Down
117 changes: 117 additions & 0 deletions torch_frame/data/multi_embedding_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Any, List, Sequence, Union

import torch
from torch import Tensor

from torch_frame.data.multi_tensor import _MultiTensor


class MultiEmbeddingTensor(_MultiTensor):
r"""A PyTorch tensor-based data structure that stores
:obj:`[num_rows, num_cols, *]`, where the size of last dimension can be
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
different for different column.

Note that the last dimension is the same within each column across rows
while in :class:`MultiNestedTensor`, the last dimension can be different
across both rows and columns.

Args:
num_rows (int): Number of rows.
num_cols (int): Number of columns.
values (torch.Tensor): The values :class:`torch.Tensor` of size
:obj:`[num_rows, dim1+dim2+...+dimN]`.
offset (torch.Tensor): The offset :class:`torch.Tensor` of size
:obj:`[num_cols+1,]`.

Example:
>>> num_rows = 2
>>> tensor_list = [
... torch.tensor([[0, 1, 2], [3, 4, 5]]), # col0
... torch.tensor([[6, 7], [8, 9]]), # col1
... torch.tensor([[10], [11]]), # col2
... ]
>>> out = MultiEmbeddingTensor.from_list(tensor_list)
>>> out
MultiEmbeddingTensor(num_rows=2, num_cols=3, device='cpu')
>>> out[0, 2]
tensor([10])
"""
def __init__(
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
self,
num_rows: int,
num_cols: int,
values: Tensor,
offset: Tensor,
) -> None:
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(num_rows, num_cols, values, offset)
assert offset[0] == 0
assert len(offset) == num_cols + 1

def __getitem__(
self,
index: Any,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
index: Any,
index: IndexSelectType,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will address in a follow-up!

) -> Union['MultiEmbeddingTensor', Tensor]:
if isinstance(index, tuple) and len(index) == 2 and isinstance(
index[0], int) and isinstance(index[1], int):
# return self.values[self.offset[index[1]]:self.offset[index[1] + 1]]
# self.values: [num_rows, dim1+dim2+...+dimN]
i = index[0]
j = index[1]
return self.values[i, self.offset[j]:self.offset[j + 1]]

# TODO(akihironitta): Support more index types
raise NotImplementedError

@classmethod
def from_list(
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
cls,
tensor_list: List[Tensor],
) -> 'MultiEmbeddingTensor':
r"""Creates a :class:`MultiEmbeddingTensor` from a list of :class:`torch.Tensor`.

Args:
tensor_list (List[Tensor]): A list of tensors, where each tensor
has the same number of rows and can have a different number of
columns.

Returns:
MultiEmbeddingTensor: A :class:`MultiEmbeddingTensor` instance.

Example:
>>> num_rows = 2
>>> tensor_list = [
... torch.tensor([[0, 1, 2], [3, 4, 5]]), # col0
... torch.tensor([[6, 7], [8, 9]]), # col1
... torch.tensor([[10], [11]]), # col2
... ]
>>> out = MultiEmbeddingTensor.from_list(tensor_list)
>>> out
MultiEmbeddingTensor(num_rows=2, num_cols=3, device='cpu')
>>> out[0, 0]
tensor([0, 1, 2])
"""
assert isinstance(tensor_list, list) and len(tensor_list) > 0
num_rows = tensor_list[0].size(0)
device = tensor_list[0].device
for tensor in tensor_list:
msg = "tensor_list must be a list of tensors."
assert isinstance(tensor, torch.Tensor), msg
msg = "tensor_list must be a list of 2D tensors."
assert tensor.dim() == 2, msg
msg = "num_rows must be the same across a list of input tensors."
assert tensor.size(0) == num_rows, msg
msg = "device must be the same across a list of input tensors."
assert tensor.device == device, msg

offset_list = []
accum_idx = 0
offset_list.append(accum_idx)
for tensor in tensor_list:
accum_idx += tensor.size(1)
offset_list.append(accum_idx)

num_cols = len(tensor_list)
values = torch.cat(tensor_list, dim=1)
assert values.size() == (num_rows, offset_list[-1])
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
offset = torch.LongTensor(offset_list)
return cls(num_rows, num_cols, values, offset)
Loading