diff --git a/adapters/src/adapters/composition.py b/adapters/src/adapters/composition.py index 6d37e44b..6a7fd14e 100644 --- a/adapters/src/adapters/composition.py +++ b/adapters/src/adapters/composition.py @@ -152,7 +152,7 @@ def validate_composition(adapter_composition: AdapterCompositionBlock, level=0, f"Models of type {model_type} don't support adapter composition using {block_type.__name__}." ) for child in adapter_composition: - if not type(child) in ALLOWED_NESTINGS[type(adapter_composition)]: + if type(child) not in ALLOWED_NESTINGS[type(adapter_composition)]: raise ValueError(f"Adapter setup is invalid. Cannot nest {child} in {adapter_composition}") # recursively validate children validate_composition(child, level=level + 1) diff --git a/adapters/src/adapters/head_utils.py b/adapters/src/adapters/head_utils.py index 2144fbe5..d0caeecf 100644 --- a/adapters/src/adapters/head_utils.py +++ b/adapters/src/adapters/head_utils.py @@ -742,6 +742,8 @@ def get_head_config_and_rename_list(model_class_name, head_name, label2id, num_l escaped_name = re.escape(name) rename_list.append((rf"{escaped_name}\.(\S+)", f"heads.{head_name}.{i}.{{0}}")) i += 1 - rename_func = lambda k, rename_list=rename_list: _regex_list_rename_func(k, rename_list) + def rename_func(k, rename_list=rename_list): + return _regex_list_rename_func(k, rename_list) + return config, rename_func diff --git a/adapters/src/adapters/heads/model_mixin.py b/adapters/src/adapters/heads/model_mixin.py index 4e0dfde8..d19beda5 100644 --- a/adapters/src/adapters/heads/model_mixin.py +++ b/adapters/src/adapters/heads/model_mixin.py @@ -707,7 +707,10 @@ def _load_pretrained_model( if len(model.base_model_prefix) > 0 and not any( s.startswith(model.base_model_prefix) for s in loaded_keys ): - rename_func = lambda x: model.base_model_prefix + "." + x if x not in head_state_dict else x + def rename_func(x): + if x not in head_state_dict: + return model.base_model_prefix + "." + x + return x state_dict = {rename_func(k): v for k, v in state_dict.items()} loaded_keys = [rename_func(k) for k in loaded_keys] diff --git a/wtpsplit/evaluation/intrinsic_baselines.py b/wtpsplit/evaluation/intrinsic_baselines.py index 97ba0fa3..fbf03d67 100644 --- a/wtpsplit/evaluation/intrinsic_baselines.py +++ b/wtpsplit/evaluation/intrinsic_baselines.py @@ -152,7 +152,7 @@ class Args: indices[lang][dataset_name][name]["length"] = [metrics.pop("length")] results[lang][dataset_name][name] = metrics except LanguageError as e: - print("Language not supported for", name) + print("Language not supported for", name, e) results[lang][dataset_name][name] = None json.dump(results, open(Constants.CACHE_DIR / "intrinsic_baselines.json", "w"), indent=4, default=int) diff --git a/wtpsplit/evaluation/intrinsic_pairwise.py b/wtpsplit/evaluation/intrinsic_pairwise.py index 18845289..37aac490 100644 --- a/wtpsplit/evaluation/intrinsic_pairwise.py +++ b/wtpsplit/evaluation/intrinsic_pairwise.py @@ -9,7 +9,6 @@ import logging import h5py -import skops.io as sio import torch from datasets import load_dataset from tqdm.auto import tqdm @@ -290,7 +289,7 @@ def main(args): print(save_str) eval_data = torch.load(args.eval_data_path) - if "canine" in args.model_path and not "no-adapters" in args.model_path: + if "canine" in args.model_path and "no-adapters" not in args.model_path: eval_data = split_language_data(eval_data) if args.valid_text_path is not None: valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train") diff --git a/wtpsplit/evaluation/stat_tests/permutation_test_data.py b/wtpsplit/evaluation/stat_tests/permutation_test_data.py index 2e05e3d4..553318c5 100644 --- a/wtpsplit/evaluation/stat_tests/permutation_test_data.py +++ b/wtpsplit/evaluation/stat_tests/permutation_test_data.py @@ -63,10 +63,6 @@ try: if isinstance(data_list[0], int): data_list = [data_list] - except: - print(data_list) - print(lang, dataset, model_type) - raise Exception raw_data[lang][dataset][model + "-" + model_type] = data_list diff --git a/wtpsplit/extract_batched.py b/wtpsplit/extract_batched.py index ba7006da..3464aadd 100644 --- a/wtpsplit/extract_batched.py +++ b/wtpsplit/extract_batched.py @@ -3,7 +3,6 @@ import logging import numpy as np -from tqdm.auto import tqdm from transformers import AutoTokenizer from tokenizers import AddedToken diff --git a/wtpsplit/models.py b/wtpsplit/models.py index 8579d1b3..90580a99 100644 --- a/wtpsplit/models.py +++ b/wtpsplit/models.py @@ -1,12 +1,10 @@ import copy import math -import warnings from typing import List, Optional, Tuple, Union import torch from torch import Tensor, nn from torch.nn import CrossEntropyLoss -from torchinfo import summary from transformers import AutoModel, AutoModelForTokenClassification from transformers.modeling_outputs import ( BaseModelOutputWithPoolingAndCrossAttentions, @@ -40,7 +38,6 @@ ) from transformers.models.xlm_roberta.modeling_xlm_roberta import ( XLMRobertaEmbeddings, - XLMRobertaEncoder, XLMRobertaPooler, XLMRobertaLayer, ) @@ -1395,30 +1392,30 @@ def custom_forward(*inputs): AutoModel.register(SubwordXLMConfig, SubwordXLMForTokenClassification) AutoModelForTokenClassification.register(SubwordXLMConfig, SubwordXLMForTokenClassification) -if __name__ == "__main__": - # test XLM - from transformers import AutoConfig, AutoTokenizer - - model_str = "xlm-roberta-base" - config = SubwordXLMConfig.from_pretrained(model_str) - config.num_labels = 4 - config.num_hidden_layers = 12 - config.lookahead = 48 - config.lookahead_split_layers = 6 - backbone = SubwordXLMForTokenClassification.from_pretrained(model_str, config=config) - print(summary(backbone, depth=4)) - - # some sample input - text = "A sentence. Now we move on. And on and this is the last sentence. Now, we are starting to move on to the next sentence. This is the last sentence." - tokenizer = AutoTokenizer.from_pretrained(model_str) - - tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False, pad_to_multiple_of=512, padding=True) - from tokenizers import AddedToken - - tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]}) - print(tokenizer.tokenize(text)) - print(tokenizer.encode(text)) - print(tokens) - - # forward pass - print(backbone(**tokens)) +# if __name__ == "__main__": +# # test XLM +# from transformers import AutoTokenizer + +# model_str = "xlm-roberta-base" +# config = SubwordXLMConfig.from_pretrained(model_str) +# config.num_labels = 4 +# config.num_hidden_layers = 12 +# config.lookahead = 48 +# config.lookahead_split_layers = 6 +# backbone = SubwordXLMForTokenClassification.from_pretrained(model_str, config=config) +# print(summary(backbone, depth=4)) + +# # some sample input +# text = "A sentence. Now we move on. And on and this is the last sentence. Now, we are starting to move on to the next sentence. This is the last sentence." +# tokenizer = AutoTokenizer.from_pretrained(model_str) + +# tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False, pad_to_multiple_of=512, padding=True) +# from tokenizers import AddedToken + +# tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]}) +# print(tokenizer.tokenize(text)) +# print(tokenizer.encode(text)) +# print(tokens) + +# # forward pass +# print(backbone(**tokens)) diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index 96cc7c11..761dc112 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -4,7 +4,7 @@ import random import shutil import sys -import time +# import time from collections import Counter, defaultdict from dataclasses import dataclass from functools import partial @@ -17,7 +17,7 @@ import torch_xla.core.xla_model as xm import transformers from datasets import load_dataset -from datasets.download import DownloadConfig +# from datasets.download import DownloadConfig from tokenizers import AddedToken from torchinfo import summary from tqdm.auto import tqdm @@ -35,6 +35,7 @@ from wtpsplit.train.evaluate import evaluate_sentence from wtpsplit.train.trainer import Trainer from wtpsplit.train.utils import Model, cleanup_cache_files +# from wtpsplit.train.utils import cleanup_cache_files from wtpsplit.utils import Constants, LabelArgs, corrupt_training, get_label_dict, get_subword_label_dict logger = logging.getLogger(__name__)