Skip to content

Commit

Permalink
macos12 full build (x86) (#125)
Browse files Browse the repository at this point in the history
* macos12 full build (x86)

* add support for setting precision via --dtype
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 1b8289c commit ea5f01a
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/compile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
run-tinystories:
strategy:
matrix:
runner: [ubuntu-latest, macos-14]
runner: [ubuntu-latest, macos-14, macos-12]
runs-on: ${{matrix.runner}}
steps:
- name: Checkout repo
Expand Down
7 changes: 5 additions & 2 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch.nn as nn
from torch.export import Dim, export

from quantize import quantize_model, name_to_dtype, set_precision, get_precision

try:
executorch_export_available = True
from export_et import export_model as export_model_et
Expand Down Expand Up @@ -62,8 +64,9 @@ def main(checkpoint_path, device, quantize = "{ }", args = None):
assert checkpoint_path.is_file(), checkpoint_path

print(f"Using device={device}")
precision = torch.float # bfloat16

precision = name_to_dtype(args.dtype) # torch.float # bfloat16
set_precision(precision)

print("Loading model ...")
t0 = time.time()
model = _load_model(
Expand Down
5 changes: 3 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch._dynamo.config
import torch._inductor.config

from quantize import quantize_model, name_to_dtype
from quantize import quantize_model, name_to_dtype, set_precision, get_precision


def device_sync(device):
Expand Down Expand Up @@ -344,7 +344,8 @@ def main(
# print = lambda *args, **kwargs: None

print(f"Using device={device}")
precision = torch.float # bfloat16
precision = name_to_dtype(model_dtype)
set_precision(precision)
is_speculative = draft_checkpoint_path is not None
is_chat = "chat" in str(checkpoint_path)

Expand Down
6 changes: 5 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch import Tensor
from torch.nn import functional as F

from quantize import get_precision

def find_multiple(n: int, k: int) -> int:
if n % k == 0:
Expand Down Expand Up @@ -99,8 +100,11 @@ def from_name(cls, name: str):

class KVCache(nn.Module):
def __init__(
self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.float): # bfloat16 ):
self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=None):
# torch.float): # bfloat16 ):
super().__init__()
if not dtype:
dtype=get_precision()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
Expand Down
File renamed without changes.
15 changes: 15 additions & 0 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@
##########################################################################
### dtype name to torch.dtype mapping ###

precision = torch.float

def set_precision(dtype):
global precision
precision = dtype

def get_precision():
global precision
return precision

def name_to_dtype(name):
if name in name_to_dtype_dict:
return name_to_dtype_dict[name]
Expand All @@ -33,6 +43,11 @@ def name_to_dtype(name):
"fp32" : torch.float,
"fp16" : torch.float16,
"bf16" : torch.bfloat16,
"float" : torch.float,
"half" : torch.float16,
"float32" : torch.float,
"float16" : torch.float16,
"bfloat16" : torch.bfloat16,
}

##########################################################################
Expand Down

0 comments on commit ea5f01a

Please sign in to comment.