Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ByTermSplit splitter #499

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

falquaddoomi
Copy link
Collaborator

@falquaddoomi falquaddoomi commented Nov 15, 2024

This PR adds a new splitter, ByTermSplit, that like the existing splitters divides a set of genes into multiple "splits", typically a training/testing split, or a training, validation, and test split.

The ByTermSplit splitter takes the following arguments in its constructor:

  • labelset, which should be an instance of LabelSetCollection
  • split_terms, which should be an iterable, which contains as many iterables as you want splits.
  • exclusive, a boolean that, if true, only allows a specific gene to occur in the first split to which it's been assigned

For example, to produce three splits, you'd supply the following for split_terms:

[
  ('GO:0000281', 'GO:0032755', 'GO:0065003'),
  ('GO:0006412', 'GO:0004930', 'GO:0000398'),
  ('GO:0006886', 'GO:0045766', 'GO:0051091')
]

The resulting splits will contain all genes to which any term in the split corresponds. Note that while it's expected that you'd make the splits disjoint (i.e., with no terms shared between splits), you don't have to; the same genes would then end up in both splits. Also, if the same gene happens to be identified by two different terms in different splits, it will be included in both (unless exclusive is specified).

Optionally, you can specify up to one "catch-all" split, identified by a single "*". The "catch-all" split will contain any genes that weren't assigned to other splits. For example:

[
  ('GO:0000281', 'GO:0032755', 'GO:0065003'),
  ('*'), # this split contains all genes not in the first or last splits
  ('GO:0006886', 'GO:0045766', 'GO:0051091')
]

Finally, the exclusive flag ensures that no split contains genes present in another split. Preference is given to the first split in which the gene occurs, so if gene ID 3841 occurred in splits 1 and 3, it would just be retained in 1.

Partially addresses #498.

@kmanpearl
Copy link

Hey Faisal,
Sorry for not getting to this review sooner. I ran into an issue when trying to create an OBNBDataset object. This is the code I used:

from obnb.dataset import Dataset as OBNBDataset
from obnb.data import GOBP, BioGRID
from obnb.label.filters import Compose, EntityExistenceFilter, LabelsetRangeFilterSize
from obnb.label.split import ByTermSplit

def setup_data(root: str):
    lsc = GOBP(root=root)
    g = BioGRID(root=root)

    split = ByTermSplit(
            lsc,
            split_terms=[
                #('*'),
                #('GO:0000281', 'GO:0032755', 'GO:0065003', 'GO:0007420', 'GO:0032543', 'GO:0045727', 'GO:0046330', 'GO:0008360', 'GO:0000981', 'GO:0007283', 'GO:0051607', 'GO:0005125', 'GO:0051865', 'GO:0006470', 'GO:0005200', 'GO:0000165'),
                 ('GO:0006412', 'GO:0004930', 'GO:0000398', 'GO:0016525', 'GO:0090263', 'GO:0008283', 'GO:0098586', 'GO:0043433', 'GO:0006511', 'GO:0030512', 'GO:0016567', 'GO:0016579', 'GO:0008083'),
                 ('*'),
                 ('GO:0006886', 'GO:0045766', 'GO:0051091', 'GO:0007166', 'GO:0001934', 'GO:0016477', 'GO:0000724', 'GO:0016241', 'GO:0007179', 'GO:0032088', 'GO:0001558', 'GO:0050852', 'GO:0051897'),
            ],
            exclusive=False
        )
    # Standard Preprocessing
    # Removed check for num genes per split since not needed with term split
    # Lowered network genes to 15 since all genes for a term will be in the same split 
    lsc.iapply(
        Compose(
            # Only use genes that are present in the network
            EntityExistenceFilter(list(g.node_ids)),
            # Remove any labelsets with less than 15 network genes
            LabelsetRangeFilterSize(min_val=15)
        ),
    )
    
    return OBNBDataset(
        graph=g,
        feature=g.to_dense_graph().to_feature(),
        label=lsc,
        splitter=split
    )
dataset = setup_data('data')

and the error:

{
	"name": "IndexError",
	"message": "arrays used as indices must be of integer (or boolean) type",
	"stack": "---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[6], line 1
----> 1 dataset = setup_data('data')

Cell In[5], line 32, in setup_data(root)
     11 split = ByTermSplit(
     12         lsc,
     13         split_terms=[
   (...)
     20         exclusive=False
     21     )
     23 lsc.iapply(
     24     Compose(
     25         # Only use genes that are present in the network
   (...)
     29     ),
     30 )
---> 32 return OBNBDataset(
     33     graph=g,
     34     feature=g.to_dense_graph().to_feature(),
     35     label=lsc,
     36     splitter=split
     37 )

File ~/Desktop/obnb/src/obnb/dataset/base.py:79, in Dataset.__init__(self, graph, feature, label, auto_generate_feature, dual, transform, transform_kwargs, splitter, **split_kwargs)
     77     self.masks = None
     78 else:
---> 79     self.y, self.masks = label.split(
     80         splitter,
     81         target_ids=tuple(self.idmap.lst),
     82         **split_kwargs,
     83     )
     85 # TODO: replace consider_negative option in label.split with this
     86 _, self.y_mask = label.get_y(
     87     target_ids=tuple(self.idmap.lst),
     88     return_y_mask=True,
     89 )

File ~/Desktop/obnb/src/obnb/label/collection.py:425, in LabelsetCollection.split(self, splitter, target_ids, labelset_name, mask_names, consider_negative, **kwargs)
    423     mask = np.zeros((len(target_ids), len(split)), dtype=bool)
    424     for i, j in enumerate(split):
--> 425         mask[to_target_idx[j], i] = True
    426     masks[mask_name] = mask
    428 if consider_negative:

IndexError: arrays used as indices must be of integer (or boolean) type"
}

@falquaddoomi
Copy link
Collaborator Author

falquaddoomi commented Dec 4, 2024

Hey @kmanpearl, thanks for trying it out, and for the complete example and bug report! The issue occurred because I was putting literal gene IDs (which also were strings, rather than ints) in the splits rather than indices into the ids array that's passed to ByTermSplit.__call__(), which I now realize is how the rest of the splitters work.

My latest commit changes it to integer indices, and with that change I was able to run your example without any obvious errors. If you could try an example again and let me know if the output looks reasonable, I'd appreciate it, since I don't know how to interpret it myself.

@kmanpearl
Copy link

hey @falquaddoomi,

When I try to run the same code after the latest commit I am now getting the following different error. I thought possibly it was related to having the ('*') catchall but I get the same error when I proved 3 sets of terms or two sets and one catch all:

IndexError                                Traceback (most recent call last)
Cell In[2], line 40
     25     lsc.iapply(
     26         Compose(
     27             # Only use genes that are present in the network
   (...)
     31         ),
     32     )
     34     return OBNBDataset(
     35         graph=g,
     36         feature=g.to_dense_graph().to_feature(),
     37         label=lsc,
     38         splitter=split
     39     )
---> 40 dataset = setup_data('data')

Cell In[2], line 34, in setup_data(root)
     22 # Standard Preprocessing
     23 # Removed check for num genes per split since not needed with term split
     24 # Lowered network genes to 15 since all genes for a term will be in the same split 
     25 lsc.iapply(
     26     Compose(
     27         # Only use genes that are present in the network
   (...)
     31     ),
     32 )
---> 34 return OBNBDataset(
     35     graph=g,
     36     feature=g.to_dense_graph().to_feature(),
     37     label=lsc,
     38     splitter=split
     39 )

File ~/Desktop/obnb/src/obnb/dataset/base.py:79, in Dataset.__init__(self, graph, feature, label, auto_generate_feature, dual, transform, transform_kwargs, splitter, **split_kwargs)
     77     self.masks = None
     78 else:
---> 79     self.y, self.masks = label.split(
     80         splitter,
     81         target_ids=tuple(self.idmap.lst),
     82         **split_kwargs,
     83     )
     85 # TODO: replace consider_negative option in label.split with this
     86 _, self.y_mask = label.get_y(
     87     target_ids=tuple(self.idmap.lst),
     88     return_y_mask=True,
     89 )

File ~/Desktop/obnb/src/obnb/label/collection.py:401, in LabelsetCollection.split(self, splitter, target_ids, labelset_name, mask_names, consider_negative, **kwargs)
    398     y[list(map(entity_idmap.get, labelset))] = True
    400 # Iterate over splits generated by splitter and align with target_ids
--> 401 splits = list(zip(*[*splitter(self.entity_ids, y)]))
    402 split_size = len(splits)
    403 if mask_names is not None:

File ~/Desktop/obnb/src/obnb/label/split/explicit.py:85, in ByTermSplit.__call__(self, ids, y)
     81 gdf = self.gene_id_to_terms
     83 # for each split, filter to the gene IDs that have at least one
     84 # term in the split
---> 85 result = [
     86     (
     87         {
     88             gene_id
     89             for gene_id in ids
     90             if gdf[gdf[\"GeneID\"] == str(id)][\"Terms\"].values[0] & terms
     91         }
     92         if terms != {\"*\"}
     93         else None
     94     )
     95     for terms in self.split_terms
     96 ]
     98 # if one of the resulting splits ended up as 'None', we need to
     99 # fill in that split with any gene that wasn't matched by any of
    100 # the other splits
    101 for idx in range(len(result)):

File ~/Desktop/obnb/src/obnb/label/split/explicit.py:87, in <listcomp>(.0)
     81 gdf = self.gene_id_to_terms
     83 # for each split, filter to the gene IDs that have at least one
     84 # term in the split
     85 result = [
     86     (
---> 87         {
     88             gene_id
     89             for gene_id in ids
     90             if gdf[gdf[\"GeneID\"] == str(id)][\"Terms\"].values[0] & terms
     91         }
     92         if terms != {\"*\"}
     93         else None
     94     )
     95     for terms in self.split_terms
     96 ]
     98 # if one of the resulting splits ended up as 'None', we need to
     99 # fill in that split with any gene that wasn't matched by any of
    100 # the other splits
    101 for idx in range(len(result)):

File ~/Desktop/obnb/src/obnb/label/split/explicit.py:90, in <setcomp>(.0)
     81 gdf = self.gene_id_to_terms
     83 # for each split, filter to the gene IDs that have at least one
     84 # term in the split
     85 result = [
     86     (
     87         {
     88             gene_id
     89             for gene_id in ids
---> 90             if gdf[gdf[\"GeneID\"] == str(id)][\"Terms\"].values[0] & terms
     91         }
     92         if terms != {\"*\"}
     93         else None
     94     )
     95     for terms in self.split_terms
     96 ]
     98 # if one of the resulting splits ended up as 'None', we need to
     99 # fill in that split with any gene that wasn't matched by any of
    100 # the other splits
    101 for idx in range(len(result)):

IndexError: index 0 is out of bounds for axis 0 with size 0

@falquaddoomi
Copy link
Collaborator Author

Hey @kmanpearl, thanks again for the error traceback. I accidentally introduced that error in commit f1e5dd7 when I was making changes to get the testing suites to pass; specifically, I changed id to gene_id in several places, but forgot one.

It's fixed as of e2b28f0 and I re-ran your code to make sure it's actually working this time, but please feel free to run it on your end as well.

@kmanpearl
Copy link

@falquaddoomi I was able to get the code to run without error but I'm not sure it is behaving as desired. I was comparing the results of the term split and the original gene split using the following two functions:

def setup_term_data(root: str):
    lsc = GOBP(root=root)
    g = BioGRID(root=root)

    split = ByTermSplit(
            lsc,
            split_terms=[
                #('*'),
                ('GO:0000281', 'GO:0032755', 'GO:0065003', 'GO:0007420', 'GO:0032543', 'GO:0045727', 'GO:0046330', 'GO:0008360', 'GO:0000981', 'GO:0007283', 'GO:0051607', 'GO:0005125', 'GO:0051865', 'GO:0006470', 'GO:0005200', 'GO:0000165'),
                ('GO:0006412', 'GO:0004930', 'GO:0000398', 'GO:0016525', 'GO:0090263', 'GO:0008283', 'GO:0098586', 'GO:0043433', 'GO:0006511', 'GO:0030512', 'GO:0016567', 'GO:0016579', 'GO:0008083'),
                #('*'),
                ('GO:0006886', 'GO:0045766', 'GO:0051091', 'GO:0007166', 'GO:0001934', 'GO:0016477', 'GO:0000724', 'GO:0016241', 'GO:0007179', 'GO:0032088', 'GO:0001558', 'GO:0050852', 'GO:0051897'),
            ],
            exclusive=False
        )
    # Standard Preprocessing
    # Removed check for num genes per split since not needed with term split
    # Lowered network genes to 15 since all genes for a term will be in the same split 
    lsc.iapply(
        Compose(
            # Only use genes that are present in the network
            EntityExistenceFilter(list(g.node_ids)),
            # Remove any labelsets with less than 15 network genes
            LabelsetRangeFilterSize(min_val=15)
        ),
    )
    
    return OBNBDataset(
        graph=g,
        feature=g.to_dense_graph().to_feature(),
        label=lsc,
        splitter=split
    )
term_dataset = setup_term_data('data')

def setup_gene_data(root: str):
    lsc = GOBP(root=root)
    g = BioGRID(root=root)
    pubmedcnt_converter = GenePropertyConverter(root, name="PubMedCount")
    splitter = RatioPartition(0.6, 0.2, 0.2, ascending=False,
                                    property_converter=pubmedcnt_converter)
    # Apply in-place filters to the labelset collection
    lsc.iapply(
        Compose(
            # Only use genes that are present in the network
            EntityExistenceFilter(list(g.node_ids)),
            # Remove any labelsets with less than 50 network genes
            LabelsetRangeFilterSize(min_val=50),
            # Make sure each split has at least 5 positive examples
            LabelsetRangeFilterSplit(min_val=5, splitter=splitter),
        ),
    )
    return OBNBDataset(
        graph=g,
        feature=g.to_dense_graph().to_feature(),
        label=lsc,
        splitter=splitter
    )
gene_dataset = setup_gene_data('data')

The shape of y_mask was as expected. gene_dataset.y_mask.shape = (19835, 889) and term_dataset.y_mask.shape = (19835, 121) The difference in number of terms is due to a difference in filtering steps. In the gene_dataset each row is always True or always False since all labels for a gene are placed in the same split as expected. However for term_dataset since I set exclusive=False this does not necessarily have to be true. What I would expect is that each column (representing a term's labels) are all the same (all True or all False).

However, when I look at the contents of term_dataset.y_mask it behaves how I would expect the gene_dataset to behave:

array([[ True,  True,  True, ...,  True,  True,  True],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [False, False, False, ..., False, False, False],
       [ True,  True,  True, ...,  True,  True,  True],
       [False, False, False, ..., False, False, False]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants