size_req = max(size_req, replay_burnin)
if trainer_preprocessor is None:
sig = inspect.signature(trainer.train)
logger.info(f"Deriving trainer_preprocessor from {sig.parameters}")
// Assuming training_batch is in the first position (excluding self)
assert (
list(sig.parameters.keys())[0] == "training_batch"
), f"{sig.parameters} doesn"t have training batch in first position."
training_batch_type = sig.parameters["training_batch"].annotation
assert training_batch_type != inspect.Parameter.empty
if not hasattr(training_batch_type, "from_replay_buffer"):
raise NotImplementedError(
f"{training_batch_type} does not implement from_replay_buffer"
)
def trainer_preprocessor(batch):
retval = training_batch_type.from_replay_buffer(batch)
if device is not None:
retval = retval.to(device)