Skip to content

Commit

Permalink
Merge pull request #57 from nd-ball/debugging
Browse files Browse the repository at this point in the history
bugfix issue #51
  • Loading branch information
jplalor authored Dec 21, 2023
2 parents 771f523 + fadb080 commit 2417485
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/)
and this project adheres to [Semantic Versioning](http://semver.org/).

## [0.4.11] - 2023-12-21

- Fix an issue with cli.evaluate.


## [0.4.10] - 2023-04-12

- Fix an issue with codecov, and also allow for Python 3.10 and 3.11.
Expand Down
23 changes: 17 additions & 6 deletions py_irt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,22 +266,33 @@ def evaluate(
# load saved params
irt_params = read_json(parameter_path)

# reverse dict for lookup
subjectIDX = {}
for (key, item) in irt_params["subject_ids"].items():
subjectIDX[item] = int(key)

itemIDX = {}
for (key, item) in irt_params["item_ids"].items():
itemIDX[item] = int(key)

# load subject, item pairs we want to test
subject_item_pairs = read_jsonlines(test_pairs_path)

# calculate predictions and write them to disk
config = IrtConfig(model_type=model_type, epochs=epochs,
initializers=initializers)

observation_subjects = [subjectIDX[entry["subject_id"]]
for entry in subject_item_pairs]
observation_items = [itemIDX[entry["item_id"]] for entry in subject_item_pairs]

irt_model = IrtModel.from_name(model_type)(
priors=config.priors,
priors="vague",
device=device,
num_items=len(irt_params["item_ids"]),
num_subjects=len(irt_params["subject_ids"]),
num_items=len(set(observation_items)),
num_subjects=len(set(observation_subjects)),
)

observation_subjects = [entry["subject_id"]
for entry in subject_item_pairs]
observation_items = [entry["item_id"] for entry in subject_item_pairs]
preds = irt_model.predict(observation_subjects,
observation_items, irt_params)
outputs = []
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "py-irt"
version = "0.4.10"
version = "0.4.11"
readme = "README.md"
homepage = "https://github.com/nd-ball/py-irt/"
description = "Bayesian IRT models in Python"
Expand Down

0 comments on commit 2417485

Please sign in to comment.