From 69f258cac3bf0cbf40fa94482b30fba08cf60079 Mon Sep 17 00:00:00 2001 From: Andreea Popescu Date: Tue, 17 Sep 2024 10:53:49 +0100 Subject: [PATCH] better build push --- .github/workflows/build_push_image.yml | 21 ++++++++++++++----- .../download_model.py | 1 + 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build_push_image.yml b/.github/workflows/build_push_image.yml index c38630c..5192dca 100644 --- a/.github/workflows/build_push_image.yml +++ b/.github/workflows/build_push_image.yml @@ -2,7 +2,9 @@ name: "CD: build & push image" on: push: - branches: [build-image] + branches: + - build-push-llama3-image + - build-push-phi3-image workflow_dispatch: env: @@ -12,7 +14,7 @@ env: jobs: deploy: - timeout-minutes: 15 + timeout-minutes: 30 runs-on: group: bulkier steps: @@ -31,13 +33,22 @@ jobs: run: | python -m pip install transformers torch + - name: Set environment variables based on branch + run: | + if [[ "${{ github.ref }}" == "refs/heads/build-push-llama3-image" ]]; then + echo "MODEL_NAME=llama3" >> $GITHUB_ENV + elif [[ "${{ github.ref }}" == "refs/heads/build-push-phi3-image" ]]; then + echo "MODEL_NAME=phi3" >> $GITHUB_ENV + fi + - name: Docker build and push run: | - df -h - IMAGE_NAME="${DOCKER_REPO_NAME}:${TAG_VERSION}" + IMAGE_NAME="${DOCKER_REPO_NAME}-{MODEL_NAME}:${TAG_VERSION}" cd src/compute_horde_prompt_gen - python download_model.py --model_name phi3 --huggingface_token "${{ secrets.HUGGINGFACE_API_KEY }}" + python download_model.py --model_name ${{ env.MODEL_NAME }} --huggingface_token "${{ secrets.HUGGINGFACE_API_KEY }}" docker build -t $IMAGE_NAME . + + docker push $IMAGE_NAME diff --git a/src/compute_horde_prompt_gen/download_model.py b/src/compute_horde_prompt_gen/download_model.py index 0975e23..8e062b5 100644 --- a/src/compute_horde_prompt_gen/download_model.py +++ b/src/compute_horde_prompt_gen/download_model.py @@ -43,6 +43,7 @@ model_name = MODEL_PATHS[args.model_name] print(f"Saving {model_name} model to {save_path}") + quantization_config = None if args.quantize: import torch from transformers import BitsAndBytesConfig