raise ValueError("Expecting a BatchGeneratorOffline.")
if not self._layers_initialized:
self._init_layers_observe_embedding(self._observe_embeddings, example_trace=next(batch_generator_offline.batches(1))[0])
self._init_layers()
self._layers_initialized = True
After Change
super().to(device=device, *args, *kwargs)
def _pre_generate_layers(self, dataset, auto_save_file_name_prefix=None):
print("Layer pre-generation")
if not self._layers_initialized:
self._init_layers_observe_embedding(self._observe_embeddings, example_trace=dataset.__getitem__(0))
self._init_layers()
self._layers_initialized = True