super(MembershipInferenceBlackBox, self).__init__(estimator=classifier)
if input_type not in ["prediction", "loss"]:
raise ValueError("Illegal value for parameter `input_type`.")
self.input_type = input_type
if attack_model:
if ClassifierMixin not in type(attack_model).__mro__:
raise TypeError("Attack model must be of type Classifier.")
self.attack_model = attack_model
self.default_model = False
self.attack_model_type = None
else:
self.default_model = True
self.attack_model_type = attack_model_type
if attack_model_type == "nn":
if input_type == "prediction":
self.attack_model = MembershipInferenceAttackModel(classifier.nb_classes)
else:
self.attack_model = MembershipInferenceAttackModel(classifier.nb_classes, 1)
self.epochs = 100
self.bs = 100
self.lr = 0.0001
elif attack_model_type == "rf":
self.attack_model = RandomForestClassifier()
elif attack_model_type == "gb":
self.attack_model = GradientBoostingClassifier()
else:
raise ValueError("Illegal value for parameter `attack_model_type`.")
def fit(self, x: np.ndarray, y: np.ndarray, test_x: np.ndarray, test_y: np.ndarray, **kwargs) -> np.ndarray:
Infer membership in the training set of the target estimator.