2 Star 10 Fork 6

Leo / NER-Trip

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
main2.py 5.71 KB
AI 代码解读
一键复制 编辑 原始数据 按行查看 历史
Leo 提交于 2021-05-19 22:27 . 【增加】仓库迁移
#!/usr/bin/env python
# -*- coding:utf-8 -*-
#@Time : 2021/5/8 0008 8:51
#@Author : tb_youth
#@FileName: main2.py
#@SoftWare: PyCharm
#@Blog : https://blog.csdn.net/tb_youth
import tensorflow as tf
import numpy as np
import os, argparse, time
from models.bilstm_crf2 import BiLSTM_CRF
from preprocess.data_utils import str2bool
from preprocess.utils import get_logger
from preprocess.data_utils import read_dictionary, tag2label, random_embedding
from preprocess.utils import read_pkl
from config.setting import PathConfig as path_cfg,root_path
from models.utils import get_Na_dct
CUR_PATH = os.path.join(root_path)
## Session configuration
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # default: 0
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.2 # need ~700MB GPU memory
## hyperparameters
parser = argparse.ArgumentParser(description='BiLSTM-CRF for Chinese NER task')
parser.add_argument('--train_data', type=str, default=path_cfg.inputs_path, help='train data source')
parser.add_argument('--test_data', type=str, default=path_cfg.inputs_path, help='test data source')
parser.add_argument('--batch_size', type=int, default=64, help='#sample of each minibatch')
parser.add_argument('--epoch', type=int, default=40, help='#epoch of training')
parser.add_argument('--hidden_dim', type=int, default=300, help='#dim of hidden state')
parser.add_argument('--optimizer', type=str, default='Adam', help='Adam/Adadelta/Adagrad/RMSProp/Momentum/SGD')
parser.add_argument('--CRF', type=str2bool, default=True, help='use CRF at the top layer. if False, use Softmax')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')
parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
parser.add_argument('--dropout', type=float, default=0.5, help='dropout keep_prob')
parser.add_argument('--update_embedding', type=str2bool, default=True, help='update embedding during training')
parser.add_argument('--pretrain_embedding', type=str, default='random', help='use pretrained char embedding or init it randomly')
parser.add_argument('--embedding_dim', type=int, default=300, help='random init char embedding_dim')
parser.add_argument('--shuffle', type=str2bool, default=True, help='shuffle training data before each epoch')
parser.add_argument('--mode', type=str, default='train', help='train/test/demo')
parser.add_argument('--demo_model', type=str, default='1620921350', help='model for test and demo')
parser.add_argument('--Na_dct',type=dict,default={},help='entity info in test data')
args = parser.parse_args()
# test data info
args.Na_dct = get_Na_dct(path_cfg.process_test_data_path,path_cfg.original_data_path)
## get char embeddings
word2id = read_dictionary(os.path.join(CUR_PATH, args.train_data, 'word2id.pkl'))
if args.pretrain_embedding == 'random':
embeddings = random_embedding(word2id, args.embedding_dim)
else:
embedding_path = 'pretrain_embedding.npy'
embeddings = np.array(np.load(embedding_path), dtype='float32')
## read corpus and get training data
if args.mode != 'demo':
train_path = os.path.join(CUR_PATH, args.train_data, 'train.pkl')
test_path = os.path.join(CUR_PATH, args.test_data,'test.pkl')
train_data = read_pkl(train_path)['train']
test_data = read_pkl(test_path)['test']
test_size = len(test_data)
## paths setting
paths = {}
timestamp = str(int(time.time())) if args.mode == 'train' else args.demo_model
output_path = os.path.join(CUR_PATH, path_cfg.outputs_path, timestamp)
if not os.path.exists(output_path): os.makedirs(output_path)
summary_path = os.path.join(output_path, "summaries")
paths['summary_path'] = summary_path
if not os.path.exists(summary_path): os.makedirs(summary_path)
model_path = os.path.join(output_path, "checkpoints/")
if not os.path.exists(model_path): os.makedirs(model_path)
ckpt_prefix = os.path.join(model_path, "model")
paths['model_path'] = ckpt_prefix
result_path = os.path.join(output_path, "results")
paths['result_path'] = result_path
if not os.path.exists(result_path): os.makedirs(result_path)
log_path = os.path.join(result_path, "log.txt")
paths['log_path'] = log_path
get_logger(log_path).info(str(args))
## training model
if args.mode == 'train':
model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
model.build_graph()
## train model on the whole training data
print("train data: {}".format(len(train_data)))
model.train(train=train_data, dev=test_data) # use test_data as the dev_data to see overfitting phenomena
## testing model
elif args.mode == 'test':
ckpt_file = tf.train.latest_checkpoint(model_path)
print(ckpt_file)
paths['model_path'] = ckpt_file
model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
model.build_graph()
print("test data: {}".format(test_size))
model.test(test_data)
## demo
elif args.mode == 'demo':
ckpt_file = tf.train.latest_checkpoint(model_path)
print(ckpt_file)
paths['model_path'] = ckpt_file
model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
model.build_graph()
saver = tf.train.Saver()
with tf.Session(config=config) as sess:
print('============= demo =============')
saver.restore(sess, ckpt_file)
while(1):
print('Please input your sentence:')
demo_sent = input()
if demo_sent == '' or demo_sent.isspace():
print('See you next time!')
break
else:
demo_sent = list(demo_sent.strip())
demo_data = [(demo_sent, ['O'] * len(demo_sent))]
tag = model.demo_one(sess, demo_data)
print(tag)
Python
1
https://gitee.com/tbyouth/ner-trip.git
git@gitee.com:tbyouth/ner-trip.git
tbyouth
ner-trip
NER-Trip
master

搜索帮助