-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbase_llama.py
33 lines (27 loc) · 924 Bytes
/
base_llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from dataclasses import dataclass
import re
from torch import dtype
from config import LlamaConfig
from utils import *
class LlamaPreTrainedModel(nn.Module):
config_class = LlamaConfig
base_model_prefix = "llama"
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.n_layers = config.n_layers
def init_weights(self):
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
@property
def dtype(self) -> dtype:
return get_parameter_dtype(self)