Skip to content

Commit

Permalink
when X_ref input is None, use X_best_f. whether plot X_ref only depen…
Browse files Browse the repository at this point in the history
…ds on plotRef indicator
  • Loading branch information
xuyuting committed Sep 17, 2024
1 parent ca94e42 commit 30dc2ec
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions obsidian/plotting/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,12 @@ def factor_plot(optimizer: Optimizer,
# Create a dataframe of test samples for plotting
n_samples = 100
if X_ref is None:
df_mean = optimizer.X_best_f
X_test = pd.concat([df_mean]*n_samples, axis=0).reset_index(drop=True)
X_ref = optimizer.X_best_f
else:
if not isinstance(X_ref, pd.DataFrame):
raise TypeError('X_ref must be a DataFrame')
X_test = pd.concat([X_ref]*n_samples, axis=0).reset_index(drop=True)

X_test = pd.concat([X_ref]*n_samples, axis=0).reset_index(drop=True)
# Vary the indicated column
X_name = X_test.columns[feature_id]
param_i = optimizer.X_space.params[feature_id]
Expand Down Expand Up @@ -316,7 +315,7 @@ def factor_plot(optimizer: Optimizer,
line={'color': obsidian_colors.teal},
name='Mean'),
)
if (X_ref is not None) and plotRef:
if plotRef:
Y_pred_ref = optimizer.predict(X_ref, return_f_inv=not f_transform)
Y_mu_ref = Y_pred_ref[y_name+('_t (pred)' if f_transform else ' (pred)')].values
fig.add_trace(go.Scatter(x=X_ref.iloc[:, feature_id].values, y=Y_mu_ref,
Expand Down Expand Up @@ -382,8 +381,7 @@ def surface_plot(optimizer: Optimizer,

# Create a dataframe of test samples for plotting
n_grid = 100
df_mean = optimizer.X_best_f
X_test = pd.concat([df_mean]*(n_grid**2), axis=0).reset_index(drop=True)
X_test = pd.concat([optimizer.X_best_f]*(n_grid**2), axis=0).reset_index(drop=True)

# Create a mesh grid which is necessary for the 3D plot
X0_name = X_test.columns[feature_ids[0]]
Expand Down

0 comments on commit 30dc2ec

Please sign in to comment.