def __init__(self, parser=None):
parser = parser or ArgumentParser()
add_to_parser(parser, self.usage, True)
args = TrainData.parse_args(parser)
self.args = args = self.process_args(args) or args
if args.invert_samples and not args.samples_file:
parser.error("You must specify --samples-file when using --invert-samples")
if args.samples_file and not isfile(args.samples_file):
parser.error("No such file: " + (args.invert_samples or args.samples_file))
if not 0.0 <= args.sensitivity <= 1.0:
parser.error("sensitivity must be between 0.0 and 1.0")
inject_params(args.model)
save_params(args.model)
params = ModelParams(skip_acc=args.no_validation, extra_metrics=args.extra_metrics,
loss_bias=1.0 - args.sensitivity, freeze_till=args.freeze_till)
self.model = create_model(args.model, params)
self.train, self.test = self.load_data(self.args)
from keras.callbacks import ModelCheckpoint, TensorBoard
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
save_best_only=args.save_best)
epoch_fiti = Fitipy(splitext(args.model)[0] + ".epoch")
self.epoch = epoch_fiti.read().read(0, int)
def on_epoch_end(a, b):
self.epoch += 1
epoch_fiti.write().write(self.epoch, str)
self.model_base = splitext(self.args.model)[0]
if args.samples_file:
self.samples, self.hash_to_ind = self.load_sample_data(args.samples_file, self.train)
else:
self.samples = set()
self.hash_to_ind = {}
self.callbacks = [
checkpoint, TensorBoard(
log_dir=self.model_base + ".logs",
), LambdaCallback(on_epoch_end=on_epoch_end)
After Change
super().__init__(args)
if args.invert_samples and not args.samples_file:
raise ValueError("You must specify --samples-file when using --invert-samples")
if args.samples_file and not isfile(args.samples_file):
raise ValueError("No such file: " + (args.invert_samples or args.samples_file))
if not 0.0 <= args.sensitivity <= 1.0:
raise ValueError("sensitivity must be between 0.0 and 1.0")