"For more information, see "
"https://github.com/pytorch/examples/issues/467."))
if not (callable(model_creator) and callable(optimizer_creator)):
raise ValueError(
"Must provide a callable model_creator and optimizer_creator.")
if num_replicas is not None:
raise DeprecationWarning(
"num_replicas is deprecated. Use num_workers instead.")
if batch_size is not None:
raise DeprecationWarning(
"batch_size is deprecated. Use config={"batch_size": N} "
"specify a batch size for each worker or "
"config={ray.util.sgd.utils.BATCH_SIZE: N} to specify a "
"batch size to be used across all workers.")
if data_loader_args:
raise ValueError(
"data_loader_args is deprecated. You can return a "
"torch.utils.data.DataLoader in data_creator. Ray will "
"automatically set a DistributedSampler if a DataLoader is "
"returned and num_workers > 1.")
self.model_creator = model_creator
self.optimizer_creator = optimizer_creator
self.loss_creator = loss_creator
self.data_creator = data_creator
self.scheduler_creator = scheduler_creator
self.training_operator_cls = training_operator_cls
if not training_operator_cls and not loss_creator:
raise ValueError("If a loss_creator is not provided, you must "
"provide a custom training operator.")
self.initialization_hook = initialization_hook
self.config = {} if config is None else config
if use_gpu == "auto":
use_gpu = torch.cuda.is_available()