Skip to content

Commit

Permalink
Reformat with ruff-format
Browse files Browse the repository at this point in the history
  • Loading branch information
valohai-bot authored and akx committed Feb 22, 2024
1 parent 1671c21 commit f025366
Show file tree
Hide file tree
Showing 21 changed files with 91 additions and 181 deletions.
6 changes: 1 addition & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@ repos:
- id: ruff
args:
- --fix

- repo: https://github.com/psf/black
rev: 23.9.1
hooks:
- id: black
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
Expand Down
4 changes: 1 addition & 3 deletions tests/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def test_parameter_value_interpolation(example1_config):
],
)
command = " && ".join(command)
assert (
command == "asdf hello {parameter-value:hello} hello && dsfargeg 840 && hello"
)
assert command == "asdf hello {parameter-value:hello} hello && dsfargeg 840 && hello"


def test_parameter_value_with_falsy_values(example1_config):
Expand Down
9 changes: 2 additions & 7 deletions tests/test_get_step_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ def test_get_step_by_name_doesnt_exist():

def test_get_step_by_command():
config = Config.parse([echo_step, list_step])
assert (
echo_step["step"] == config.get_step_by(command="echo HELLO WORLD").serialize()
)
assert echo_step["step"] == config.get_step_by(command="echo HELLO WORLD").serialize()
assert list_step["step"] == config.get_step_by(command="ls").serialize()


Expand All @@ -34,10 +32,7 @@ def test_get_step_by_name_and_command():
config = Config.parse([echo_step, list_step])
assert not config.get_step_by(name="greeting", command="echo HELLO MORDOR")
assert not config.get_step_by(name="farewell", command="echo HELLO WORLD")
assert (
echo_step["step"]
== config.get_step_by(name="greeting", command="echo HELLO WORLD").serialize()
)
assert echo_step["step"] == config.get_step_by(name="greeting", command="echo HELLO WORLD").serialize()


def test_get_step_by_non_existing_attribute():
Expand Down
16 changes: 4 additions & 12 deletions tests/test_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
def test_optional_flag():
items = lint_file("./examples/flag-optional-example.yaml")
warning = "Step test, parameter case-insensitive: `optional` has no effect on flag-type parameters"
assert any(
(warning in item["message"]) for item in items.warnings
) # pragma: no branch
assert any((warning in item["message"]) for item in items.warnings) # pragma: no branch


@pytest.mark.parametrize(
Expand All @@ -33,9 +31,7 @@ def test_optional_flag():
def test_invalid_parameter_default(file, expected_message):
items = lint_file(get_warning_example_path(file))
messages = [item["message"] for item in chain(items.warnings, items.errors)]
assert any(
expected_message in message for message in messages
), messages # pragma: no branch
assert any(expected_message in message for message in messages), messages # pragma: no branch


@pytest.mark.parametrize(
Expand All @@ -53,9 +49,7 @@ def test_invalid_parameter_default(file, expected_message):
def test_invalid_indentation(file, expected_message):
items = lint_file(get_error_example_path(file))
messages = [item["message"] for item in chain(items.hints, items.errors)]
assert any(
expected_message in message for message in messages
), messages # pragma: no branch
assert any(expected_message in message for message in messages), messages # pragma: no branch


@pytest.mark.parametrize(
Expand All @@ -67,9 +61,7 @@ def test_invalid_indentation(file, expected_message):
)
def test_expression_lint_ok(file_path):
items = lint_file(get_valid_example_path(file_path))
assert items.is_valid(), [
item["message"] for item in chain(items.hints, items.errors)
]
assert items.is_valid(), [item["message"] for item in chain(items.hints, items.errors)]


@pytest.mark.parametrize(
Expand Down
76 changes: 44 additions & 32 deletions tests/test_multiple_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,50 @@

def test_multiple_param_build(multiple_param_config):
step = multiple_param_config.steps["example"]
assert step.build_command({}) == [
"echo "
"--unlabeled-samples=60000 "
"--encoder-layers=1000-500-250-250-250-10 " # dash-separated
"--denoising-cost-x=1000,1,0.01,0.01,0.01,0.01,0.01 " # comma-separated
"--decoder-spec gauss " # repeated
"--decoder-spec railgun", # repeated
]
assert step.build_command({"encoder-layers": 6, "denoising-cost-x": (2, "e")}) == [
"echo "
"--unlabeled-samples=60000 "
"--encoder-layers=6 " # dash-separated
"--denoising-cost-x=2,e " # comma-separated
"--decoder-spec gauss " # repeated
"--decoder-spec railgun", # repeated
]
assert step.build_command(
{"encoder-layers": None, "denoising-cost-x": (2, "e")},
) == [
"echo "
"--unlabeled-samples=60000 "
"--denoising-cost-x=2,e " # comma-separated
"--decoder-spec gauss " # repeated
"--decoder-spec railgun", # repeated
]
assert step.build_command({"encoder-layers": [], "denoising-cost-x": (2, "e")}) == [
"echo "
"--unlabeled-samples=60000 "
"--denoising-cost-x=2,e " # comma-separated
"--decoder-spec gauss " # repeated
"--decoder-spec railgun", # repeated
]
assert (
step.build_command({})
== [
"echo "
"--unlabeled-samples=60000 "
"--encoder-layers=1000-500-250-250-250-10 " # dash-separated
"--denoising-cost-x=1000,1,0.01,0.01,0.01,0.01,0.01 " # comma-separated
"--decoder-spec gauss " # repeated
"--decoder-spec railgun", # repeated
]
)
assert (
step.build_command({"encoder-layers": 6, "denoising-cost-x": (2, "e")})
== [
"echo "
"--unlabeled-samples=60000 "
"--encoder-layers=6 " # dash-separated
"--denoising-cost-x=2,e " # comma-separated
"--decoder-spec gauss " # repeated
"--decoder-spec railgun", # repeated
]
)
assert (
step.build_command(
{"encoder-layers": None, "denoising-cost-x": (2, "e")},
)
== [
"echo "
"--unlabeled-samples=60000 "
"--denoising-cost-x=2,e " # comma-separated
"--decoder-spec gauss " # repeated
"--decoder-spec railgun", # repeated
]
)
assert (
step.build_command({"encoder-layers": [], "denoising-cost-x": (2, "e")})
== [
"echo "
"--unlabeled-samples=60000 "
"--denoising-cost-x=2,e " # comma-separated
"--decoder-spec gauss " # repeated
"--decoder-spec railgun", # repeated
]
)


def test_multiple_param_validate(multiple_param_config):
Expand Down
5 changes: 1 addition & 4 deletions tests/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,4 @@ def test_unknown_parse():
cfg = parse("[{ city_name: Constantinople }]")
lint_result = cfg.lint()
assert lint_result.warning_count == 1
assert (
list(lint_result.warnings)[0]["message"]
== "No parser for {'city_name': 'Constantinople'}"
)
assert list(lint_result.warnings)[0]["message"] == "No parser for {'city_name': 'Constantinople'}"
23 changes: 5 additions & 18 deletions tests/test_pipeline_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ def test_pipeline_conversion_no_output_timeout(pipeline_config: Config):
commit_identifier="latest",
).convert_pipeline(pipeline_config.pipelines["My medium pipeline"])
train_node = next(node for node in result["nodes"] if node["name"] == "train")
assert (
train_node["template"]["runtime_config"]["no_output_timeout"]
== parse_duration_string("6h").total_seconds()
)
assert train_node["template"]["runtime_config"]["no_output_timeout"] == parse_duration_string("6h").total_seconds()


def test_pipeline_conversion_override_inputs(pipeline_overridden_config: Config):
Expand All @@ -48,18 +45,12 @@ def test_pipeline_conversion_override_inputs(pipeline_overridden_config: Config)
assert merged["template"]["image"] == "merge node image"
assert isinstance(merged["template"]["inputs"], dict)
assert len(merged["template"]["inputs"].get("training-images", [])) == 2
assert (
merged["template"]["inputs"].get("training-images", [])[0]
== "merged node image 1"
)
assert merged["template"]["inputs"].get("training-images", [])[0] == "merged node image 1"
assert len(merged["template"]["parameters"].items()) == 3

overridden = next(node for node in result["nodes"] if node["name"] == "overridden")
assert isinstance(overridden["template"]["inputs"], dict)
assert (
overridden["template"]["inputs"].get("training-images", [])[0]
== "overridden node image"
)
assert overridden["template"]["inputs"].get("training-images", [])[0] == "overridden node image"
assert len(overridden["template"]["inputs"].get("training-images", [])) == 1
assert len(overridden["template"]["parameters"].items()) == 3

Expand All @@ -80,11 +71,7 @@ def test_pipeline_parameter_conversion(pipeline_with_parameters_config):
assert isinstance(parameter["config"]["targets"], list)

# When pipeline parameter has no default value, the expression should be empty
parameter_config = next(
param for param in pipe.parameters if param.name == parameter_name
)
expression_value = (
parameter_config.default if parameter_config.default is not None else ""
)
parameter_config = next(param for param in pipe.parameters if param.name == parameter_name)
expression_value = parameter_config.default if parameter_config.default is not None else ""
assert parameter["expression"] == expression_value
assert type(parameter["expression"]) == type(expression_value)
4 changes: 1 addition & 3 deletions tests/test_serialize_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ def test_serialize_workload_resources(step_with_resources):
def test_serialize_partial_resources(step_with_partial_resources):
"""Serialized data only contains keys found in the config."""
config = step_with_partial_resources
resources = config.steps["contains partial workload resources"].serialize()[
"resources"
]
resources = config.steps["contains partial workload resources"].serialize()["resources"]

assert "min" in resources["cpu"]
assert "max" not in resources["cpu"]
35 changes: 7 additions & 28 deletions tests/test_step_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,15 @@ def test_parse_inputs(example2_config):
config = example2_config
step = config.steps["run training"]
assert len(step.inputs) == 5
assert (
len([inp.description for inp in step.inputs.values() if inp.description]) == 4
)
assert len([inp.description for inp in step.inputs.values() if inp.description]) == 4


def test_parse_input_defaults(example3_config):
config = example3_config
step = config.steps["batch inference"]
assert len(step.inputs) == 2
assert step.inputs["model"].default == "s3://foo/model.pb"
assert (
isinstance(step.inputs["images"].default, list)
and len(step.inputs["images"].default) == 2
)
assert isinstance(step.inputs["images"].default, list) and len(step.inputs["images"].default) == 2


def test_parse(example1_config):
Expand Down Expand Up @@ -167,35 +162,19 @@ def test_timeouts(timeouts_config):
assert timeouts_config.steps["short-time-limit"].time_limit.total_seconds() == 300
assert timeouts_config.steps["short-time-limit"].no_output_timeout is None
assert timeouts_config.steps["big-no-output-timeout"].time_limit is None
assert (
timeouts_config.steps["big-no-output-timeout"].no_output_timeout.total_seconds()
== 86400
)
assert (
timeouts_config.steps["human-readable-time-limit"].time_limit.total_seconds()
== 5405
)
assert (
timeouts_config.steps[
"human-readable-time-limit"
].no_output_timeout.total_seconds()
== 86400 * 2
) # 48h
assert timeouts_config.steps["big-no-output-timeout"].no_output_timeout.total_seconds() == 86400
assert timeouts_config.steps["human-readable-time-limit"].time_limit.total_seconds() == 5405
assert timeouts_config.steps["human-readable-time-limit"].no_output_timeout.total_seconds() == 86400 * 2 # 48h


def test_bling(example1_config: Config) -> None:
assert example1_config.steps["run training"].category == "Training"
assert (
example1_config.steps["run training"].icon
== "https://valohai.com/assets/img/valohai-logo.svg"
)
assert example1_config.steps["run training"].icon == "https://valohai.com/assets/img/valohai-logo.svg"


def test_widget(example1_config: Config) -> None:
parameters = example1_config.steps["run training"].parameters
assert (
parameters["sql-query"].widget and parameters["sql-query"].widget.type == "sql"
)
assert parameters["sql-query"].widget and parameters["sql-query"].widget.type == "sql"
widget = parameters["output-alias"].widget
assert widget and widget.type == "datumalias"
assert widget and widget.settings and widget.settings["width"] == 123
Expand Down
4 changes: 1 addition & 3 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ def test_raise():

def test_error_list():
errs = [f"{err}" for err in validate(invalid_obj, raise_exc=False)]
assert any(
("Additional properties are not allowed" in err) for err in errs
) # pragma: no branch
assert any(("Additional properties are not allowed" in err) for err in errs) # pragma: no branch
assert any(("required property" in err) for err in errs) # pragma: no branch
assert any(("0 is not of type 'string'" in err) for err in errs)
10 changes: 2 additions & 8 deletions valohai_yaml/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,7 @@ def _validate_json_schema(
key=lambda error: (relevance(error), repr(error.path)),
)
for error in errors:
simplified_schema_path = [
el
for el in list(error.relative_schema_path)
if el not in ("properties", "items")
]
simplified_schema_path = [el for el in list(error.relative_schema_path) if el not in ("properties", "items")]
obj_path = [str(el) for el in error.path]
styled_validator = styler(error.validator.title(), bold=True)
styled_schema_path = styler(".".join(simplified_schema_path), bold=True)
Expand Down Expand Up @@ -158,9 +154,7 @@ def lint(
except pyyaml.YAMLError as err:
if hasattr(err, "problem_mark"):
mark = err.problem_mark
indent_error = (
f"Indentation Error at line {mark.line + 1}, column {mark.column + 1}"
)
indent_error = f"Indentation Error at line {mark.line + 1}, column {mark.column + 1}"
lr.add_error(indent_error)
else:
lr.add_error(str(err))
Expand Down
6 changes: 1 addition & 5 deletions valohai_yaml/objs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@ def serialize(self) -> Any: # type = Any because subclasses may override
@classmethod
def parse(cls: Type[T], data: SerializedDict) -> T:
inst = cls(
**{
key.replace("-", "_"): value
for (key, value) in data.items()
if not key.startswith("_")
},
**{key.replace("-", "_"): value for (key, value) in data.items() if not key.startswith("_")},
)
inst._original_data = data
return inst
Expand Down
27 changes: 9 additions & 18 deletions valohai_yaml/objs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,8 @@ def serialize(self) -> List[SerializedDict]:
chain(
({"step": step.serialize()} for (key, step) in self.steps.items()),
({"task": task.serialize()} for (key, task) in self.tasks.items()),
(
{"endpoint": endpoint.serialize()}
for (key, endpoint) in sorted(self.endpoints.items())
),
(
{"pipeline": pipeline.serialize()}
for (key, pipeline) in sorted(self.pipelines.items())
),
({"endpoint": endpoint.serialize()} for (key, endpoint) in sorted(self.endpoints.items())),
({"pipeline": pipeline.serialize()} for (key, pipeline) in sorted(self.pipelines.items())),
),
)

Expand Down Expand Up @@ -179,14 +173,11 @@ def default_merge(cls, a: "Config", b: "Config") -> "Config":
return result

def __repr__(self) -> str: # pragma: no cover # noqa: D105
return (
"<Config with %d steps (%r), %d endpoints (%r), and %d pipelines (%r)>"
% (
len(self.steps),
self.steps,
len(self.endpoints),
sorted(self.endpoints),
len(self.pipelines),
sorted(self.pipelines),
)
return "<Config with %d steps (%r), %d endpoints (%r), and %d pipelines (%r)>" % (
len(self.steps),
self.steps,
len(self.endpoints),
sorted(self.endpoints),
len(self.pipelines),
sorted(self.pipelines),
)
Loading

0 comments on commit f025366

Please sign in to comment.