Skip to content

Commit

Permalink
Merge pull request #1 from bethgelab/fix_plotting_issue
Browse files Browse the repository at this point in the history
fix seaborn style influence on the matplotlib styles
  • Loading branch information
kantharajucn authored Jun 15, 2021
2 parents 7e798b0 + 354a2bf commit 1a2dc99
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
21 changes: 14 additions & 7 deletions modelvshuman/plotting/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,9 @@ def confusion_matrix_helper(data, output_filename,

plt.savefig(output_filename, bbox_inches='tight', dpi=300)
plt.close()
sns.reset_defaults()
sns.reset_orig()
plt.style.use('default')


def plot_shape_bias_matrixplot(datasets,
Expand Down Expand Up @@ -632,21 +635,25 @@ def plot_matrix(datasets, analysis,
f, ax = plt.subplots(figsize=(22, 18))
cmap = sns.diverging_palette(230, 20, as_cmap=True)

heatmap = sns.heatmap(res["matrix"], mask=None, cmap=cmap, vmax=1.0, center=0,
sns.heatmap(res["matrix"], ax=ax, mask=None, cmap=cmap, vmax=1.0, center=0,
square=True, linewidths=2.0, cbar_kws={"shrink": .5},
xticklabels=True, yticklabels=True)

for i, tick_label in enumerate(heatmap.axes.get_yticklabels()):
for i, tick_label in enumerate(ax.axes.get_yticklabels()):
tick_label.set_color(colors[i])
for i, tick_label in enumerate(heatmap.axes.get_xticklabels()):
for i, tick_label in enumerate(ax.axes.get_xticklabels()):
tick_label.set_color(colors[i])

figure_path = pjoin(result_dir,
f"{dataset.name}_{analysis.plotting_name.replace(' ', '-')}_matrix{by_mean_str}.pdf")
heatmap.figure.savefig(figure_path, bbox_inches='tight', pad_inches=0)
plt.cla()
plt.clf()
plt.close('all')
f.savefig(figure_path, bbox_inches='tight', pad_inches=0)
f.clear()
plt.cla()
plt.clf()
plt.close()
sns.reset_defaults()
sns.reset_orig()
plt.style.use('default')


def sort_matrix_by_models_mean(result_dict):
Expand Down
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ setup_requires =
# setuptools >=38.3.0 # version with most `setup.cfg` bugfixes
install_requires =
torch==1.7.1
torchvision
torchvision==0.8.2
requests
gdown
scikit-image
Expand All @@ -44,16 +44,16 @@ install_requires =
PySocks
tensorflow_hub
tensorflow-gpu
tensorflow==2.0
tensorflow==2.5.0
matplotlib>=3.3.2
pandas
seaborn
ftfy
regex
tqdm
CLIP @ git+https://github.com/openai/CLIP#egg=CLIP
figshare @ git+https://github.com/cognoma/figshare#egg=figshare
pytorch_pretrained_vit
tensorflow-estimator==2.1.*
tests_require =
pytest
dependency_links =
Expand Down

0 comments on commit 1a2dc99

Please sign in to comment.