Skip to content

Commit

Permalink
style: format with black and remove unuse var
Browse files Browse the repository at this point in the history
  • Loading branch information
corentin-hrflow committed Dec 22, 2023
1 parent 3ce6c04 commit a741f47
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 15 deletions.
4 changes: 2 additions & 2 deletions hrflow/hrflow/text/tagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def post(
output_lang=output_lang,
top_n=top_n,
)

if texts is None and text is not None:
payload["text"] = text
elif text is None and texts is not None:
Expand All @@ -79,6 +79,6 @@ def post(
raise ValueError("Either text or texts must be provided.")
else:
raise ValueError("Only one of text or texts must be provided.")

response = self.client.post("text/tagging", json=payload)
return validate_response(response)
2 changes: 2 additions & 0 deletions tests/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
_var_from_env_get,
)


@pytest.fixture(scope="module")
def hrflow_client():
return Hrflow(
api_secret=_var_from_env_get("HRFLOW_API_KEY"),
api_user=_var_from_env_get("HRFLOW_USER_EMAIL"),
)


def _job_get() -> t.Dict[str, t.Any]:
return dict(
reference=str(uuid4()),
Expand Down
22 changes: 17 additions & 5 deletions tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@
_ASYNC_RETRY_INTERVAL_SECONDS = 5
_ASYNC_TIMEOUT_SECONDS = 60


@pytest.fixture(scope="module")
def hrflow_client():
return Hrflow(
api_secret=_var_from_env_get("HRFLOW_API_KEY"),
api_user=_var_from_env_get("HRFLOW_USER_EMAIL"),
)


def _profile_get() -> t.Dict[str, t.Any]:
return dict(
reference=str(uuid4()),
Expand Down Expand Up @@ -330,7 +332,9 @@ def test_profile_parsing_file_quicksilver_async_basic(hrflow_client):
assert _ASYNC_RETRY_INTERVAL_SECONDS > 0
for _ in range(max(0, _ASYNC_TIMEOUT_SECONDS // _ASYNC_RETRY_INTERVAL_SECONDS)):
model = ProfileIndexingResponse.model_validate(
hrflow_client.profile.storing.get(source_key=SOURCE_KEY, reference=reference)
hrflow_client.profile.storing.get(
source_key=SOURCE_KEY, reference=reference
)
)
if model.code == http_codes.ok:
break
Expand Down Expand Up @@ -408,7 +412,9 @@ def test_profile_parsing_file_mozart_async_basic(hrflow_client):
assert _ASYNC_RETRY_INTERVAL_SECONDS > 0
for _ in range(max(0, _ASYNC_TIMEOUT_SECONDS // _ASYNC_RETRY_INTERVAL_SECONDS)):
model = ProfileIndexingResponse.model_validate(
hrflow_client.profile.storing.get(source_key=SOURCE_KEY, reference=reference)
hrflow_client.profile.storing.get(
source_key=SOURCE_KEY, reference=reference
)
)
if model.code == http_codes.ok:
break
Expand Down Expand Up @@ -514,7 +520,9 @@ def test_profile_asking_basic(hrflow_client):
model = ProfileAskingResponse.model_validate(
hrflow_client.profile.asking.get(
source_key=SOURCE_KEY,
key=_indexed_response_get(hrflow_client, SOURCE_KEY, _profile_get()).data.key,
key=_indexed_response_get(
hrflow_client, SOURCE_KEY, _profile_get()
).data.key,
questions=[
"What is the full name of the profile ?",
],
Expand All @@ -539,7 +547,9 @@ def test_profile_asking_multiple_questions(hrflow_client):
hrflow_client.profile.asking.get(
source_key=SOURCE_KEY,
questions=questions,
key=_indexed_response_get(hrflow_client, SOURCE_KEY, _profile_get()).data.key,
key=_indexed_response_get(
hrflow_client, SOURCE_KEY, _profile_get()
).data.key,
)
)
assert model.code == http_codes.ok
Expand All @@ -556,7 +566,9 @@ def test_profile_asking_no_question(hrflow_client):
model = ProfileAskingResponse.model_validate(
hrflow_client.profile.asking.get(
source_key=SOURCE_KEY,
key=_indexed_response_get(hrflow_client, SOURCE_KEY, _profile_get()).data.key,
key=_indexed_response_get(
hrflow_client, SOURCE_KEY, _profile_get()
).data.key,
questions=None,
)
)
Expand Down
16 changes: 10 additions & 6 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
),
]


@pytest.fixture(scope="module")
def hrflow_client():
return Hrflow(
api_secret=_var_from_env_get("HRFLOW_API_KEY"),
api_user=_var_from_env_get("HRFLOW_USER_EMAIL"),
)


@pytest.mark.text
@pytest.mark.embedding
def test_embedding_basic(hrflow_client):
Expand Down Expand Up @@ -171,6 +173,7 @@ def test_tagger_rome_family_with_text_param(hrflow_client):
assert model.code == requests.codes.ok
assert isinstance(model.data, TextTaggingDataItem)


@pytest.mark.text
@pytest.mark.tagging
def test_tagger_rome_family_with_texts_param(hrflow_client):
Expand All @@ -185,6 +188,7 @@ def test_tagger_rome_family_with_texts_param(hrflow_client):
assert isinstance(model.data, list)
assert len(model.data) == len(TAGGING_TEXTS)


@pytest.mark.text
@pytest.mark.tagging
def test_tagger_rome_family_with_text_and_texts_param(hrflow_client):
Expand All @@ -198,9 +202,10 @@ def test_tagger_rome_family_with_text_and_texts_param(hrflow_client):
)
)
pytest.fail("Should have raised a ValueError")
except ValueError as e:
except ValueError:
pass



@pytest.mark.text
@pytest.mark.tagging
def test_tagger_rome_family_without_text_or_texts_param(hrflow_client):
Expand All @@ -212,9 +217,10 @@ def test_tagger_rome_family_without_text_or_texts_param(hrflow_client):
)
)
pytest.fail("Should have raised a ValueError")
except ValueError as e:
except ValueError:
pass


def _tagging_test(
hrflow_client: Hrflow,
algorithm_key: TAGGING_ALGORITHM,
Expand Down Expand Up @@ -359,9 +365,7 @@ def test_ocr_basic(hrflow_client):
495c951bbae6b/profiles/52e3c23a5f21190c59f53c41b5630ecb5d414f94/parsing/resume.pdf"""
file = _file_get(s3_url, "ocr")
assert file is not None
model = TextOCRResponse.model_validate(
hrflow_client.text.ocr.post(file=file)
)
model = TextOCRResponse.model_validate(hrflow_client.text.ocr.post(file=file))
assert model.code == requests.codes.ok
assert "ocr" in model.message.lower()

Expand Down
3 changes: 1 addition & 2 deletions tests/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ class TextTaggingDataItem(BaseModel):
@classmethod
def _check(cls, values: t.Dict[str, t.List[t.Any]]) -> t.Dict[str, t.List[t.Any]]:
if isinstance(values, list):
return [cls._check(item) for item in values
]
return [cls._check(item) for item in values]
li = len(values.get("ids"))
lp = len(values.get("predictions"))
lt = len(values.get("tags"))
Expand Down

0 comments on commit a741f47

Please sign in to comment.