Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine Tune Training #1376

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
18 changes: 16 additions & 2 deletions client/dive-common/apispec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,16 @@ interface Category {
}

interface TrainingConfigs {
configs: string[];
default: string;
training: {
configs: string[];
default: string;
};
models: Record<string, {
name: string;
type: string;
path?: string;
folderId?: string;
}>;
}

type Pipelines = Record<string, Category>;
Expand Down Expand Up @@ -150,6 +158,12 @@ interface Api {
config: string,
annotatedFramesOnly: boolean,
labelText?: string,
fineTuneModel?: {
name: string;
type: string;
path?: string;
folderId?: string;
},
): Promise<unknown>;

loadMetadata(datasetId: string): Promise<DatasetMeta>;
Expand Down
39 changes: 37 additions & 2 deletions client/platform/desktop/backend/native/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,27 @@ async function getPipelineList(settings: Settings): Promise<Pipelines> {
return ret;
}

// Function to recursively traverse a directory and collect files with specified extensions
function getFilesWithExtensions(dir: string, extensions: string[], fileList: string[] = []) {
const files = fs.readdirSync(dir);

files.forEach((file) => {
const filePath = npath.join(dir, file);
const fileStat = fs.statSync(filePath);

if (fileStat.isDirectory()) {
fileList.concat(getFilesWithExtensions(filePath, extensions, fileList));
} else {
const fileExtension = npath.extname(file).toLowerCase();
if (extensions.includes(fileExtension)) {
fileList.push(filePath);
}
}
});

return fileList;
}

/**
* get training configurations
*/
Expand All @@ -433,6 +454,7 @@ async function getTrainingConfigs(settings: Settings): Promise<TrainingConfigs>
const defaultTrainingConfiguration = 'train_detector_default.viame_csv.conf';
const allowedPatterns = /\.viame_csv\.conf$/;
const disallowedPatterns = /.*(_nf|\.continue)\.viame_csv\.conf$/;
const allowedModelExtensions = ['.zip', '.pth', '.pt', '.py', '.weights', '.wt'];
const exists = await fs.pathExists(pipelinePath);
if (!exists) {
throw new Error(`Path does not exist: ${pipelinePath}`);
Expand All @@ -441,9 +463,22 @@ async function getTrainingConfigs(settings: Settings): Promise<TrainingConfigs>
configs = configs
.filter((p) => (p.match(allowedPatterns) && !p.match(disallowedPatterns)))
.sort((a, b) => (a === defaultTrainingConfiguration ? -1 : a.localeCompare(b)));
// Get Model files in the pipeline directory
const modelList = getFilesWithExtensions(pipelinePath, allowedModelExtensions);
const models: TrainingConfigs['models'] = {};
modelList.forEach((model) => {
models[npath.basename(model)] = {
name: npath.basename(model),
type: npath.extname(model),
path: model,
};
});
return {
default: configs[0],
configs,
training: {
default: configs[0],
configs,
},
models,
};
}

Expand Down
5 changes: 5 additions & 0 deletions client/platform/desktop/backend/native/viame.ts
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,11 @@ async function train(
command.push('--gt-frames-only');
}

if (runTrainingArgs.fineTuneModel && runTrainingArgs.fineTuneModel.path) {
command.push('--init-weights');
command.push(runTrainingArgs.fineTuneModel.path);
}

const job = observeChild(spawn(command.join(' '), {
shell: viameConstants.shell,
cwd: jobWorkDir,
Expand Down
7 changes: 7 additions & 0 deletions client/platform/desktop/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ export interface RunTraining {
annotatedFramesOnly: boolean;
// contents of labels.txt file
labelText?: string;
// fine tuning model
fineTuneModel?: {
name: string;
type: string;
path?: string;
folderId?: string;
};
}

export interface ConversionArgs {
Expand Down
7 changes: 7 additions & 0 deletions client/platform/desktop/frontend/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,20 @@ async function runTraining(
config: string,
annotatedFramesOnly: boolean,
labelText?: string,
fineTuneModel?: {
name: string;
type: string;
path?: string;
folderId?: string;
},
): Promise<DesktopJob> {
const args: RunTraining = {
datasetIds: folderIds,
pipelineName,
trainingConfig: config,
annotatedFramesOnly,
labelText,
fineTuneModel,
};
return ipcRenderer.invoke('run-training', args);
}
Expand Down
50 changes: 45 additions & 5 deletions client/platform/desktop/frontend/components/MultiTrainingMenu.vue
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,14 @@ export default defineComponent({
stagedItems: {} as Record<string, DatasetMeta>,
trainingOutputName: '',
selectedTrainingConfig: 'foo.whatever',
fineTuneTraining: false,
selectedFineTune: null as null | string,
trainingConfigurations: {
configs: [],
default: '',
training: {
configs: [],
default: '',
},
models: {},
} as TrainingConfigs,
annotatedFramesOnly: false,
});
Expand Down Expand Up @@ -68,7 +73,16 @@ export default defineComponent({
onBeforeMount(async () => {
const configs = await getTrainingConfigurations();
data.trainingConfigurations = configs;
data.selectedTrainingConfig = configs.default;
data.selectedTrainingConfig = configs.training.default;
});

const modelList = computed(() => {
if (data.trainingConfigurations.models) {
const list = Object.entries(data.trainingConfigurations.models)
.map(([, value]) => value.name);
return list;
}
return [];
});

function toggleStaged(meta: DatasetMeta) {
Expand All @@ -94,12 +108,20 @@ export default defineComponent({
));

async function runTrainingOnFolder() {
// Get the full data for fine tuning
let foundTrainingModel;
if (data.fineTuneTraining) {
foundTrainingModel = Object.values(data.trainingConfigurations.models)
.find((item) => item.name === data.selectedFineTune);
}
try {
await runTraining(
stagedItems.value.map(({ id }) => id),
data.trainingOutputName,
data.selectedTrainingConfig,
data.annotatedFramesOnly,
undefined,
foundTrainingModel,
);
root.$router.push({ name: 'jobs' });
} catch (err) {
Expand All @@ -124,6 +146,7 @@ export default defineComponent({
nameRules,
itemsPerPageOptions,
clientSettings,
modelList,
available: {
items: availableItems,
headers: headersTmpl.concat(
Expand Down Expand Up @@ -163,7 +186,10 @@ export default defineComponent({
<v-card-text>
Add datasets to the staging area and choose a training configuration.
</v-card-text>
<v-row class="mt-4 pt-0">
<v-row
class="mt-4 pt-0"
dense
>
<v-col sm="5">
<v-text-field
v-model="data.trainingOutputName"
Expand All @@ -180,7 +206,7 @@ export default defineComponent({
outlined
dense
label="Configuration File (Required)"
:items="data.trainingConfigurations.configs"
:items="data.trainingConfigurations.training.configs"
:hint="data.selectedTrainingConfig"
persistent-hint
>
Expand Down Expand Up @@ -229,6 +255,20 @@ export default defineComponent({
Train on ({{ staged.items.value.length }}) Datasets
</v-btn>
</div>
<div class="d-flex flex-row mt-7">
<v-checkbox
v-model="data.fineTuneTraining"
label="Fine Tuning"
hint="Fine tune an existing model"
/>
<v-spacer />
<v-select
v-if="data.fineTuneTraining"
v-model="data.selectedFineTune"
:items="modelList"
label="Fine Tune Model"
/>
</div>
</div>
<div>
<v-card-title class="text-h4">
Expand Down
8 changes: 7 additions & 1 deletion client/platform/web-girder/api/rpc.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@ function runTraining(
config: string,
annotatedFramesOnly: boolean,
labelText?: string,
fineTuneModel?: {
name: string;
type: string;
path?: string;
folderId?: string;
},
) {
return girderRest.post('dive_rpc/train', { folderIds, labelText }, {
return girderRest.post('dive_rpc/train', { folderIds, labelText, fineTuneModel }, {
params: {
pipelineName, config, annotatedFramesOnly,
},
Expand Down
50 changes: 48 additions & 2 deletions client/platform/web-girder/views/RunTrainingMenu.vue
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ export default defineComponent({
const trainingConfigurations = ref<TrainingConfigs | null>(null);
const selectedTrainingConfig = ref<string | null>(null);
const annotatedFramesOnly = ref<boolean>(false);
const fineTuning = ref<boolean>(false);
const selectedFineTune = ref<string>('');
const {
request: _runTrainingRequest,
reset: dismissJobDialog,
Expand All @@ -45,10 +47,30 @@ export default defineComponent({

const successMessage = computed(() => `Started training on ${props.selectedDatasetIds.length} dataset(s)`);

const fineTuneModelList = computed(() => {
const modelList: string[] = [];
if (trainingConfigurations.value?.models) {
Object.entries(trainingConfigurations.value.models)
.forEach(([, value]) => {
modelList.push(value.name);
});
}
return modelList;
});
const selectedFineTuneObject = computed(() => {
if (selectedFineTune.value !== '' && trainingConfigurations.value?.models) {
const found = Object.entries(trainingConfigurations.value.models)
.find(([, value]) => value.name === selectedFineTune.value);
if (found) {
return found[1];
}
}
return undefined;
});
onBeforeMount(async () => {
const resp = await getTrainingConfigurations();
trainingConfigurations.value = resp;
selectedTrainingConfig.value = resp.default;
selectedTrainingConfig.value = resp.training.default;
});

const trainingDisabled = computed(() => props.selectedDatasetIds.length === 0);
Expand All @@ -73,13 +95,16 @@ export default defineComponent({
selectedTrainingConfig.value,
annotatedFramesOnly.value,
labelText.value,
selectedFineTuneObject.value,
);
}
return runTraining(
props.selectedDatasetIds,
outputPipelineName,
selectedTrainingConfig.value,
annotatedFramesOnly.value,
undefined,
selectedFineTuneObject.value,
);
});
menuOpen.value = false;
Expand Down Expand Up @@ -115,6 +140,10 @@ export default defineComponent({
labelFile,
clearLabelText,
simplifyTrainingName,
// Fine-Tuning
fineTuning,
fineTuneModelList,
selectedFineTune,
};
},
});
Expand Down Expand Up @@ -197,7 +226,7 @@ export default defineComponent({
outlined
class="my-4"
label="Configuration File"
:items="trainingConfigurations.configs"
:items="trainingConfigurations.training.configs"
:hint="selectedTrainingConfig"
persistent-hint
>
Expand All @@ -224,6 +253,23 @@ export default defineComponent({
persistent-hint
class="pt-0"
/>
<v-checkbox
v-model="fineTuning"
label="Fine Tune Model"
hint="Fine Tune an existing model"
persistent-hint
class="pt-0"
/>
<v-select
v-if="fineTuning"
v-model="selectedFineTune"
outlined
class="my-4"
label="Fine Tune Model"
:items="fineTuneModelList"
hint="Model to Fine Tune"
persistent-hint
/>
<v-btn
depressed
block
Expand Down
Loading