Introduced logging level (output)
This commit is contained in:
@@ -8,6 +8,7 @@ class PhraseGenerator(textgenrnn):
|
||||
input_model_file_path='./WillieBotModel_weights.hdf5', logging_level=str(2)):
|
||||
# Set logging for Tensorflow
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(logging_level)
|
||||
self.logging_level = logging_level
|
||||
|
||||
# Init vars
|
||||
self.training_file_path = input_training_file_path
|
||||
@@ -19,7 +20,8 @@ class PhraseGenerator(textgenrnn):
|
||||
super().__init__(weights_path=self.model_file_path, allow_growth=True, name='WillieBotModel')
|
||||
|
||||
def pg_train(self):
|
||||
self.train_from_file(self.training_file_path, num_epochs=self.epochs, verbose=0, top_n=5, return_as_list=True)
|
||||
self.train_from_file(self.training_file_path, num_epochs=self.epochs,
|
||||
verbose=0 if self.logging_level == '2' else 1, top_n=5, return_as_list=True)
|
||||
|
||||
def pg_generate(self):
|
||||
generated_text = self.generate(1, temperature=self.temperature, return_as_list=True)
|
||||
Reference in New Issue
Block a user