{self.var[0]: self._get_sample(reparam=reparam,
sample_shape=sample_shape)}
elif x is not None: // conditional
if type(x) is torch.Tensor:
x = {self.cond_var[0]: x}
elif type(x) is list:
x = dict(zip(self.cond_var, x))
elif type(x) is dict:
if not set(list(x.keys())) == set(self.cond_var):
raise ValueError("Input"s keys are not valid.")
else:
raise ValueError("Invalid input")
params = self._get_forward(x)
self._set_dist(params)