batch["states"][state] = self.states[state].take(indices)
for action in self.actions:
batch["actions"][action] = self.actions[action].take(indices)
batch["rewards"] = self.rewards.take(indices)
batch["terminals"] = self.rewards.take(indices)
for n in range(len(self.internals)):
batch["internals"].append(self.internals.take(indices))
After Change
actions={name: action.take(indices) for name, action in self.actions.items()},
rewards=self.rewards.take(indices),
terminals=self.terminals.take(indices),
internals=[internal.take(indices, axis=0) for internal in self.internals]
)