Introduced logging level (output)

This commit is contained in:
Logan Cusano
2022-01-22 23:56:23 -05:00
parent de5e7574c8
commit dafc4b729b

View File

@@ -8,6 +8,7 @@ class PhraseGenerator(textgenrnn):
input_model_file_path='./WillieBotModel_weights.hdf5', logging_level=str(2)):
# Set logging for Tensorflow
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(logging_level)
self.logging_level = logging_level
# Init vars
self.training_file_path = input_training_file_path
@@ -19,7 +20,8 @@ class PhraseGenerator(textgenrnn):
super().__init__(weights_path=self.model_file_path, allow_growth=True, name='WillieBotModel')
def pg_train(self):
self.train_from_file(self.training_file_path, num_epochs=self.epochs, verbose=0, top_n=5, return_as_list=True)
self.train_from_file(self.training_file_path, num_epochs=self.epochs,
verbose=0 if self.logging_level == '2' else 1, top_n=5, return_as_list=True)
def pg_generate(self):
generated_text = self.generate(1, temperature=self.temperature, return_as_list=True)