Skip to content

Commit

Permalink
Ensure users do not use the CPU mirror for non-CPU installs (#871)
Browse files Browse the repository at this point in the history
* Fix cannot start app - uv cmd invalid URL

* nit

* Ensure users do not use the CPU mirror for non-CPU installs
  • Loading branch information
webfiltered authored Feb 8, 2025
1 parent e6348ac commit a880481
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,4 @@ export enum DownloadStatus {

export const CUDA_TORCH_URL = 'https://download.pytorch.org/whl/cu124';
export const NIGHTLY_CPU_TORCH_URL = 'https://download.pytorch.org/whl/nightly/cpu';
export const DEFAULT_PYPI_INDEX_URL = 'https://pypi.org/simple/';
22 changes: 16 additions & 6 deletions src/virtualEnvironment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { rm } from 'node:fs/promises';
import os, { EOL } from 'node:os';
import path from 'node:path';

import { CUDA_TORCH_URL, NIGHTLY_CPU_TORCH_URL } from './constants';
import { CUDA_TORCH_URL, DEFAULT_PYPI_INDEX_URL, NIGHTLY_CPU_TORCH_URL } from './constants';
import type { TorchDeviceType } from './preload';
import { HasTelemetry, ITelemetry, trackEvent } from './services/telemetry';
import { getDefaultShell, getDefaultShellArgs } from './shell/util';
Expand Down Expand Up @@ -64,16 +64,26 @@ export function getPipInstallArgs(config: PipInstallConfig): string[] {
* @param device The device type
* @returns The default torch mirror
*/
const getDefaultTorchMirror = (device: TorchDeviceType): string => {
function getDefaultTorchMirror(device: TorchDeviceType): string {
log.debug('Falling back to default torch mirror');
switch (device) {
case 'mps':
return NIGHTLY_CPU_TORCH_URL;
case 'nvidia':
return CUDA_TORCH_URL;
default:
return '';
return DEFAULT_PYPI_INDEX_URL;
}
};
}

/** Disallows using the default mirror (CPU torch) when the selected device is not CPU. */
function fixDeviceMirrorMismatch(device: TorchDeviceType, mirror: string | undefined) {
if (mirror === DEFAULT_PYPI_INDEX_URL) {
if (device === 'nvidia') return CUDA_TORCH_URL;
else if (device === 'mps') return NIGHTLY_CPU_TORCH_URL;
}
return mirror;
}

/**
* Manages a virtual Python environment using uv.
Expand Down Expand Up @@ -145,7 +155,7 @@ export class VirtualEnvironment implements HasTelemetry {
this.selectedDevice = selectedDevice ?? 'cpu';
this.pythonMirror = pythonMirror;
this.pypiMirror = pypiMirror;
this.torchMirror = torchMirror;
this.torchMirror = fixDeviceMirrorMismatch(selectedDevice!, torchMirror);

// uv defaults to .venv
this.venvPath = path.join(venvPath, '.venv');
Expand Down Expand Up @@ -460,7 +470,7 @@ export class VirtualEnvironment implements HasTelemetry {
const config: PipInstallConfig = {
packages: ['torch', 'torchvision', 'torchaudio'],
indexUrl: torchMirror,
prerelease: torchMirror?.includes('nightly'),
prerelease: torchMirror.includes('nightly'),
};

const installArgs = getPipInstallArgs(config);
Expand Down

0 comments on commit a880481

Please sign in to comment.