// setup GPU device if available, move model into configured device
n_gpu_use = config["n_gpu"]
n_gpu = torch.cuda.device_count()
if n_gpu_use > 0 and n_gpu == 0:
self.logger.warning("Warning: There\"s no GPU available on this machine, training will be performed on CPU.")
n_gpu_use = 0
if n_gpu_use > n_gpu:
msg = "Warning: The number of GPU\"s configured to use is {}, but only {} are available on this machine.".format(n_gpu_use, n_gpu)
self.logger.warning(msg)
n_gpu_use = n_gpu
self.device = torch.device("cuda:0" if n_gpu_use > 0 else "cpu")
self.model = model.to(self.device)
if n_gpu_use > 1:
self.model = torch.nn.DataParallel(model, device_ids=list(range(n_gpu_use)))
self.loss = loss
self.metrics = metrics
After Change
self.logger = logging.getLogger(self.__class__.__name__)
// setup GPU device if available, move model into configured device
self.device, device_ids = self._prepare_device(config["n_gpu"])
self.model = model.to(self.device)
if len(device_ids) > 1:
self.model = torch.nn.DataParallel(model, device_ids=device_ids)