def eval(self, niter=1):
// TODO: implement the eval version
for _ in range(niter):
streams, = self.example_inputs
streams = streams.to(self.device)
sources = streams[:, 1:]
sources = self.augment(sources)
mix = sources.sum(dim=1)
estimates = self.model(mix)
sources = center_trim(sources, estimates)
loss = self.criterion(estimates, sources)
def train(self, niter=1):
After Change
def eval(self, niter=1):
// TODO: implement the eval version
for _ in range(niter):
sources, estimates = self.model(*self.example_inputs)
sources = center_trim(sources, estimates)
loss = self.criterion(estimates, sources)