-
Notifications
You must be signed in to change notification settings - Fork 521
/
GPTQ.py
417 lines (348 loc) · 15.5 KB
/
GPTQ.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.fx as fx
import torch.nn as nn
import torch.nn.functional as F
from torch.utils._pytree import tree_flatten, tree_unflatten
aten = torch.ops.aten
from eval import (
setup_cache_padded_seq_input_pos_max_seq_length_for_prefill,
GPTFastEvalWrapper
)
class InputRecorder(GPTFastEvalWrapper):
"""
This is a fake evaluation wrapper that just records the inputs
so that they can be used in calibration.
If pad_calibration_inputs is enabled, the input recorder will take
each input and pad/truncate it down to the calibration_seq_length.
It will also edit the model embeddings to be zero for the 0 token used
in padding and avoid any inputs with the 0 token.
If not, it will only truncate inputs to the desired length.
"""
def __init__(
self,
model,
tokenizer,
calibration_seq_length,
pad_calibration_inputs=False,
):
super().__init__(model, tokenizer, calibration_seq_length)
self._model = model
self._tokenizer = tokenizer
self._device = torch.device("cpu")
self.vocab_size = model.config.vocab_size
self.calibration_seq_length = calibration_seq_length
self.pad_calibration_inputs = pad_calibration_inputs
self.inputs = None
if self.pad_calibration_inputs:
# This is needed for the pad_calibration_inputs option
# to work properly, the 0 token's embeddings are set to 0 so that
# the padded inputs will not affect the model numerics. This token isn't used
# commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs
# where it appears
try:
if isinstance(self._model.transformer.wte, nn.Embedding):
self.mod.transformer.wte.weight.data[0, :] *= 0
except:
print(
"Did not find embeddings in model.transformer.wte, disabling padding"
)
self.pad_calibration_inputs = False
def add_input(self, args):
if self.inputs is None:
self.inputs = [MultiInput([arg]) for arg in args]
else:
self.inputs = [
multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
]
def get_recorded_inputs(self):
return self.inputs
def _model_call(self, inps):
inps = inps.squeeze(0)
T = len(inps)
if (
# can't use inputs that are too short when padding disabled
(T < self.calibration_seq_length and not self.pad_calibration_inputs)
or
# can't use inputs that actually use token we use for padding
(self.pad_calibration_inputs and 0 in inps)
):
# give random output
return torch.randn(
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
)
# pad or truncate to the right size
if T >= self.calibration_seq_length:
inps = inps[: self.calibration_seq_length]
else:
inps = F.pad(inps, (0, self.calibration_seq_length - T))
max_new_tokens = 1
(
seq,
input_pos,
max_seq_length,
) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
self._model, inps, max_new_tokens, self.max_length
)
x = seq.index_select(0, input_pos).view(1, -1)
self.add_input((x, input_pos))
# output `something` with correct shape to keep eval going
return torch.randn(
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
)
class MultiInput:
def __init__(self, inputs):
self.values = list(inputs)
def add_input(self, input):
self.values.append(input)
return self
def __getitem__(self, slice):
return MultiInput(self.values[slice])
def cuda(self):
self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values]
class GenericGPTQRunner(fx.Interpreter):
"""
This is a generic GPTQ runner that takes an existing model and applies GPTQ.
It uses torch._dynamo.export to obtain a graph of the model and then hooks
into function calls and when it detects a linear, it applies GPTQ to the weight
given the calibration of inputs passed in at initialization. It puts the results
into the state_dict so that the quantized model weights/qparams can be loaded
directly into the model.
This class is expected to work in concert with a GPTQSimpleQuantizer
class to define the specific type of quantization being done.
"""
def __init__(
self, model, inputs: MultiInput, blocksize=128, percdamp=0.01, groupsize=128
):
self.id_to_name = {
id(value): name for name, value in dict(model.named_parameters()).items()
}
# trace model for one input
one_input = [multi.values[0].cpu() for multi in inputs]
exported_model = torch._dynamo.export(
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
)(*one_input)
super().__init__(exported_model.graph_module)
self.new_state_dict = model.state_dict()
self.blocksize = blocksize
self.percdamp = percdamp
self.groupsize = groupsize
self.inputs = inputs
self.gptq_done = False
self.debug = False
def configure_quantization_mode(
self,
get_qparams_func,
quantize_func,
dequantize_func,
combine_qparams_list_func,
make_names_and_values_dict_func,
skip_layer_func,
):
# these functions need to already be curried with all inputs other than weight, qparams
self.get_qparams_func = (
get_qparams_func # accepts [2d weight tensor], outputs qparams.
)
self.quantize_func = quantize_func # accepts [2d weight tensor], [qparams], outputs a 2d quantized tensor of desired dtype
self.dequantize_func = dequantize_func
# accepts [quantized] tensor and [qparams], outputs a 2d dequantized tensor of type float,
# assumes this output .to(w_orig_dtype) is ~eventual desired dequant behavior
self.combine_qparams_list_func = combine_qparams_list_func
# accepts [`list` of qparams] from quantizing one group at a time,
# outputs a qparams object that could be passed into quant/dequantize_func
self.skip_layer_func = skip_layer_func # accepts [weight tensor], outputs a bool on whether or not to apply gptq to this layer
self.make_names_and_values_dict_func = make_names_and_values_dict_func # accepts [2d quantized tensor], [qparams], returns a dict of names, values to put in state_dict
# note any final packing for storage should happen here
return self
def run(self):
assert (
self.get_qparams_func is not None
), "need to configure quantization mode before running"
self.gptq_done = True
super().run(*self.inputs)
def get_quantized_state_dict(self):
assert (
self.gptq_done
), "need to run GPTQRunner before you can get_quantized_state_dict"
quantized_state_dict = self.new_state_dict
# Don't want to store/load the kv_cache so remove it from the state_dict
del_list = []
for param_fqn in quantized_state_dict:
if "kv_cache" in param_fqn:
del_list.append(param_fqn)
for param_fqn in del_list:
quantized_state_dict.pop(param_fqn)
return quantized_state_dict
def call_function(self, target, args, kwargs, skip_quant=False):
def tensors_to_cuda(args):
new_args = []
for x in args:
new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x)
return new_args
# flatten args and kwargs together
flat_args, spec = tree_flatten((args, kwargs))
# move all single tensors to cuda, will move MultiInputs to cuda one at a time
flat_args = tensors_to_cuda(flat_args)
has_multi_input = MultiInput in [type(x) for x in flat_args]
if has_multi_input:
# Just some trickery to convert
# [MultiInput[a, a, a], MultiInput(b, b, b)] => [a, b], [a, b], [a, b]
multi_input_count = max(
[len(x.values) if isinstance(x, MultiInput) else 1 for x in flat_args]
)
transposed_args = list(
zip(
*[x.values if isinstance(x, MultiInput) else [x] * multi_input_count for x in flat_args]
)
)
else:
transposed_args = [flat_args]
outputs = []
# check whether we apply GPTQ to this module
quantize_linear = (
(target == aten.linear.default) # if its a linear
and id(args[1]) in self.id_to_name # and if we know the layer name
and not skip_quant # and if we weren't told to skip quantization
# and if the skip_layer_func doesn't say we should skip
and not (self.skip_layer_func is not None and self.skip_layer_func(args[1]))
) # then we will quantize this linear layer/weight
if quantize_linear: # instantiate variables for GPTQ
H = 0
total_batches = 0
for inp in transposed_args:
inp = tensors_to_cuda(inp)
cur_args, cur_kwargs = tree_unflatten(inp, spec)
if (
quantize_linear
): # calculate H instead of output (will run the linear eventually with updated weight)
x = cur_args[0].float()
shape = x.shape
n = 1 if len(shape) == 2 else shape[0]
H *= total_batches / (total_batches + n)
total_batches += n
x = ((2 / total_batches) ** (1 / 2)) * x.reshape(
-1, shape[-1]
).t().float()
H += x.matmul(x.t())
else:
# get output if its not a linear
out = super().call_function(target, cur_args, cur_kwargs)
if isinstance(out, torch.Tensor):
outputs.append(out.cpu())
else:
outputs.append(out)
if quantize_linear:
mod_fqn = ".".join(self.id_to_name[id(args[1])].split(".")[:-1])
W = args[1].to(H.device)
Q, DQ, qparams = self.faster_quant(H, W.detach())
print(mod_fqn)
names_and_values_dict = self.make_names_and_values_dict_func(Q, qparams)
# delete old weight
if mod_fqn + ".weight" in self.new_state_dict:
self.new_state_dict.pop(mod_fqn + ".weight")
if len(args) > 2:
self.new_state_dict[mod_fqn + ".bias"] = args[2]
for name, value in names_and_values_dict.items():
self.new_state_dict[mod_fqn + "." + name] = value
# run linear with new weight to get corrected output
new_out = self.call_function(
target, (args[0], DQ, *args[2:]), kwargs, skip_quant=True
)
if self.debug:
old_out = self.call_function(
target, (args[0][:2], args[1], *args[2:]), kwargs, skip_quant=True
)
def SQNR(x, y):
return 20 * torch.log10(torch.norm(x) / torch.norm(x - y))
DQ_after = self.dequantize_func(Q, qparams).to(W.dtype)
print(
"SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after)
) # matches
print(
"SQNR for weight (can be low)", SQNR(W, DQ.cuda())
) # fine to not match
print(
"SQNR for output with GPTQ (hopefully 35+)",
torch.cat(
[
SQNR(old.cpu(), new.cpu()).unsqueeze(0)
for (old, new) in zip(old_out.values, new_out.values[:2])
]
).mean(),
)
qparams2 = self.get_qparams_func(W)
Q2 = self.quantize_func(W, qparams2)
DQ2 = self.dequantize_func(Q2, qparams2).to(W.dtype)
old_q_out = self.call_function(
target, (args[0][:2], DQ2, *args[2:]), kwargs, skip_quant=True
)
print("SQNR for output without GPTQ (should be less than above)",
torch.cat([
SQNR(old.cpu(), old_q.cpu()).unsqueeze(0)
for (old, old_q) in zip(old_out.values, old_q_out.values)
]).mean(),
)
return new_out
return MultiInput(outputs) if has_multi_input else outputs[0]
def faster_quant(self, H, W):
percdamp = self.percdamp
blocksize = self.blocksize
groupsize = self.groupsize
orig_dtype = W.dtype
W = W.detach().float()
rows, columns = W.shape[0], W.shape[1]
device = W.device
if groupsize == -1:
cur_qparams = self.get_qparams_func(W)
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
Losses = torch.zeros_like(W)
DQ = torch.zeros_like(W)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(columns, device=device)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
all_qparams = []
for i1 in range(0, columns, blocksize):
i2 = min(i1 + blocksize, columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
DQ1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if groupsize != -1 and (i1 + i) % groupsize == 0: # start of new group
cur_qparams = self.get_qparams_func(
W[:, (i1 + i) : (i1 + i + groupsize)]
)
all_qparams.append(cur_qparams)
q = self.quantize_func(w.unsqueeze(1), cur_qparams).flatten()
dq = self.dequantize_func(q.unsqueeze(1), cur_qparams).flatten()
DQ1[:, i] = dq
Losses1[:, i] = (w - dq) ** 2 / d**2
err1 = (w - dq) / d
W1[:, i:] -= (
err1.to(Hinv1.dtype).unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
)
Err1[:, i] = err1
DQ[:, i1:i2] = DQ1
Losses[:, i1:i2] = Losses1 / 2
W[:, i2:] -= Err1.to(Hinv.dtype).matmul(Hinv[i1:i2, i2:])
torch.cuda.synchronize()
if all_qparams == []:
all_qparams.append(cur_qparams)
# convert a list of qparams objects into a single one. enerally by
# concatenating a bunch of n,1 scale/zeros tensors into a n,num_groups tensor
all_qparams = self.combine_qparams_list_func(all_qparams)
Q = self.quantize_func(DQ, all_qparams)
return Q, DQ.to(orig_dtype), all_qparams