Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

一個小問題 #24

Open
tom99763 opened this issue Aug 19, 2022 · 1 comment
Open

一個小問題 #24

tom99763 opened this issue Aug 19, 2022 · 1 comment

Comments

@tom99763
Copy link

tom99763 commented Aug 19, 2022

你把原來的z做concat是為了更新encoder吧? 畢竟swap元素這個操作不可微分。 那你當初在搭這個模型的時候有試過straight through estimatior嗎?

甘溫

@tom99763
Copy link
Author

tom99763 commented Aug 19, 2022

順便分享一下我有在嘗試用tensorflow跑這個模型的實驗

swap跟剛剛講的straight through estimator我是這樣做

def get_idx(z, y):
        s = [tf.range(z.shape[i]) for i in range(3)]
        d1, d2 ,d3 = tf.meshgrid(s[1], s[0], s[2])
        idx=tf.stack([d2, d1, d3], axis=-1)
        _, h, w, _ = idx.shape
        y=tf.repeat(tf.repeat(y[:, None, None], h, axis=1), w, axis=2)[...,None]
        idx = tf.concat([idx , y], axis=-1)
        return idx
    
def get_corr_ele(z, y1, y2):
        idx1=get_idx(z, y1)
        idx2=get_idx(z, y2)
        idx = tf.concat([idx1, idx2], axis=0)
        ele=tf.gather_nd(z,idx)
        return ele, idx

def swap(z1, z2, y1, y2):
        z1y, idx=self.get_corr_ele(z1, y1, y2)
        z2y, _=self.get_corr_ele(z2, y1, y2)
        z12 = tf.tensor_scatter_nd_update(z1, idx, z2y)
        #z21 = tf.tensor_scatter_nd_update(z2, idx, x1y)
        
        #straight throguh estimator
        z12 = z1 + tf.stop_gradient(z12-z1)
        #z21 = z2 + tf.stop_gradient(z21-z2)
        z12 = tf.concat([z12, z1], axis=-1)
        return z12 #, z21

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant