def project_data(self, data):
if self.model.layers["mel_projection"].num_mels > 0:
data = torch.from_numpy(data).to(self.device).float().unsqueeze(0)
data = self.model.project_data(data, clamp=False)
data = (data > 0).squeeze(0).cpu().data.numpy().astype(bool)
return data
After Change
data = self._format(data, "rnn")
data = torch.from_numpy(data).to(self.device).float()
data = self.model.project_data(data, clamp=False)
data = data.squeeze(-1).permute(2, 1, 0)
data = (data > 0).cpu().data.numpy().astype(bool)
return data
def extract_features(self):