-
Notifications
You must be signed in to change notification settings - Fork 1
/
lora.py
30 lines (29 loc) · 1014 Bytes
/
lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import math
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.ops.operations as P
from mindspore.common.initializer import initializer
class lora_block(nn.Cell):
def __init__(self,in_dims,out_dims,hid_dims=16,shard=None):
super().__init__()
self.dim=in_dims
self.alpha=1
self.A=nn.Dense(in_dims,hid_dims,has_bias=False,weight_init='normal')
self.B=nn.Dense(hid_dims,out_dims,has_bias=False,weight_init='zeros')
self.mm=P.MatMul()
if shard is not None:
dp,mp=shard
self.set_shard(dp,mp)
def set_shard(self,dp,mp):
self.A.matmul.shard(((dp,1),(1,1)))
self.B.matmul.shard(((dp,1),(1,1)))
def update_weight(self):
wa=self.A.weight.astype(ms.float32)
wb=self.B.weight.astype(ms.float32)
return self.mm(wb,wa)*self.alpha
def construct(self, x):
x=self.A(x)
x=self.B(x)
return x*self.alpha