Skip to content

Commit

Permalink
Update: モデルファイル名からエポック数とステップ数を抽出
Browse files Browse the repository at this point in the history
  • Loading branch information
tsukumijima committed Oct 22, 2024
1 parent 4d41824 commit 13e44a8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
29 changes: 29 additions & 0 deletions aivmlib/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

import re
import rich
import traceback
import typer
Expand Down Expand Up @@ -100,10 +101,24 @@ def create_aivm(
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
return

# AIVM メタデータを生成
with hyper_parameters_path.open('rb') as hyper_parameters_file:
with style_vectors_path.open('rb') as style_vectors_file:
metadata = aivmlib.generate_aivm_metadata(model_architecture, hyper_parameters_file, style_vectors_file)

# モデルファイル名からエポック数とステップ数を抽出
model_file_name = safetensors_model_path.name
epoch_match = re.search(r'e(\d{2,})', model_file_name) # "e" の後ろに2桁以上の数字
step_match = re.search(r's(\d{2,})', model_file_name) # "s" の後ろに2桁以上の数字

# エポック数を設定
if epoch_match:
metadata.manifest.training_epochs = int(epoch_match.group(1))
# ステップ数を設定
if step_match:
metadata.manifest.training_steps = int(step_match.group(1))

# AIVM ファイルを生成
with safetensors_model_path.open('rb') as safetensors_file:
new_aivm_file_content = aivmlib.write_aivm_metadata(safetensors_file, metadata)
with output_path.open('wb') as f:
Expand Down Expand Up @@ -171,10 +186,24 @@ def create_aivmx(
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
return

# AIVM メタデータを生成
with hyper_parameters_path.open('rb') as hyper_parameters_file:
with style_vectors_path.open('rb') as style_vectors_file:
metadata = aivmlib.generate_aivm_metadata(model_architecture, hyper_parameters_file, style_vectors_file)

# モデルファイル名からエポック数とステップ数を抽出
model_file_name = onnx_model_path.name
epoch_match = re.search(r'e(\d{2,})', model_file_name) # "e" の後ろに2桁以上の数字
step_match = re.search(r's(\d{2,})', model_file_name) # "s" の後ろに2桁以上の数字

# エポック数を設定
if epoch_match:
metadata.manifest.training_epochs = int(epoch_match.group(1))
# ステップ数を設定
if step_match:
metadata.manifest.training_steps = int(step_match.group(1))

# AIVMX ファイルを生成
with onnx_model_path.open('rb') as onnx_file:
new_aivmx_file_content = aivmlib.write_aivmx_metadata(onnx_file, metadata)
with output_path.open('wb') as f:
Expand Down
1 change: 0 additions & 1 deletion aivmlib/schemas/aivm_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class ModelArchitecture(StrEnum):
StyleBertVITS2 = 'Style-Bert-VITS2'
StyleBertVITS2JPExtra = 'Style-Bert-VITS2 (JP-Extra)'


class ModelFormat(StrEnum):
Safetensors = 'Safetensors'
ONNX = 'ONNX'
Expand Down

0 comments on commit 13e44a8

Please sign in to comment.