Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jan 10, 2025
1 parent f0c369a commit 0a33bc2
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions tests/test_quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import re
from collections import defaultdict
from typing import Optional
from unittest.mock import MagicMock

Expand Down Expand Up @@ -114,31 +115,33 @@ def test_apply_quantization_config_tinyllama():
for module in model.modules():
_test_layer_quantization_status(module, inputs=False, weights=False)

count_layer_names = ("Linear", "Embeddidng", "LlamaRotaryEmbedding")
count_layer_num = defaultdict(int)

for name, module in model.named_modules():
if name in quant_config.ignore:
continue
module_type = module.__class__.__name__
if module_type in count_layer_names:
count_layer_num[module_type] += 1

# apply quant config to model
apply_quantization_config(model, quant_config)

# check for correct application of quant config
num_linears = 0
num_embeddings = 0
num_rotary_embeddings = 0
for name, module in model.named_modules():
if name in quant_config.ignore:
continue
module_type = module.__class__.__name__
if module_type == "Linear":
num_linears += 1
_test_layer_quantization_status(module, inputs=True, weights=True)
elif module_type == "Embedding":
num_embeddings += 1
_test_layer_quantization_status(module, inputs=False, weights=True)
elif module_type == "LlamaRotaryEmbedding":
num_rotary_embeddings += 1
_test_layer_quantization_status(module, inputs=False, weights=False)

# sanity check correct number of layers targeted
assert num_linears == 154 # 155 Linear layers - 1 that gets ignored
assert num_embeddings == 1
assert num_rotary_embeddings == 23 # model updated, now has model.rotary_embedding
if module_type in count_layer_names:
count_layer_num[module_type] -= 1
_inputs = module_type == "Linear"
_weights = not module_type == "LlamaRotaryEmbedding"
_test_layer_quantization_status(module, inputs=_inputs, weights=_weights)

assert all(
value == 0 for value in count_layer_num.values()
), "Not all values are zero"

# test quantization compression
# sample forward pass to fill scales, zps
Expand Down

0 comments on commit 0a33bc2

Please sign in to comment.