Skip to content

Commit

Permalink
add 8bit quant all reduce support
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Aug 2, 2024
1 parent 2de80a3 commit 32395eb
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
4 changes: 3 additions & 1 deletion open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ class HvConfig(BaseConfig):
announce_maddrs: list[str] | None = None
matchmaking_time: float | None = None
averaging_timeout: float | None = None
hivemind_compression: Literal["none", "fp16", "scaled-fp16"] = "none"
hivemind_compression: Literal["none", "fp16", "scaled-fp16", "uniform8bit", "quantile8bit", "blockwise8bit"] = (
"none"
)
all_reduce_strategy: AllReduceStrategy = AllReduceStrategy.WAIT_FOR_ALL
timeout_waiting_for_peers: float | None = None
skip_load_from_peers: bool = False
Expand Down
20 changes: 17 additions & 3 deletions open_diloco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,24 @@ def get_compression_kwargs(hivemind_compression: str) -> dict:

ret_kwargs["grad_compression"] = NoCompression()
ret_kwargs["state_averaging_compression"] = NoCompression()
elif hivemind_compression == "uniform8bit":
from hivemind import Uniform8BitQuantization

ret_kwargs["grad_compression"] = Uniform8BitQuantization()
ret_kwargs["state_averaging_compression"] = Uniform8BitQuantization()
elif hivemind_compression == "quantile8bit":
from hivemind import Quantile8BitQuantization

ret_kwargs["grad_compression"] = Quantile8BitQuantization()
ret_kwargs["state_averaging_compression"] = Quantile8BitQuantization()

elif hivemind_compression == "blockwise8bit":
from hivemind import BlockwiseQuantization

ret_kwargs["grad_compression"] = BlockwiseQuantization()
ret_kwargs["state_averaging_compression"] = BlockwiseQuantization()
else:
raise ValueError(
f"Invalid hivemind_compression: {hivemind_compression}. Please choose 'none', 'fp16', or 'scaled-fp16'."
)
raise ValueError(f"Invalid hivemind_compression: {hivemind_compression}")
return ret_kwargs


Expand Down

0 comments on commit 32395eb

Please sign in to comment.