Skip to content

Commit

Permalink
fix some lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jun 16, 2024
1 parent c6983d5 commit a4f9e20
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 43 deletions.
2 changes: 1 addition & 1 deletion adapters/src/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def validate_composition(adapter_composition: AdapterCompositionBlock, level=0,
f"Models of type {model_type} don't support adapter composition using {block_type.__name__}."
)
for child in adapter_composition:
if not type(child) in ALLOWED_NESTINGS[type(adapter_composition)]:
if type(child) not in ALLOWED_NESTINGS[type(adapter_composition)]:
raise ValueError(f"Adapter setup is invalid. Cannot nest {child} in {adapter_composition}")
# recursively validate children
validate_composition(child, level=level + 1)
Expand Down
4 changes: 3 additions & 1 deletion adapters/src/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,8 @@ def get_head_config_and_rename_list(model_class_name, head_name, label2id, num_l
escaped_name = re.escape(name)
rename_list.append((rf"{escaped_name}\.(\S+)", f"heads.{head_name}.{i}.{{0}}"))
i += 1
rename_func = lambda k, rename_list=rename_list: _regex_list_rename_func(k, rename_list)
def rename_func(k, rename_list=rename_list):
return _regex_list_rename_func(k, rename_list)


return config, rename_func
5 changes: 4 additions & 1 deletion adapters/src/adapters/heads/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,10 @@ def _load_pretrained_model(
if len(model.base_model_prefix) > 0 and not any(
s.startswith(model.base_model_prefix) for s in loaded_keys
):
rename_func = lambda x: model.base_model_prefix + "." + x if x not in head_state_dict else x
def rename_func(x):
if x not in head_state_dict:
return model.base_model_prefix + "." + x
return x
state_dict = {rename_func(k): v for k, v in state_dict.items()}
loaded_keys = [rename_func(k) for k in loaded_keys]

Expand Down
2 changes: 1 addition & 1 deletion wtpsplit/evaluation/intrinsic_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class Args:
indices[lang][dataset_name][name]["length"] = [metrics.pop("length")]
results[lang][dataset_name][name] = metrics
except LanguageError as e:
print("Language not supported for", name)
print("Language not supported for", name, e)
results[lang][dataset_name][name] = None

json.dump(results, open(Constants.CACHE_DIR / "intrinsic_baselines.json", "w"), indent=4, default=int)
Expand Down
3 changes: 1 addition & 2 deletions wtpsplit/evaluation/intrinsic_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import logging

import h5py
import skops.io as sio
import torch
from datasets import load_dataset
from tqdm.auto import tqdm
Expand Down Expand Up @@ -290,7 +289,7 @@ def main(args):

print(save_str)
eval_data = torch.load(args.eval_data_path)
if "canine" in args.model_path and not "no-adapters" in args.model_path:
if "canine" in args.model_path and "no-adapters" not in args.model_path:
eval_data = split_language_data(eval_data)
if args.valid_text_path is not None:
valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train")
Expand Down
4 changes: 0 additions & 4 deletions wtpsplit/evaluation/stat_tests/permutation_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@
try:
if isinstance(data_list[0], int):
data_list = [data_list]
except:
print(data_list)
print(lang, dataset, model_type)
raise Exception

raw_data[lang][dataset][model + "-" + model_type] = data_list

Check failure on line 67 in wtpsplit/evaluation/stat_tests/permutation_test_data.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (E999)

wtpsplit/evaluation/stat_tests/permutation_test_data.py:67:17: E999 SyntaxError: Unexpected token 'raw_data'

Check failure on line 67 in wtpsplit/evaluation/stat_tests/permutation_test_data.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (E999)

wtpsplit/evaluation/stat_tests/permutation_test_data.py:67:17: E999 SyntaxError: Unexpected token 'raw_data'

Expand Down
1 change: 0 additions & 1 deletion wtpsplit/extract_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging

import numpy as np
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from tokenizers import AddedToken

Expand Down
57 changes: 27 additions & 30 deletions wtpsplit/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import copy
import math
import warnings
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss
from torchinfo import summary
from transformers import AutoModel, AutoModelForTokenClassification
from transformers.modeling_outputs import (
BaseModelOutputWithPoolingAndCrossAttentions,
Expand Down Expand Up @@ -40,7 +38,6 @@
)
from transformers.models.xlm_roberta.modeling_xlm_roberta import (
XLMRobertaEmbeddings,
XLMRobertaEncoder,
XLMRobertaPooler,
XLMRobertaLayer,
)
Expand Down Expand Up @@ -1395,30 +1392,30 @@ def custom_forward(*inputs):
AutoModel.register(SubwordXLMConfig, SubwordXLMForTokenClassification)
AutoModelForTokenClassification.register(SubwordXLMConfig, SubwordXLMForTokenClassification)

if __name__ == "__main__":
# test XLM
from transformers import AutoConfig, AutoTokenizer

model_str = "xlm-roberta-base"
config = SubwordXLMConfig.from_pretrained(model_str)
config.num_labels = 4
config.num_hidden_layers = 12
config.lookahead = 48
config.lookahead_split_layers = 6
backbone = SubwordXLMForTokenClassification.from_pretrained(model_str, config=config)
print(summary(backbone, depth=4))

# some sample input
text = "A sentence. Now we move on. And on and this is the last sentence. Now, we are starting to move on to the next sentence. This is the last sentence."
tokenizer = AutoTokenizer.from_pretrained(model_str)

tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False, pad_to_multiple_of=512, padding=True)
from tokenizers import AddedToken

tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
print(tokenizer.tokenize(text))
print(tokenizer.encode(text))
print(tokens)

# forward pass
print(backbone(**tokens))
# if __name__ == "__main__":
# # test XLM
# from transformers import AutoTokenizer

# model_str = "xlm-roberta-base"
# config = SubwordXLMConfig.from_pretrained(model_str)
# config.num_labels = 4
# config.num_hidden_layers = 12
# config.lookahead = 48
# config.lookahead_split_layers = 6
# backbone = SubwordXLMForTokenClassification.from_pretrained(model_str, config=config)
# print(summary(backbone, depth=4))

# # some sample input
# text = "A sentence. Now we move on. And on and this is the last sentence. Now, we are starting to move on to the next sentence. This is the last sentence."
# tokenizer = AutoTokenizer.from_pretrained(model_str)

# tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False, pad_to_multiple_of=512, padding=True)
# from tokenizers import AddedToken

# tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
# print(tokenizer.tokenize(text))
# print(tokenizer.encode(text))
# print(tokens)

# # forward pass
# print(backbone(**tokens))
5 changes: 3 additions & 2 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
import shutil
import sys
import time
# import time
from collections import Counter, defaultdict
from dataclasses import dataclass
from functools import partial
Expand All @@ -17,7 +17,7 @@
import torch_xla.core.xla_model as xm
import transformers
from datasets import load_dataset
from datasets.download import DownloadConfig
# from datasets.download import DownloadConfig
from tokenizers import AddedToken
from torchinfo import summary
from tqdm.auto import tqdm
Expand All @@ -35,6 +35,7 @@
from wtpsplit.train.evaluate import evaluate_sentence
from wtpsplit.train.trainer import Trainer
from wtpsplit.train.utils import Model, cleanup_cache_files
# from wtpsplit.train.utils import cleanup_cache_files
from wtpsplit.utils import Constants, LabelArgs, corrupt_training, get_label_dict, get_subword_label_dict

logger = logging.getLogger(__name__)
Expand Down

0 comments on commit a4f9e20

Please sign in to comment.