We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
你把原來的z做concat是為了更新encoder吧? 畢竟swap元素這個操作不可微分。 那你當初在搭這個模型的時候有試過straight through estimatior嗎?
甘溫
The text was updated successfully, but these errors were encountered:
順便分享一下我有在嘗試用tensorflow跑這個模型的實驗
swap跟剛剛講的straight through estimator我是這樣做
def get_idx(z, y): s = [tf.range(z.shape[i]) for i in range(3)] d1, d2 ,d3 = tf.meshgrid(s[1], s[0], s[2]) idx=tf.stack([d2, d1, d3], axis=-1) _, h, w, _ = idx.shape y=tf.repeat(tf.repeat(y[:, None, None], h, axis=1), w, axis=2)[...,None] idx = tf.concat([idx , y], axis=-1) return idx def get_corr_ele(z, y1, y2): idx1=get_idx(z, y1) idx2=get_idx(z, y2) idx = tf.concat([idx1, idx2], axis=0) ele=tf.gather_nd(z,idx) return ele, idx def swap(z1, z2, y1, y2): z1y, idx=self.get_corr_ele(z1, y1, y2) z2y, _=self.get_corr_ele(z2, y1, y2) z12 = tf.tensor_scatter_nd_update(z1, idx, z2y) #z21 = tf.tensor_scatter_nd_update(z2, idx, x1y) #straight throguh estimator z12 = z1 + tf.stop_gradient(z12-z1) #z21 = z2 + tf.stop_gradient(z21-z2) z12 = tf.concat([z12, z1], axis=-1) return z12 #, z21
Sorry, something went wrong.
No branches or pull requests
你把原來的z做concat是為了更新encoder吧? 畢竟swap元素這個操作不可微分。 那你當初在搭這個模型的時候有試過straight through estimatior嗎?
甘溫
The text was updated successfully, but these errors were encountered: