diff --git a/output_new/epoch-3000-mu-corr.png b/output_new/epoch-3000-mu-corr.png new file mode 100644 index 0000000..7308714 Binary files /dev/null and b/output_new/epoch-3000-mu-corr.png differ diff --git a/output_new/epoch-3000-mu-dist.png b/output_new/epoch-3000-mu-dist.png new file mode 100644 index 0000000..a7a7ba4 Binary files /dev/null and b/output_new/epoch-3000-mu-dist.png differ diff --git a/output_new/epoch-3000-mu-rot-dist.png b/output_new/epoch-3000-mu-rot-dist.png new file mode 100644 index 0000000..ce4ca8b Binary files /dev/null and b/output_new/epoch-3000-mu-rot-dist.png differ diff --git a/output_new/epoch-3000-subdec-mu-dist.png b/output_new/epoch-3000-subdec-mu-dist.png new file mode 100644 index 0000000..c27617b Binary files /dev/null and b/output_new/epoch-3000-subdec-mu-dist.png differ diff --git a/output/fa_log.txt b/output_new/fa_log.txt similarity index 100% rename from output/fa_log.txt rename to output_new/fa_log.txt diff --git a/preprocess_db/add_rare_types_3.py b/preprocess_db/add_rare_types_3.py index c4ba7c3..8480f03 100644 --- a/preprocess_db/add_rare_types_3.py +++ b/preprocess_db/add_rare_types_3.py @@ -5,7 +5,7 @@ from pandas import DataFrame MISSING_TYPES: Tuple[int, ...] = tuple(range(1, 17)) -EXTRA_NO_SELF_TYPES = True +EXTRA_NO_SELF_TYPES = False MALE_LABEL_SHIFT = 16 # `get_weight` function assumes this @@ -139,12 +139,13 @@ def types_tal_good_mask(df: DataFrame, def smart_coincide_2( tal_profs: Array, types_self: Array, types_tal: Array, males: Array, - threshold: int = -80, # was 90, + threshold: int = 95, # was 90, thresholds_males: Tuple[Tuple[int, Tuple[int, ...]], ...] = ( - (85, (4,)), (95, (3, 10, 16)), (-90, (12, 13)) + # (85, (4,)), (95, (3, 10, 16)), (-90, (12, 13)) + (85, (4,)), ), thresholds_females: Tuple[Tuple[int, Tuple[int, ...]], ...] = ( - (95, (16,)), + # (95, (16,)), )) -> Array: return smart_coincide_1(tal_profs=tal_profs, types_self=types_self, types_tal=types_tal, males=males, threshold=threshold, diff --git a/train/betatcvae/kld.py b/train/betatcvae/kld.py index a328944..636b93a 100644 --- a/train/betatcvae/kld.py +++ b/train/betatcvae/kld.py @@ -204,8 +204,8 @@ def forward(self, z: Tensor, *qz_params: Tensor) -> Tuple[Tensor, Dict[str, Tens def extra_repr(self) -> str: return f'dataset_size={self.dataset_size}' - # def kld(self, z: Tensor, *qz_params: Tensor) -> Tensor: - # n = z.shape[0] # batch_size - # logqz_x = self.q_dist(z, *qz_params).view(n, -1).sum(dim=1) - # logpz = self.prior_dist(z).view(n, -1).sum(dim=1) - # return logqz_x - logpz + def kld(self, z: Tensor, *qz_params: Tensor) -> Tensor: + n = z.shape[0] # batch_size + logqz_x = self.q_dist(z, *qz_params).view(n, -1).sum(dim=1) + logpz = self.prior_dist(z).view(n, -1).sum(dim=1) + return logqz_x - logpz diff --git a/train/jats/callbacks.py b/train/jats/callbacks.py index 5812da9..1573e05 100644 --- a/train/jats/callbacks.py +++ b/train/jats/callbacks.py @@ -8,21 +8,20 @@ class PlotCallback(pl.Callback): - def __init__(self, df: DataFrame, interesting_ids_ast_file_path: str, plot_every_n_epoch: int = 0, verbose=False): + def __init__(self, df: DataFrame, interesting_ids_ast_file_path: str, plot_every_n_epoch: int = 0): super(PlotCallback, self).__init__() self.plot_vae_args_weighted_batch = get_plot_vae_args_weighted_batch(df) self.plot_vae_args__y__w = get_plot_vae_args__y__w(df) with open(interesting_ids_ast_file_path, 'r', encoding='utf-8') as f: self.interesting_ids: Tuple[int, ...] = tuple(ast.literal_eval(f.read())) self.df = df - self.verbose = verbose self.plot_every_n_epoch = plot_every_n_epoch def plot(self, trainer, pl_module): if trainer.current_epoch == 0: return if self.plot_every_n_epoch > 0: - if trainer.current_epoch % self.plot_every_n_epoch != 0: + if (trainer.current_epoch != 10) and (trainer.current_epoch % self.plot_every_n_epoch != 0): return prefix = path.join(pl_module.logger.log_dir, f'epoch-{pl_module.current_epoch}-') @@ -35,8 +34,8 @@ def plot(self, trainer, pl_module): plot_dist(mu_wb, z_beta_wb, mu, y, prefix + 'mu-dist', axis_name='μ') jats = pl_module.jatsregularizer - mu_rot = jats.cat_rot_2d(mu) - plot_dist(jats.cat_rot_2d(mu_wb), jats.cat_rot_2d(z_beta_wb), mu_rot, y, + mu_rot = jats.cat_rot_np(mu) + plot_dist(jats.cat_rot_np(mu_wb), jats.cat_rot_np(z_beta_wb), mu_rot, y, prefix + 'mu-rot-dist', axis_name='μ_{rot}') plot_dist(subdec_mu_wb, None, subdec_mu, y, prefix + 'subdec-mu-dist', axis_name='s(μ)') diff --git a/train/jats/load.py b/train/jats/load.py index ba80ba8..e3175ee 100644 --- a/train/jats/load.py +++ b/train/jats/load.py @@ -89,7 +89,7 @@ def get_labeled_mask(df: DataFrame) -> Array: return df['smart_coincide'].values > 0 -def get_loader(df: DataFrame, mode: str, batch_size: int, num_workers: int = 0) -> DataLoader: +def get_loader(df: DataFrame, mode: str, batch_size: int, num_workers: int = 0, verbose=False) -> DataLoader: """ Returns a loader of the output format: @@ -100,23 +100,26 @@ def get_loader(df: DataFrame, mode: str, batch_size: int, num_workers: int = 0) """ if mode not in ('unlbl', 'lbl', 'both', 'plot'): raise ValueError + def get_(x_: Array, verbose_=verbose) -> Tuple[Tensor, ...]: + return (Tensor(x_).to(dtype=tr.long),) if verbose_ else () + if mode is 'unlbl': - _, passthr, x, e_ext, _, weights, _ = get_data(df) - dataset = TensorDataset(Tensor(x), Tensor(e_ext), Tensor(passthr)) + ids, passthr, x, e_ext, _, weights, wtarget = get_data(df) + dataset = TensorDataset(Tensor(x), Tensor(e_ext), Tensor(passthr), *get_(ids), *get_(wtarget)) sampler = WeightedRandomSampler(weights=weights.astype(np.float64), num_samples=len(x)) return DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers) if mode == 'lbl': - _, passthr, x, e_ext, target, weights, _ = get_data(df[get_labeled_mask(df)]) + ids, passthr, x, e_ext, target, weights, wtarget = get_data(df[get_labeled_mask(df)]) dataset = TensorDataset(Tensor(x), Tensor(e_ext), - Tensor(target).to(dtype=tr.long), Tensor(passthr)) + Tensor(target).to(dtype=tr.long), Tensor(passthr), *get_(ids), *get_(wtarget)) sampler = WeightedRandomSampler(weights=weights.astype(np.float64), num_samples=len(x)) return DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers) if mode == 'both': - _, passthr, x, e_ext, target, weights, _ = get_data(df) + ids, passthr, x, e_ext, target, weights, wtarget = get_data(df) weights_lbl = np.copy(weights) mask_lbl = get_labeled_mask(df) weights_lbl[mask_lbl] = get_data(df[mask_lbl])[5] @@ -125,12 +128,57 @@ def get_loader(df: DataFrame, mode: str, batch_size: int, num_workers: int = 0) Tensor(target).to(dtype=tr.long), Tensor(passthr), Tensor(weights.astype(np.float64)), Tensor(weights_lbl.astype(np.float64)), - Tensor(mask_lbl).to(dtype=tr.bool)) + Tensor(mask_lbl).to(dtype=tr.bool), + *get_(ids), *get_(wtarget)) return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) if mode == 'plot': - _, passthr, x, e_ext, _, weights, _ = get_data(df[get_labeled_mask(df)]) - dataset = TensorDataset(Tensor(x), Tensor(e_ext), Tensor(passthr)) + ids, passthr, x, e_ext, _, weights, wtarget = get_data(df[get_labeled_mask(df)]) + dataset = TensorDataset(Tensor(x), Tensor(e_ext), Tensor(passthr), *get_(ids), *get_(wtarget)) sampler = WeightedRandomSampler(weights=weights.astype(np.float64), num_samples=len(x)) return DataLoader(dataset, batch_size=batch_size, sampler=sampler) + + +def test_loader(df_len: int, loader_verbose: DataLoader): + """ + :param df_len: len(df) expected. + :param loader_verbose: get_loader(df, 'unlbl', BATCH_SIZE, num_workers=NUM_WORKERS, verbose=True) expected. + :return: + """ + ids_, types_ = [], [] + for smpl in loader_verbose: + id_i, t_i = smpl[-2], smpl[-1] + ids_.append(id_i.view(-1)) + types_.append(t_i.view(-1)) + + ids = tr.cat(ids_).numpy() + uni, counts = np.unique(ids, return_counts=True) + print(df_len, len(ids), len(uni), list(reversed(sorted(counts)))[:100]) + types = tr.cat(types_).numpy() + uni, counts = np.unique(types, return_counts=True) + print(df_len, counts) + + for smpl in loader_verbose: + id_i, t_i = smpl[-2], smpl[-1] + ids_.append(id_i.view(-1)) + types_.append(t_i.view(-1)) + + ids = tr.cat(ids_).numpy() + uni, counts = np.unique(ids, return_counts=True) + print(df_len, len(ids), len(uni), list(reversed(sorted(counts)))[:100]) + types = tr.cat(types_).numpy() + uni, counts = np.unique(types, return_counts=True) + print(df_len, counts) + + for smpl in loader_verbose: + id_i, t_i = smpl[-2], smpl[-1] + ids_.append(id_i.view(-1)) + types_.append(t_i.view(-1)) + + ids = tr.cat(ids_).numpy() + uni, counts = np.unique(ids, return_counts=True) + print(df_len, len(ids), len(uni), list(reversed(sorted(counts)))[:100]) + types = tr.cat(types_).numpy() + uni, counts = np.unique(types, return_counts=True) + print(df_len, counts) diff --git a/train/jats/utils.py b/train/jats/utils.py index 44dc1f0..91b4785 100644 --- a/train/jats/utils.py +++ b/train/jats/utils.py @@ -12,6 +12,7 @@ def probs_temper(probs: Tensor) -> Tensor: probs_[masks[4 - 1] | masks[8 - 1] | masks[12 - 1] | masks[16 - 1]] = 3 return probs_ + def probs_quadraclub(probs: Tensor) -> Tensor: """ (NTL, SFL, STC, NFC, SFC, NTC, NFL, STL) """ probs_ = probs.clone() @@ -26,16 +27,19 @@ def probs_quadraclub(probs: Tensor) -> Tensor: probs_[masks[15 - 1] | masks[16 - 1]] = 7 return probs_ + def expand_quadraclub(probs: Tensor) -> Tensor: """ (NTL, SFL, STC, NFC, SFC, NTC, NFL, STL) """ n, m = probs.shape return probs.view(n, m, 1).expand(n, m, 2).reshape(n, m * 2) + def expand_temper(probs: Tensor) -> Tensor: """ (EP/-IR, IJ/+IR, IP/-ER, EJ/+ER) """ n, m = probs.shape return probs.view(n, 1, m).expand(n, 4, m).reshape(n, m * 4) + def expand_temper_to_stat_dyn(probs: Tensor) -> Tensor: """ (EP/-IR, IJ/+IR, IP/-ER, EJ/+ER) """ n, m = probs.shape[0], probs.shape[1] // 2 diff --git a/train/jatsregularizertested.py b/train/jatsregularizer.py similarity index 65% rename from train/jatsregularizertested.py rename to train/jatsregularizer.py index 03af464..4b95462 100644 --- a/train/jatsregularizertested.py +++ b/train/jatsregularizer.py @@ -1,5 +1,4 @@ from typing import Tuple, List - import numpy as np from numpy.typing import NDArray as Array import torch as tr @@ -16,10 +15,8 @@ T: Ax = ((3, 4, 7, 8, 9, 10, 13, 14), (1, 2, 5, 6, 11, 12, 15, 16)) AD: Ax = ((5, 6, 7, 8, 9, 10, 11, 12), (1, 2, 3, 4, 13, 14, 15, 16)) AB: Ax = ((9, 10, 11, 12, 13, 14, 15, 16), (1, 2, 3, 4, 5, 6, 7, 8)) -SAB: Ax = ((11, 12, 13, 14), (3, 4, 5, 6)) -SAB_OTHER: Ax = ((1, 2, 7, 8, 9, 10, 15, 16), (1, 2, 7, 8, 9, 10, 15, 16)) -TAD: Ax = ((7, 8, 9, 10), (1, 2, 15, 16)) -TAD_OTHER: Ax = ((3, 4, 5, 6, 11, 12, 13, 14), (3, 4, 5, 6, 11, 12, 13, 14)) +# SAB: Ax = ((11, 12, 13, 14), (3, 4, 5, 6)) +# SAB_OTHER: Ax = ((1, 2, 7, 8, 9, 10, 15, 16), (1, 2, 7, 8, 9, 10, 15, 16)) SIR: Ax = ((1, 13), (6, 10)) # (NERr, SIR) NIR: Ax = ((5, 9), (2, 14)) # (SERr, NIR) @@ -63,175 +60,33 @@ def missing_types(ax: Ax) -> Tuple[int, ...]: inv(DIR), QIR, inv(DER), QER)) -class JATSRegularizerTested(nn.Module): +# there was an axis found that has something like 12,13,11/14 <=> 5,6,8 +class JATSRegularizer(nn.Module): q_subsets: Tuple[Tuple[int, ...], ...] thr: Tuple[Tensor] - - thrs_: Tuple[tuple, ...] = ( - (1.8, 1.8, tuple(1.8 for _ in range(7))), - (0.5, 1., tuple(1.5 for _ in range(7))), - (0., 1., tuple(1.2 for _ in range(7))) - ) disabled: bool = False sigmoid_scl: float = 1 - def __init__(self): - """ Loss reduction is sum. """ - super(JATSRegularizerTested, self).__init__() - self.zero = tr.tensor([0.]).mean() - self.inv_rot_45 = self.get_inv_rot(tr.tensor(0.25)) - self.thrs = Tensor(self.expand_thrs(self.thrs_)) - self.thr = (self.thrs[-1],) - - self.q_subsets = self.get_q_subsets() - if len(self.q_subsets) != 16: raise RuntimeError - - def expand_thrs(self, thrs_: Tuple[tuple, ...]) -> Tuple[Tuple[float, ...], ...]: - return tuple( - tuple(thrs[0] for _ in range(KHAXDIM)) - + tuple(thrs[1] for _ in range(KHAXDIM)) - + thrs[2] - for thrs in thrs_ - ) - - @staticmethod - def get_khaxes() -> Tuple[Tuple[Tuple[int, ...], ...], Tuple[Tuple[int, ...], ...]]: - return (KHAX0[:KHAXDIM] + KHAX_OTHER[:KHAXDIM], - KHAX1[:KHAXDIM] + KHAX_OTHER[:KHAXDIM]) - - def get_q_types(self) -> Tuple[Tuple[int, ...], ...]: - khaxes, axes = self.get_khaxes(), self.get_axes() - return (khaxes[1] + axes[1] + # MAP POS - khaxes[0] + axes[0]) # MAP NEG - - def get_q_subsets(self, verbose: bool = False) -> Tuple[Tuple[int, ...], ...]: - """ - Classes from ``self.get_q_types()`` should be from 1..16. - :returns q_subsets spec for each of 16 types. - """ - q_subsets_: List[List[int]] = [[] for _ in range(1, 17)] - for axis_i, axis_types in enumerate(self.get_q_types()): - for type_ in axis_types: - q_subsets_[type_ - 1].append(axis_i) - q_subsets = tuple(tuple(sorted(subset)) for subset in q_subsets_) - if verbose: - print('types specs: ', [[i + 1, spec] for i, spec in enumerate(q_subsets)]) - return q_subsets - - # @staticmethod - # def get_transform() -> Tensor: - # cos, sin, pi = math.cos, math.sin, math.pi - # rot45 = tr.tensor([[cos(pi / 4), -sin(pi / 4)], - # [sin(pi / 4), +cos(pi / 4)]]) - # rot_inv_45 = rot45.transpose(0, 1) - # return rot_inv_45 - - @staticmethod - def get_inv_rot(angle_pi_mult: Tensor) -> Tensor: - angle = angle_pi_mult * tr.pi - return tr.tensor([[tr.cos(angle), -tr.sin(angle)], - [tr.sin(angle), +tr.cos(angle)]]).transpose(0, 1) - - def rot_transform(self, z: Tensor) -> Tensor: - """ - >>> ''' - >>> [TBg] {T} [TAd] [FAd] {Ad} [TAd] - >>> Bg=>{Ad} F=>{T} - >>> FBg F FAd FBg Bg TBg - >>> [-ER] {-R} [-IR] [+ER] {E} [-IR] - >>> I=>{E} +R=>{-R} - >>> +IR +R +ER +IR I -ER - >>> [IAd] {Ad} [EAd] [EBg] {E} [EAd] - >>> I=>{E} Bg=>{Ad} - >>> IBg Bg EBg IBg I IAd - >>> ''' - Original axes => New axes after rot(-45) - 0. S=>[N] => S=>{N} - 1. FBg=>[TAd] => Bg=>{Ad} - 2. FAd=>[TBg] => F=>{T} - 3. +IR=>[-IR] => +R=>{-R} - 4. -ER=>[+ER] => I=>{E} - 5. IBg=>[EAd] => Bg=>{Ad} - 6. IAd=>[EBg] => I=>{E} - """ - return tr.cat([ - z[:, (0,)], - z[:, (1, 2)] @ self.inv_rot_45, - z[:, (3, 4)] @ self.inv_rot_45, - z[:, (5, 6)] @ self.inv_rot_45, - ], dim=1) - - @staticmethod - def get_axes() -> Tuple[Tuple[Tuple[int, ...], ...], Tuple[Tuple[int, ...], ...]]: - axes = tuple(tpl( - N[i], - AD[i], T[i], # rot (TAd, TBg) - inv(R)[i], E[i], # rot (-IR, +ER) - AD[i], E[i] # rot (EAd, EBg) - ) for i in (0, 1)) - return axes[0], axes[1] - - def set_threshold(self, value: int) -> None: - """ - Sets threshold. - - :param value: (value - 1) - index of the thresholds set from self.thrs. - If value == 0 then set self.disabled=True - """ - if value == 0: - self.disabled = True - return - if value >= 1: - self.thr = (self.thrs[value - 1],) - self.disabled = False - return - raise ValueError('value < 0') - - def get_q(self, z_: Tensor) -> Tensor: - """ - Returns array of quasi-probabilities vectors (batch_size, len(q_types)). - Elements of the quasi-probability vectors correspond with ``self.q_types`` (same length). - """ - q_pos = tr.sigmoid((z_ + self.thr[0]) * self.sigmoid_scl) - q_neg = tr.sigmoid((-z_ + self.thr[0]) * self.sigmoid_scl) - return tr.cat([q_pos, q_neg], dim=1) - - def axes_cross_entropy(self, z_: Tensor, y: Tensor) -> Tensor: - if self.disabled: - return self.zero - q = self.get_q(z_) - axes_neg_cross_entropy_ = [ - tr.log(q[mask][:, subset] + 1e-8).sum() - for subset, mask in zip(self.q_subsets, (y == i for i in range(16))) - if z_[mask].shape[0] > 0 - ] # list of tensors of shape (batch_subset_size,) - return sum(axes_neg_cross_entropy_) * (-1) - - def forward(self, z: Tensor, subdec_z: Tensor, y: Tensor) -> Tensor: - """ Loss reduction is sum. """ - return self.axes_cross_entropy(tr.cat([subdec_z, subdec_z, self.rot_transform(z)], dim=1), y) - - def extra_repr(self) -> str: - return (f'thrs_={self.thrs_}, disabled={self.disabled}' + - f'khaxes={self.get_khaxes()}, axes={self.get_axes()}, sigmoid_scl={self.sigmoid_scl}') - - -# there was an axis found that has something like 12,13,14 <=> 6,8,(16,10) -class JATSRegularizerUntested(JATSRegularizerTested): thrs_: Tuple[tuple, ...] = ((1.8, 1.8, (1.8, 4.0, 1.8, 1.8, 1.8, 1.8, 1.8, 1.8)), + (1.0, 1.2, (1.8, 4.0, 1.8, 1.8, 1.8, 1.8, 1.8, 1.8)), (0.5, 1.0, (1.5, 4.0, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5)), (0.0, 1.0, (1.2, 4.0, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2)), (0.0, 1.0, (1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2)), ) - angle_pi_mults = [0.167, 0.167, 0.25, 0.25] # [0.0833, 0.167, 0.25] + angle_pi_mults = [0.1, 0.167, 0.25, 0.25] def __init__(self): """ Loss reduction is sum. """ - super(JATSRegularizerUntested, self).__init__() - self.inv_rotations = [self.get_inv_rot(tr.tensor(angle_pi_mult)) for angle_pi_mult in self.angle_pi_mults] + super(JATSRegularizer, self).__init__() + self.zero = tr.tensor([0.]).mean() + self.inv_rotations = [self.get_inv_rot_2d(tr.tensor(angle_pi_mult)) for angle_pi_mult in self.angle_pi_mults] + self.thrs = Tensor(self.expand_thrs(self.thrs_)) self.thr = (self.thrs[-1],) self.lastthr = True + self.q_subsets = self.get_q_subsets() + if len(self.q_subsets) != 16: raise RuntimeError + def rot_transform(self, z: Tensor) -> Tensor: """ >>> ''' @@ -266,20 +121,58 @@ def rot_transform(self, z: Tensor) -> Tensor: IAd=>[EBg] => I=>{E} """ return tr.cat([ - z[:, (0, 7)] @ self.inv_rotations[0] if self.lastthr else z[:, (0, 1)], + z[:, (0, 7)] @ self.inv_rotations[1] if self.lastthr else z[:, (0, 7)], z[:, (1, 2)] @ self.inv_rotations[2 if self.lastthr else 1], z[:, (3, 4)] @ self.inv_rotations[2], z[:, (5, 6)] @ self.inv_rotations[3], ], dim=1) - def cat_rot_2d(self, z: Array) -> Array: + def cat_rot(self, z: Tensor) -> Tensor: + return tr.cat([ + z[:, (0, 7)], + z[:, (1, 2)] @ self.inv_rotations[1], + z[:, (3, 4)] @ self.inv_rotations[2], + z[:, (5, 6)] @ self.inv_rotations[3], + ], dim=1) + + def cat_rot_np(self, z: Array) -> Array: return np.concatenate([ - z[:, (0, 7)], # @ self.inv_rotations[0].numpy(), + z[:, (0, 7)], z[:, (1, 2)] @ self.inv_rotations[1].numpy(), z[:, (3, 4)] @ self.inv_rotations[2].numpy(), z[:, (5, 6)] @ self.inv_rotations[3].numpy(), ], axis=1) + @staticmethod + def get_khaxes() -> Tuple[Tuple[Tuple[int, ...], ...], Tuple[Tuple[int, ...], ...]]: + return (KHAX0[:KHAXDIM] + KHAX_OTHER[:KHAXDIM], + KHAX1[:KHAXDIM] + KHAX_OTHER[:KHAXDIM]) + + def get_q_types(self) -> Tuple[Tuple[int, ...], ...]: + khaxes, axes = self.get_khaxes(), self.get_axes() + return (khaxes[1] + axes[1] + # MAP POS + khaxes[0] + axes[0]) # MAP NEG + + def get_q_subsets(self, verbose: bool = False) -> Tuple[Tuple[int, ...], ...]: + """ + Classes from ``self.get_q_types()`` should be from 1..16. + :returns q_subsets spec for each of 16 types. + """ + q_subsets_: List[List[int]] = [[] for _ in range(1, 17)] + for axis_i, axis_types in enumerate(self.get_q_types()): + for type_ in axis_types: + q_subsets_[type_ - 1].append(axis_i) + q_subsets = tuple(tuple(sorted(subset)) for subset in q_subsets_) + if verbose: + print('types specs: ', [[i + 1, spec] for i, spec in enumerate(q_subsets)]) + return q_subsets + + @staticmethod + def get_inv_rot_2d(angle_pi_mult: Tensor) -> Tensor: + angle = angle_pi_mult * tr.pi + return tr.tensor([[tr.cos(angle), -tr.sin(angle)], + [tr.sin(angle), +tr.cos(angle)]]).transpose(0, 1) + @staticmethod def get_axes() -> Tuple[Tuple[Tuple[int, ...], ...], Tuple[Tuple[int, ...], ...]]: axes = tuple(tpl( @@ -287,13 +180,11 @@ def get_axes() -> Tuple[Tuple[Tuple[int, ...], ...], Tuple[Tuple[int, ...], ...] AD[i], T[i], inv(R)[i], E[i], AD[i], E[i], - # N[i], SAB[i], SAB_OTHER[i], - # TAD_OTHER[i], TAD[i], T[i], - # E[i], inv(R)[i], ) for i in (0, 1)) return axes[0], axes[1] - def expand_thrs(self, thrs_: Tuple[tuple, ...]) -> Tuple[Tuple[float, ...], ...]: + @staticmethod + def expand_thrs(thrs_: Tuple[tuple, ...]) -> Tuple[Tuple[float, ...], ...]: return tuple( tuple(thrs[0] for _ in range(KHAXDIM)) + tuple(thrs[1] for _ in range(KHAXDIM)) @@ -310,6 +201,7 @@ def set_threshold(self, value: int) -> None: """ if value == 0: self.disabled = True + self.lastthr = False return if value >= 1: self.thr = (self.thrs[value - 1],) @@ -321,5 +213,31 @@ def set_threshold(self, value: int) -> None: return raise ValueError('value < 0') + def get_q(self, z_: Tensor) -> Tensor: + """ + Returns array of quasi-probabilities vectors (batch_size, len(q_types)). + Elements of the quasi-probability vectors correspond with ``self.q_types`` (same length). + """ + q_pos = tr.sigmoid((z_ + self.thr[0]) * self.sigmoid_scl) + q_neg = tr.sigmoid((-z_ + self.thr[0]) * self.sigmoid_scl) + return tr.cat([q_pos, q_neg], dim=1) + + def axes_cross_entropy(self, z_: Tensor, y: Tensor) -> Tensor: + if self.disabled: + return self.zero + q = self.get_q(z_) + axes_neg_cross_entropy_ = [ + tr.log(q[mask][:, subset] + 1e-8).sum() + for subset, mask in zip(self.q_subsets, (y == i for i in range(16))) + if z_[mask].shape[0] > 0 + ] # list of tensors of shape (batch_subset_size,) + return sum(axes_neg_cross_entropy_) * (-1) + + def forward(self, z: Tensor, subdec_z: Tensor, y: Tensor) -> Tensor: + """ Loss reduction is sum. """ + return self.axes_cross_entropy(tr.cat([subdec_z, subdec_z, self.rot_transform(z)], dim=1), y) + def extra_repr(self) -> str: - return super(JATSRegularizerUntested, self).extra_repr() + f'angle_pi_mults={self.angle_pi_mults}' + return (f'thrs_={self.thrs_}, disabled={self.disabled}, ' + + f'khaxes={self.get_khaxes()}, axes={self.get_axes()}, sigmoid_scl={self.sigmoid_scl}, ' + + f'angle_pi_mults={self.angle_pi_mults}') diff --git a/train/kldcheckdims.py b/train/kldcheckdims.py index 20cf16e..f5de9c0 100644 --- a/train/kldcheckdims.py +++ b/train/kldcheckdims.py @@ -7,7 +7,7 @@ class CheckKLDDims(Module): def __init__(self, switch_from_min_to_max_thr=False, thr: float = 0, subset: Tuple[int, ...] = (), check_interv: Tuple[int, int] = ()): super(CheckKLDDims, self).__init__() - self.thr = thr # 714 had 0.19 but 0.15 was also successfully tried + self.thr = thr self.subset = subset self.check_interv = check_interv # if check_interv[0] > check_interv[1]: raise ValueError diff --git a/train/main.py b/train/main.py index d71f57e..37e4acb 100644 --- a/train/main.py +++ b/train/main.py @@ -3,12 +3,15 @@ import subprocess from pathlib import Path +# README: # Mode #1: Simply run but make sure that ../train_log is empty # Mode #2: Simply run but make sure that ../train_log has logs -# with successful or good enough to try runs copied from -# ../train_log_search_model . Then set OFFSET_EP below to -# the last epoch from latest checkpoint + 1. +# with successful or good enough to try runs copied from ../train_log_search_model # (Placing "skip" file into the directory would skip it). +# +# Tensorboard info: Don't start from CWD or parent, use for example ~ +# > conda activate nn +# > tensorboard --logdir MAX_ERROR = 1000 diff --git a/train/mmdloss.py b/train/mmdloss.py index 87275da..a62395e 100644 --- a/train/mmdloss.py +++ b/train/mmdloss.py @@ -51,3 +51,7 @@ def forward(self, z: Tensor) -> Tensor: def extra_repr(self) -> str: return f'min_mmd_batch_size={self.batch_size}, ' + 'mu={:.3f}, sigma={:.3f}'.format( self.mu[0].item(), self.log_sigma.exp()[0].item()) + + +def mmd_sym_swap(z: Tensor) -> Tensor: + return mmd(z, -z[:, tr.randperm(z.shape[1])]) diff --git a/train/training.py b/train/training.py index a8eea01..64d41cb 100644 --- a/train/training.py +++ b/train/training.py @@ -2,6 +2,7 @@ from os import path from os.path import join import argparse +# import numpy as np import pandas as pd from sklearn.model_selection import train_test_split import torch as tr @@ -10,23 +11,21 @@ # from torch.optim.lr_scheduler import MultiStepLR import pytorch_lightning as pl from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -# from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint from torchmetrics import Accuracy, MeanMetric, MaxMetric, SumMetric, CatMetric -from lightningfix import get_last_checkpoint, GitDirsSHACallback, DeterministicWarmup, tensor_to_dict, LightningModule2 +from lightningfix import (get_last_checkpoint, GitDirsSHACallback, DeterministicWarmup, tensor_to_dict, + LightningModule2) from mmdloss import MMDNormalLoss from betatcvae.kld import Distrib, StdNormal, KLDTCLoss -from jatsregularizertested import JATSRegularizerUntested # , JATSRegularizerTested +from jatsregularizer import JATSRegularizer from kldcheckdims import CheckKLDDims from jats.load import get_target_stratify, get_loader, MAIN_QUEST_N, EXTRA_QUEST from jats.callbacks import PlotCallback from jats.utils import probs_temper, probs_quadraclub, expand_quadraclub, expand_temper, expand_temper_to_stat_dyn -# README: run ./main.py script in order to do first part of the training. - -# Tensorboard info: Don't start from CWD or parent, use for example ~ -# > conda activate nn -# > tensorboard --logdir +# README: +# run ./main.py script and read it's README DEFAULTNAME = 'train_log' NUM_WORKERS = 4 @@ -40,10 +39,8 @@ LAT_D = 8 SUB_D = 12 -# K = 27 # 6; 11; 27 -MAXEPOCHS = 100 + 3000 # the first part: 100; + 3000 -# the first part: 220 + 20 * K; + 3000 -OFFSET_EP = 0 # 0 +MAXEPOCHS = 3001 # 350 # 1500 +OFFSET_EP = 0 REAL_MAX_EP = True if X_EXT_D != len(EXTRA_QUEST) or X_D != MAIN_QUEST_N: raise RuntimeError('Inconsistent constants.') @@ -84,15 +81,13 @@ class PlotCallback2(PlotCallback): def plot(self, trainer, pl_module): if isinstance(pl_module, LightningModule2): - if (not pl_module.successful_run) and (not self.verbose): - return if pl_module.dummy_validation: return super(PlotCallback2, self).plot(trainer, pl_module) plot_callback = PlotCallback2(jats_df, join(preprocess_db, 'ids_interesting.ast'), - plot_every_n_epoch=20, verbose=True) + plot_every_n_epoch=20) git_dir_sha = GitDirsSHACallback(preprocess_db) @@ -110,30 +105,31 @@ def forward(self, x: Tensor): # noinspection PyAbstractClass class LightVAE(LightningModule2): - def __init__(self, learning_rate=1e-3, offset_step: int = 0): # 1e-4 max_epochs: int, + def __init__(self, learning_rate=1e-4, offset_step: int = 0): super().__init__() self.learning_rate = learning_rate - # self.lr_warmup = False - # originally all 48 were 32; all 96 were 64 - self.encoder = nn.Sequential(nn.Linear(X_D + X_EXT_D, 96), nn.SELU(), - nn.Linear(96, 48), nn.SELU(), - nn.Linear(48, LAT_D * 3), + self.encoder = nn.Sequential(nn.Linear(X_D + X_EXT_D, 64), nn.SELU(), + nn.Linear(64, 32), nn.SELU(), + nn.Linear(32, LAT_D * 3), DoubleVAESampler(StdNormal())) - self.sub_decoder = nn.Sequential(nn.Linear(LAT_D, 48), nn.SELU(), - nn.Linear(48, 48), nn.SELU(), - nn.Linear(48, SUB_D)) - self.decoder = nn.Sequential(nn.Linear(SUB_D, 48), nn.SELU(), - nn.Linear(48, 96), nn.SELU(), - nn.Linear(96, X_D)) - - self.cls_pos_shifts = nn.Parameter(tr.zeros(LAT_D + SUB_D) - 0.01) - self.cls_neg_shifts = nn.Parameter(tr.zeros(LAT_D + SUB_D) + 0.01) - self.selu = nn.SELU() - self.classifier_logprobs = nn.Sequential(nn.Linear(PSS_D + (LAT_D + SUB_D) * 3, 16), nn.LogSoftmax(dim=1)) + self.sub_decoder = nn.Sequential(nn.Linear(LAT_D, 32), nn.SELU(), + nn.Linear(32, 32), nn.SELU(), + nn.Linear(32, SUB_D)) + self.decoder = nn.Sequential(nn.Linear(SUB_D, 32), nn.SELU(), + nn.Linear(32, 64), nn.SELU(), + nn.Linear(64, X_D)) + self.decoder_beta = nn.Sequential(nn.Linear(SUB_D, 32), nn.SELU(), + nn.Linear(32, 64), nn.SELU(), + nn.Linear(64, X_D)) + + clsdim = LAT_D * 2 + SUB_D + self.cls_shifts = nn.Parameter(tr.zeros(clsdim * 2).view(2, clsdim) - 0.01) + self.cls_logits = nn.Linear(PSS_D + clsdim * 3, 16) # nn.Sequential(nn.Linear(, 16), nn.LeakyReLU(), nn.Linear(16, 16), nn.LogSoftmax(dim=1)) + self.cls_linear = nn.Linear(PSS_D + clsdim * 3, PSS_D + clsdim * 3) - self.jatsregularizer = JATSRegularizerUntested() + self.jatsregularizer = JATSRegularizer() self.nll_logprobs = nn.NLLLoss(reduction='none') self.bce_logits = nn.BCEWithLogitsLoss(reduction='none') self.mse_probs = nn.MSELoss(reduction='none') @@ -165,109 +161,66 @@ def __init__(self, learning_rate=1e-3, offset_step: int = 0): # 1e-4 max_epoch # after old-840-ep ~ mse 0.142 / bce 86.7 # TC = 0.48 @ 640ep; 0.33 @ 840ep # @640ep NLL_t=88.07 raises max d0.5 then goes down. + # MSE_test=0.1417 was right before the last rho switch. - # k = 57.1875 - # def s(x): return round(x * 57.1875) - # def a(x): return x + 80 # 80=>25 160=>50 - # x + 96 + 32 - # def b(x): return a(x) - 320 # 640=>200 d64=>d20 - # a(x) - 96 + 32 + 64 * K - # def c(x): return b(x) + 120 # 760=>237.5 840=>262.5 - # b(x) + 120 * 4 + scl(200) / k - # def d(x): return c(x) # 1040=>325 1120=>350 - # c(x) + scl(200) / k - def scl(x): return round(x * 183) + def scl(x): return round(x * 183) + offset_step + def inv_scl(scl_x): return scl_x // 183 k25, k50 = scl(25), scl(50) - k100, k175 = scl(100), scl(175) - k200 = scl(200) + k200, k300 = scl(200), scl(300) + k500, k4000 = scl(500), scl(4000) + k100, k1000 = scl(100), scl(1000) + # k3000 = scl(3000) self.warmup = DeterministicWarmup( 1, offset_step, - # dict(alpha=[[0, 0.005], [s(a(80)), 0.5], [s(d(1120)), 0.5]]), # orig - # dict(alpha=[[0, 0.005], [s(a(80)), 0.25], [s(d(1120)), 0.25]]), - # dict(beta=[[0, 0.005], [s(a(80)), 1], [s(d(1120)), 1]]), # orig - dict(beta=[[0, 0.005], [k50, 0.5], [k200, 0.5]]), - dict(gamma=[[0, 0], [k50, 7], [k200, 7]]), - dict(epsilon=[[0, 0.75], [k100, 0.75], [k175, 0.33], [k200, 0.33]]), # orig - # dict(epsilon=[[0, 0.75], [s(b(640)), 0.75], [s(c(760)), 0.2], [s(d(1120)), 0.2]]), - dict(eta=[[0, 21], [k50, 7], [k100, 7], [k175, 3], [k200, 3]]), # should be @ (1-mu) - # dict(eta=[[0, 30], [s(a(80)), 10], [s(b(640)), 10], [s(c(760)), 5], [s(d(1120)), 5]]), # was @ (1-mu)/2 - dict(mu=[[0, 0.57], [k100, 0.57], [scl(175), 0.33], [k200, 0.33]]), # should be @ (1-mu) - # dict(mu=[[0, 0.4], [s(b(640)), 0.4], [s(c(760)), 0.2], [s(d(1120)), 0.2]]), # was @ (1-mu)/2 bug - # dict(rho=[[0, 3], [s(b(640)), 3], [s(b(641)), 2], [s(d(1040)), 2], [s(d(1041)), 1], [s(d(1120)), 1]]), - dict(rho=[[0, 4], [k25, 4], [k25 + 1, 3], [k100, 3], [k100 + 1, 2], [k175, 2], - [k175 + 1, 1], [k200, 1]]), - dict(omega=[[0, 2], [k100, 2], [k175, 4], [k200, 4]]), + dict(alpha=[[0, 0.005], [k50, 0.25], [k100, 0.5], [k4000, 0.5]]), + dict(beta=[[0, 0.005], [k50, 0.25], [k100, 1], [k4000, 1]]), + dict(gamma=[[0, 0], [k50, 7], [k4000, 7]]), + dict(delta=[[0, 0], [k4000, 0]]), # [k3000, 0], [k3000 + 1, 1], [k4000, 1] + dict(epsilon=[[0, 0.85], [k50, 0.85], [k100, 0.75], [k200, 0.75], [k300, 0.1], [k4000, 0.1]]), + dict(eta=[[0, 21], [k50, 7], [k200, 7], [k300, 3], [k4000, 3]]), + dict(mu=[[0, 0.57], [k200, 0.57], [k300, 0.33], [k4000, 0.33]]), + dict(rho=[[0, 5], [k25, 5], [k25 + 1, 4], [k200, 4], [k200 + 1, 3], [k500, 3], [k500 + 1, 2], [k4000, 2]]), + dict(omega=[[0, 2], [k200, 2], [k300, 4], [k4000, 4]]), ) - # self.max1 = 540 - # untested version: + # Untested version: # ----------------- - # sh = offset_step - self.check_kld_dims_max = nn.ModuleList([ - CheckKLDDims(True, thr=1.30, subset=(1,), check_interv=(4, 5)), - CheckKLDDims(True, thr=0.55, subset=(1,), check_interv=(35, 100)), - CheckKLDDims(True, thr=0.55, subset=(4,), check_interv=(65, 100)), + check_kld_dims_delta_min = nn.ModuleList([ + CheckKLDDims(thr=0.001, subset=(0, 1, 4, 5), check_interv=(5 - 1, 1000)), + CheckKLDDims(thr=0.001, subset=(2, 3), check_interv=(50 - 1, 1000)), ]) - self.check_kld_dims_min = nn.ModuleList([ - CheckKLDDims(thr=1.20, subset=(5,), check_interv=(4, 5)), - CheckKLDDims(thr=0.50, subset=(0, 2, 3, 5, 6), check_interv=(65, 100)), - # CheckKLDDims(thr=0.32, subset=(0, 2, 6), check_interv=(100, 140), - # CheckKLDDims(thr=0.42, subset=(0, 2, 6), check_interv=(140, max_epochs)), - # CheckKLDDims(thr=0.42, subset=(3,), check_interv=(140, max_epochs)), - # CheckKLDDims(thr=0.23, subset=(0, 1, 2, 3, 4, 5, 6), check_interv=(140, max_epochs)), - # CheckKLDDims(thr=0.20, subset=(0, 1, 2, 3, 4, 5, 6), check_interv=(220 + 20 * K + sh, max_epochs)), - # CheckKLDDims(thr=0.14, subset=(7,), check_interv=(220 + 20 * K + sh, max_epochs)), # 220 + check_kld_dims_max = nn.ModuleList([ + CheckKLDDims(True, thr=1.30, subset=(1, 4), check_interv=(5 - 1, 5)), + CheckKLDDims(True, thr=1.10, subset=(1, 4), check_interv=(10 - 1, 10)), + CheckKLDDims(True, thr=0.55, subset=(1, 4), check_interv=(25 - 1, 1000)), ]) + check_kld_dims_min = nn.ModuleList([ + CheckKLDDims(thr=0.42, subset=(0, 2, 3, 5, 6), check_interv=(80 - 1, inv_scl(k200))), + CheckKLDDims(thr=0.20, subset=(0, 1, 2, 3, 4, 5, 6), check_interv=(inv_scl(k300) - 1, 1000)), + CheckKLDDims(thr=0.14, subset=(7,), check_interv=(inv_scl(k300) + 100 - 1, 1000)), + ]) + self.check_kld_dims = nn.ModuleList([check_kld_dims_delta_min, check_kld_dims_max, check_kld_dims_min]) - # tested version: - # --------------- - # self.check_kld_dims_0 = CheckKLDDims(thr=0.005, subset=(7,), check_interv=(o(a(160)), max_epochs)) - # self.check_kld_dims_0_i = 0 - - # self.check_kld_dims_delta_max = nn.ModuleList([ - # CheckKLDDims(True, thr=0.07, subset=(7,), check_interv=(o(a(100)), max_epochs)) - # ]) - # self.check_kld_dims_max = nn.ModuleList([ - # CheckKLDDims(True, thr=0.5, subset=(1,), check_interv=(o(a(200)), max_epochs)), - # CheckKLDDims(True, thr=0.55, subset=(4,), check_interv=(o(a(200)), max_epochs)), - # ]) - # self.check_kld_dims_min = nn.ModuleList([ - # CheckKLDDims(thr=0.32, subset=(0, 2, 5, 6), check_interv=(o(a(200)), o(a(320)))), - # CheckKLDDims(thr=0.42, subset=(0, 2, 5, 6), check_interv=(o(a(320)), max_epochs)), - # CheckKLDDims(thr=0.42, subset=(3,), check_interv=(o(a(320)), max_epochs)), - # CheckKLDDims(thr=0.23, subset=(0, 1, 2, 3, 4, 5, 6), check_interv=(o(a(320)), o(b(640)))), - # CheckKLDDims(thr=0.25, subset=(0, 1, 2, 3, 4, 5, 6), check_interv=(o(b(640)), max_epochs)), - # CheckKLDDims(thr=0.15, subset=(7,), check_interv=(o(b(640)), max_epochs)), - # ]) - - # def classify_logprobs__(self, pss, z, subdec_z): - # z_subdec_z = tr.cat([z, subdec_z], dim=1) - # return self.classifier_logprobs(tr.cat([ - # pss, - # z_subdec_z, - # self.selu(z_subdec_z - self.cls_pos_shifts), - # self.selu(-z_subdec_z - self.cls_neg_shifts) - # ], dim=1)) - # - # def nll_classify_logprobs__(self, pss, z, subdec_z, y) -> Tensor: - # return self.nll_logprobs(self.classify_logprobs__(pss, z, subdec_z), y) - # - # def classify_probs__(self, pss, z, subdec_z) -> Tensor: - # return tr.exp(self.classify_logprobs__(pss, z, subdec_z)) + @staticmethod + def get_kld_dims_args(kldvec: Tensor, kldvec_beta: Tensor): + kld_greater = kldvec_beta.view(1, -1)[:, (5, 6)].view(-1) + dkld = tr.cat((kld_greater - kldvec_beta[1], kld_greater - kldvec_beta[4], kld_greater - kldvec_beta[7])) + return (dkld,), (kldvec_beta,), (kldvec, kldvec_beta) - def classify_logprobs(self, pss, z, subdec_z) -> Tuple[Tensor, Tensor]: - linear = self.classifier_logprobs[0] - logsoftmax = self.classifier_logprobs[1] + def load_state_dict(self, state_dict, strict=False): + return super(LightVAE, self).load_state_dict(state_dict, strict=strict) - z_subdec_z = tr.cat([z, subdec_z], dim=1) - logits = linear(tr.cat([ + def classify_logprobs(self, pss, z, subdec_z) -> Tuple[Tensor, Tensor]: + z_z_rot_subdec_z = tr.cat([z, self.jatsregularizer.cat_rot(z), subdec_z], dim=1) + # expects 1 + (8 + 12) * 3; we have 1 + (8 + 6) * 3 + logits = self.cls_logits(tr.cat([ pss, - z_subdec_z, - self.selu(z_subdec_z - self.cls_pos_shifts), - self.selu(-z_subdec_z - self.cls_neg_shifts) + z_z_rot_subdec_z, + tr.selu(z_z_rot_subdec_z - self.cls_shifts[0]), + tr.selu(-z_z_rot_subdec_z - self.cls_shifts[1]) ], dim=1))[:, :12] - return logsoftmax(logits[:, :8]), logsoftmax(logits[:, 8:]) + return tr.log_softmax((logits[:, :8]), dim=1), tr.log_softmax((logits[:, 8:]), dim=1) def nll_classify_logprobs(self, pss, z, subdec_z, y) -> Tensor: logprobs8, logprobs4 = self.classify_logprobs(pss, z, subdec_z) @@ -281,15 +234,22 @@ def classify_probs(self, pss, z, subdec_z) -> Tensor: def sub_decode(self, *z: Tensor) -> Tensor: return self.sub_decoder(z[0]) - def decode__sub_decode(self, *z: Tensor) -> Tensor: + def decode_sub_decode(self, *z: Tensor) -> Tensor: return self.decoder(self.sub_decoder(z[0])) + def decode_beta_sub_decode(self, *z: Tensor) -> Tensor: + return self.decoder_beta(self.sub_decoder(z[0])) + # def sub_decode(self, *z: Tensor) -> Tensor: # return self.sub_decoder(tr.cat(z, dim=1)) # - # def decode__sub_decode(self, *z: Tensor) -> Tensor: + # def decode_sub_decode(self, *z: Tensor) -> Tensor: # subdec_z = self.sub_decode(*z) # return self.decoder(tr.cat((subdec_z,) + z[1:], dim=1)) + # + # def decode_beta_sub_decode(self, *z: Tensor) -> Tensor: + # subdec_z = self.sub_decode(*z) + # return self.decoder_beta(tr.cat((subdec_z,) + z[1:], dim=1)) def forward(self, x: Tensor, x_ext: Tensor, passth: Tensor): """ @@ -309,7 +269,7 @@ def training_step(self, batch, batch_idx): # Setting warmup coefficients # ================= - (beta, gamma, epsilon, # (alpha, + (alpha, beta, gamma, delta, epsilon, eta, mu, rho, omega), warmup_log = self.warmup(self.global_step) self.logger.log_metrics(warmup_log, step=self.global_step) @@ -321,8 +281,8 @@ def training_step(self, batch, batch_idx): # x.sum(dim=1).mean() == x.mean(dim=0).sum() == x.sum() / n mu1, log_sigma, log_sigma_beta, z1, z1_beta = self.encoder(tr.cat([x1, x1ext], dim=1)) subdec_mu1 = self.sub_decode(mu1, pss1) - bce = self.bce_logits(self.decode__sub_decode(z1, pss1), x1).sum() / n1 - bce_b = self.bce_logits(self.decode__sub_decode(z1_beta, pss1), x1).sum() / n1 + bce = self.bce_logits(self.decode_sub_decode(z1, pss1), x1).sum() / n1 + bce_b = self.bce_logits(self.decode_beta_sub_decode(z1_beta, pss1), x1).sum() / n1 kld = self.kld_tc.q_dist.kld((mu1, log_sigma)).sum() / n1 # kld = self.kld_tc.kld(z1, mu1, log_sigma).mean() @@ -335,9 +295,8 @@ def training_step(self, batch, batch_idx): self.metr_tc(tc_b) self.metr_tc_max(tc_b_unrdcd.max() if (n1 == self.mmd.batch_size) else tc_b) - mmd = self.mmd(mu1) + mmd_mu1 = self.mmd(mu1) trim_loss = (tr.relu(mu1 - 2.5) + tr.relu(-mu1 - 2.5)).sum() / n1 - # trim_loss was *200 for each of twins. But actually they are the same hence 400 trim_loss_subdec = (tr.relu(subdec_mu1 - 3) + # # not neg. thr. at right tr.relu(-subdec_mu1 - 3)).sum() / n1 # not pos. thr. at left @@ -350,28 +309,30 @@ def training_step(self, batch, batch_idx): jats_z_b = self.jatsregularizer(z2_beta, self.sub_decode(z2_beta, pss2), y2) / n2 jats_mu_b = self.jatsregularizer(mu2, self.sub_decode(mu2, pss2), y2) / n2 - nll = self.nll_classify_logprobs(pss2, z2.detach(), subdec_z2.detach(), y2).sum() / n2 + if delta < 0.5: + nll = self.nll_classify_logprobs(pss2, z2.detach(), subdec_z2.detach(), y2).sum() / n2 + else: + nll = self.nll_classify_logprobs(pss2, z2, subdec_z2, y2).sum() / n2 - # Only trim_loss coefficients are OK: # ================= loss = ( - (bce + kld * beta) * epsilon + + (bce + kld * alpha) * epsilon + (bce_b + kld_b * beta + tc_b * gamma) * (1 - epsilon) + trim_loss * 400 + # was 200 + 200 trim_loss_subdec * 200 + - mmd * 1000 + - ( - jats_mu_b * mu + jats_z * (epsilon*(1 - mu)) + jats_z_b * ((1 - epsilon)*(1 - mu)) - ) * eta + - nll + mmd_mu1 * 1000 + + (jats_mu_b * mu + (jats_z * epsilon + jats_z_b * (1 - epsilon)) * (1 - mu) + ) * eta + + nll * 0.01 ) + # trim_loss was *200 for each of twins. But actually they are the same hence 400 if tr.isnan(loss).any(): raise NotImplementedError('NaN spotted in the objective.') self.log("train_loss", loss) self.metr_loss(loss) with tr.no_grad(): - xlogits_mu1 = self.decode__sub_decode(mu1, pss1) + xlogits_mu1 = self.decode_sub_decode(mu1, pss1) self.metr_bce_l(self.bce_logits(xlogits_mu1, x1).sum() / n1) self.metr_mse_l(self.mse_probs(tr.sigmoid(xlogits_mu1), x1).mean()) self.metr_acc_l(self.classify_probs(pss2, mu2, self.sub_decode(mu2, pss2)), y2) @@ -386,10 +347,10 @@ def validation_step(self, batch, batch_idx): # Unlabelled data: mu1, log_sigma, log_sigma_beta, z1, z1_beta = self.encoder(tr.cat([x1, x1ext], dim=1)) - bce_z1 = (self.bce_logits(self.decode__sub_decode(z1, pss1), x1).sum(dim=1) * w1).sum() - bce_z1_b = (self.bce_logits(self.decode__sub_decode(z1_beta, pss1), x1).sum(dim=1) * w1).sum() + bce_z1 = (self.bce_logits(self.decode_sub_decode(z1, pss1), x1).sum(dim=1) * w1).sum() + bce_z1_b = (self.bce_logits(self.decode_beta_sub_decode(z1_beta, pss1), x1).sum(dim=1) * w1).sum() - xlogits_mu1 = self.decode__sub_decode(mu1, pss1) + xlogits_mu1 = self.decode_sub_decode(mu1, pss1) self.metr_bce_mu_v((self.bce_logits(xlogits_mu1, x1).sum(dim=1) * w1).sum()) self.metr_mse_v((self.mse_probs(tr.sigmoid(xlogits_mu1), x1).mean(dim=1) * w1).sum()) @@ -434,29 +395,10 @@ def validation_epoch_end(self, outputs): # self.log("cls_shifts", {**tensor_to_dict(self.cls_pos_shifts, 'p'), # **tensor_to_dict(self.cls_neg_shifts, 'n')}) - - # self.logg("kld_val", self.metr_kld, lambda m: tensor_to_dict(m.sum(dim=0))) - # self.logg("kld_beta_val", self.metr_kld_beta, lambda m: tensor_to_dict(m.sum(dim=0))) klvec = self.logg("kld_val", self.metr_kld, lambda m: tensor_to_dict(m.sum(dim=0))).sum(dim=0) klvec_b = self.logg("kld_beta_val", self.metr_kld_beta, lambda m: tensor_to_dict(m.sum(dim=0))).sum(dim=0) - - # j = 0. - # if not self.check_kld_dims_0(self.current_epoch, klvec, klvec_b): - # if not self.dummy_validation: - # self.check_kld_dims_0_i += 1 - # if self.check_kld_dims_0_i >= 8: - # self.log("kld_check_failed", j) - # self.successful_run = False - # self.trainer.should_stop = True - # return - # elif not self.dummy_validation: - # self.check_kld_dims_0_i = 0 - - # for checks, args_ in zip([self.check_kld_dims_delta_max, self.check_kld_dims_max, self.check_kld_dims_min], - # [(klvec_b - klvec,), (klvec_b,), (klvec, klvec_b)]): i = -1. - for checks, args_ in zip([self.check_kld_dims_max, self.check_kld_dims_min], - [(klvec_b,), (klvec, klvec_b)]): + for checks, args_ in zip(self.check_kld_dims, self.get_kld_dims_args(klvec, klvec_b)): for check in checks: i += 1. if not check(self.current_epoch, *args_): @@ -468,25 +410,17 @@ def validation_epoch_end(self, outputs): def configure_optimizers(self): return tr.optim.Adam(self.parameters(), lr=self.learning_rate) - # if not self.lr_warmup: - # return tr.optim.Adam(self.parameters(), lr=self.learning_rate * 0.1) - # optimizer = tr.optim.Adam(self.parameters(), lr=self.learning_rate) - # scheduler = MultiStepLR(optimizer, milestones=[self.max1], gamma=0.1) - # return [optimizer], [dict( - # scheduler=scheduler, - # interval='epoch', - # frequency=1, - # )] - - -autoencoder = LightVAE(offset_step=183 * OFFSET_EP) # max_epochs=maxepochs, + + +autoencoder = LightVAE(offset_step=183 * OFFSET_EP) trainer_ = pl.Trainer(max_epochs=maxepochs, logger=logger, check_val_every_n_epoch=5, callbacks=[ plot_callback, git_dir_sha, - # ModelCheckpoint(every_n_epochs=10, save_top_k=-1), + ModelCheckpoint(every_n_epochs=10, save_top_k=-1), ]) if __name__ == '__main__': trainer_.fit(autoencoder, train_loaders, test_loader, ckpt_path=checkpoint_path) - if not autoencoder.successful_run or (MAXEPOCHS < 3000): + if not autoencoder.successful_run: raise RuntimeError('Unsuccessful. Skipping this train run.') + raise RuntimeError('Force skip all train runs.') diff --git a/train/wfa.py b/train/wfa.py index 7406255..9aed147 100644 --- a/train/wfa.py +++ b/train/wfa.py @@ -113,7 +113,6 @@ def get_bce_mse(x, x_rec, w) -> Tuple[float, float]: for j, i in enumerate(n_factors_list): fa = FactorAnalyzer(n_factors=i, is_corr_matrix=True, method='ml', rotation=(None, 'varimax', 'oblimax', 'quartimax', 'equamax')[0]) - # rotation=('equamax', None)[1] fa.fit(w_sigma_mat) fa.mean_ = np.zeros(x_ext_lrn.shape[1]) fa.std_ = fa.mean_ + 1.