// next compute the contribution from process noise that is injected at each timestep.
// (we need to do a cumulative sum to integrate across time for the z-state contribution)
eye = torch.eye(self.state_dim, device=fs_cov.device, dtype=fs_cov.dtype)
z_process_covar = self.log_trans_noise_scale_sq.exp() * eye
N_trans_obs_shift = torch.cat([self.z_obs_matrix.unsqueeze(0), N_trans_obs[0:-1]])
predicted_covar2z = torch.matmul(N_trans_obs_shift.transpose(-1, -2),
torch.matmul(z_process_covar, N_trans_obs_shift)) // N O O
predicted_covar = predicted_covar1z + predicted_covar1gp + gp_process_covar + \
torch.cumsum(predicted_covar2z, dim=0)
if include_observation_noise:
eye = torch.eye(self.obs_dim, device=fs_cov.device, dtype=fs_cov.dtype)
predicted_covar = predicted_covar + self._get_obs_noise_scale().pow(2.0) * eye
return predicted_mean, predicted_covar
def forecast(self, targets, N_timesteps):