o = torch.as_tensor(timestep.observation[None, None, ...],
device=tu.global_device()).float()
a = torch.as_tensor(timestep.action[None, None, ...],
device=tu.global_device()).float()
r = torch.as_tensor(np.array([timestep.reward])[None, None, ...],
device=tu.global_device()).float()
no = torch.as_tensor(timestep.next_observation[None, None, ...],
device=tu.global_device()).float()
if self._use_next_obs:
data = torch.cat([o, a, r, no], dim=2)
else:
data = torch.cat([o, a, r], dim=2)
if self._context is None:
self._context = data
else:
self._context = torch.cat([self._context, data], dim=1)
def infer_posterior(self, context):
rCompute :math:`q(z \| c)` as a function of input context and sample new z.
After Change
a = torch.as_tensor(timestep.action[None, None, ...],
device=global_device()).float()
r = torch.as_tensor(np.array([timestep.reward])[None, None, ...],
device=global_device()).float()
no = torch.as_tensor(timestep.next_observation[None, None, ...],
device=global_device()).float()
if self._use_next_obs:
data = torch.cat([o, a, r, no], dim=2)
else:
data = torch.cat([o, a, r], dim=2)
if self._context is None:
self._context = data
else:
self._context = torch.cat([self._context, data], dim=1)
def infer_posterior(self, context):
rCompute :math:`q(z \| c)` as a function of input context and sample new z.