assert_valid_config(**kwargs)
config = get_default_config()
config.base_model = kwargs.get("base_model", config.base_model)
if config.base_model in [GPTModel, GPT2Model, GPTModelSmall, OSCAR] and config.float_16_predict:
LOGGER.warning("float_16_predict not supported by GPT and GPT2")
config.float_16_predict = False
auto_keys = []
After Change
overrides = config.base_model.get_optimal_params(config)
if not issubclass(config.base_model, _BaseBert) and config.float_16_predict:
LOGGER.warning("float_16_predict only supported by bert based models")
config.float_16_predict = False