54 lines
2.4 KiB
Python
54 lines
2.4 KiB
Python
import os
|
|
import argparse
|
|
from textgenrnn import textgenrnn
|
|
|
|
|
|
class PhraseGenerator(textgenrnn):
|
|
def __init__(self, input_training_file_path='./lyrics.txt', input_epochs=1, input_temperature=.5,
|
|
input_model_file_path='./WillieBotModel_weights.hdf5', logging_level=str(2)):
|
|
# Set logging for Tensorflow
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(logging_level)
|
|
|
|
# Init vars
|
|
self.training_file_path = input_training_file_path
|
|
self.model_file_path = input_model_file_path
|
|
self.epochs = input_epochs
|
|
self.temperature = input_temperature
|
|
|
|
# Init Textgenrnn
|
|
super().__init__(weights_path=self.model_file_path, allow_growth=True, name='WillieBotModel')
|
|
|
|
def pg_train(self):
|
|
self.train_from_file(self.training_file_path, num_epochs=self.epochs, verbose=0, top_n=5, return_as_list=True)
|
|
|
|
def pg_generate(self):
|
|
generated_text = self.generate(1, temperature=self.temperature, return_as_list=True)
|
|
print(generated_text[0])
|
|
return str(generated_text[0])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Description of your program')
|
|
parser.add_argument('-t', '--train', action='store_true', help='Train the model', required=False)
|
|
parser.add_argument('-g', '--generate', action='store_true', help='Generate text', required=False)
|
|
parser.add_argument('-e', '--epochs', action='store', type=int, help='Set amount of epochs (defaults to 5)',
|
|
required=False)
|
|
parser.add_argument('-p', '--temp', action='store', type=int,
|
|
help='Set temperature for generation (defaults to .5)', required=False)
|
|
parser.add_argument('-f', '--training_file', action='store', type=str,
|
|
help='Set the training file (defaults to \'./lyrics.txt\')', required=False)
|
|
args = vars(parser.parse_args())
|
|
print(args)
|
|
print('Starting')
|
|
|
|
pg = PhraseGenerator(input_epochs=args['epochs'] if args['epochs'] else 1,
|
|
input_training_file_path=args['training_file'] if args['training_file'] else './lyrics.txt',
|
|
input_temperature=args['temp'] if args['temp'] else .5,
|
|
logging_level=str(2) if args['generate'] else str(0))
|
|
|
|
if args['train']:
|
|
pg.pg_train()
|
|
|
|
if args['generate']:
|
|
pg.pg_generate()
|