8212e735fd187936161e64026eb46c599e1937d4,dlpy/network.py,Network,load_weights_from_file_with_labels,#Network#Any#Any#Any#Any#Any#Any#,1033
Before Change
"""
from dlpy.model_conversion.model_conversion_utils import query_action_parm
cas_lib_name, file_name, tmp_caslib = caslibify(self.conn, path, task="load")
has_gpu_model,act_parms = query_action_parm(self.conn, "dlImportModelWeights", "deepLearn", "gpuModel")
if (not has_gpu_model) and use_gpu:
raise DLPyError("A GPU model was specified, but your Viya installation does not support"
After Change
"""
from dlpy.model_conversion.model_conversion_utils import query_action_parm
with caslibify_context(self.conn, path, task = "load") as (cas_lib_name, file_name):
has_gpu_model,act_parms = query_action_parm(self.conn, "dlImportModelWeights", "deepLearn", "gpuModel")
if (not has_gpu_model) and use_gpu:
raise DLPyError("A GPU model was specified, but your Viya installation does not support"
"importing GPU models.")
if label_file_name:
from dlpy.utils import get_user_defined_labels_table
label_table = get_user_defined_labels_table(self.conn, label_file_name, label_length)
else:
from dlpy.utils import get_imagenet_labels_table
label_table = get_imagenet_labels_table(self.conn, label_length)
if data_spec:
has_data_spec = query_action_parm(self.conn, "dlImportModelWeights", "deepLearn", "gpuModel")
if has_data_spec:
// run action with dataSpec option
if has_gpu_model:
with sw.option_context(print_messages = False):
rt = self._retrieve_("deeplearn.dlimportmodelweights",
model=self.model_table,
modelWeights=dict(replace=True, name=self.model_name + "_weights"),
dataSpecs=data_spec,
gpuModel=use_gpu,
formatType=format_type, weightFilePath=file_name, caslib=cas_lib_name,
labelTable=label_table)
else:
with sw.option_context(print_messages = False):
rt = self._retrieve_("deeplearn.dlimportmodelweights",
model=self.model_table,
modelWeights=dict(replace=True, name=self.model_name + "_weights"),
dataSpecs=data_spec,
formatType=format_type, weightFilePath=file_name, caslib=cas_lib_name,
labelTable=label_table)
else:
if has_gpu_model:
with sw.option_context(print_messages = False):
rt = self._retrieve_("deeplearn.dlimportmodelweights", model=self.model_table,
modelWeights=dict(replace=True, name=self.model_name + "_weights"),
formatType=format_type, weightFilePath=file_name, caslib=cas_lib_name,
gpuModel=use_gpu,
labelTable=label_table)
else:
with sw.option_context(print_messages = False):
rt = self._retrieve_("deeplearn.dlimportmodelweights", model=self.model_table,
modelWeights=dict(replace=True, name=self.model_name + "_weights"),
formatType=format_type,
weightFilePath=file_name,
caslib=cas_lib_name,
labelTable=label_table)
// handle error or create necessary attributes
if rt.severity > 1:
for msg in rt.messages:
print(msg)
raise DLPyError("Cannot import model weights, there seems to be a problem.")
// create attributes if necessary
if not has_data_spec:
from dlpy.attribute_utils import create_extended_attributes
create_extended_attributes(self.conn, self.model_name, self.layers, data_spec)
else:
print("NOTE: no dataspec(s) provided - creating image classification model.")
if has_gpu_model:
self._retrieve_("deeplearn.dlimportmodelweights", model=self.model_table,
modelWeights=dict(replace=True, name=self.model_name + "_weights"),
formatType=format_type, weightFilePath=file_name, caslib=cas_lib_name,
gpuModel=use_gpu,
labelTable=label_table,
)
else:
self._retrieve_("deeplearn.dlimportmodelweights", model=self.model_table,
modelWeights=dict(replace=True, name=self.model_name + "_weights"),
formatType=format_type, weightFilePath=file_name, caslib=cas_lib_name,
labelTable=label_table,
)
self.set_weights(self.model_name + "_weights")
def load_weights_from_table(self, path):
"""
In pattern: SUPERPATTERN
Frequency: 5
Non-data size: 4
Instances Project Name: sassoftware/python-dlpy
Commit Name: 8212e735fd187936161e64026eb46c599e1937d4
Time: 2019-09-05
Author: Wenyu.Shi@sas.com
File Name: dlpy/network.py
Class Name: Network
Method Name: load_weights_from_file_with_labels
Project Name: sassoftware/python-dlpy
Commit Name: 8212e735fd187936161e64026eb46c599e1937d4
Time: 2019-09-05
Author: Wenyu.Shi@sas.com
File Name: dlpy/network.py
Class Name: Network
Method Name: load_weights_from_table
Project Name: sassoftware/python-dlpy
Commit Name: 8212e735fd187936161e64026eb46c599e1937d4
Time: 2019-09-05
Author: Wenyu.Shi@sas.com
File Name: dlpy/network.py
Class Name: Network
Method Name: save_weights_csv
Project Name: sassoftware/python-dlpy
Commit Name: 8212e735fd187936161e64026eb46c599e1937d4
Time: 2019-09-05
Author: Wenyu.Shi@sas.com
File Name: dlpy/network.py
Class Name: Network
Method Name: load_weights_from_file
Project Name: sassoftware/python-dlpy
Commit Name: 8212e735fd187936161e64026eb46c599e1937d4
Time: 2019-09-05
Author: Wenyu.Shi@sas.com
File Name: dlpy/network.py
Class Name: Network
Method Name: save_to_table_with_caslibify