提交 41ea7f5e 编辑于 作者: Baohang Zhou's avatar Baohang Zhou
浏览文件

add candidate generation

上级 f016f766
import copy
import codecs
from operator import getitem
import numpy as np
from tqdm import trange
import tensorflow as tf
from utils import BASEPATH
from typing import List, Dict
......@@ -18,7 +20,7 @@ class DataLoader(object):
self.__bert_path = f"{bert_path}/vocab.txt"
self.__tokenizer = self.__load_vocabulary()
self.__dict_dataset = dict_dataset
self.__entity_base = EntityBase()
self.__entity_base = EntityBase(bert_path)
self.__dict_ner_label = ["X"]
self.__max_seq_len = 100
......@@ -52,10 +54,19 @@ class DataLoader(object):
def __parse_idx_sequence(self, pred, label):
res_pred, res_label = [], []
records = {}
for i in range(len(pred)):
tmp_pred, tmp_label = [], []
str_pred = " ".join([str(ele) for ele in pred[i].numpy().tolist()])
str_label = " ".join([str(ele) for ele in label[i].numpy().tolist()])
str_record = str_pred + str_label
if str_record in records:
tmp = records[str_record]
res_pred.append(tmp[0])
res_label.append(tmp[1])
else:
for p, l in zip(pred[i], label[i]):
if self.__dict_ner_label[l] != "X":
tmp_label.append(self.__dict_ner_label[l])
......@@ -64,9 +75,9 @@ class DataLoader(object):
tmp_pred.append("O")
else:
tmp_pred.append(self.__dict_ner_label[p])
res_pred.append(tmp_pred)
res_label.append(tmp_label)
records[str_record] = (tmp_pred, tmp_label)
return res_pred, res_label
......@@ -75,7 +86,7 @@ class DataLoader(object):
pred, true = self.__parse_idx_sequence(pred, label)
y_real, pred_real = [], []
records = []
for i in range(len(real_len)):
for i in trange(len(real_len), ascii=True):
record = " ".join(true[i]) + str(real_len[i])
if record not in records:
records.append(record)
......@@ -133,7 +144,7 @@ class DataLoader(object):
nen_label = nen_label.numpy().tolist()
tmp_nen_pred = []
tmp_nen_label = []
for i in range(len(nen_label)):
for i in trange(len(nen_label), ascii=True):
n_entity = 0
if nen_label[i] == 1:
for e in ner_label_real:
......@@ -312,8 +323,9 @@ class DataLoader(object):
parsed_ent_indices = []
parsed_ent_segments = []
for sentence, ner, nen in zip(sentences, ner_label, nen_label):
samples = self.__extend_sample(sentence, ner, nen, 1)
for i in trange(len(sentences), ascii=True):
sentence, ner, nen = [ele[i] for ele in [sentences, ner_label, nen_label]]
samples = self.__extend_sample(sentence, ner, nen, 1, "test" in path)
pkd_sentence = samples["sentences"]
pkd_ner_tags = samples["ner"]
......@@ -367,7 +379,12 @@ class DataLoader(object):
return dataset
def __extend_sample(
self, sentence: List[int], ner_tag: List[int], nen_tag: Dict, n_neg: int
self,
sentence: List[int],
ner_tag: List[int],
nen_tag: Dict,
n_neg: int,
flag: bool = False,
) -> Dict:
"""
extend the sample to all samples with specific nen tags
......@@ -395,6 +412,18 @@ class DataLoader(object):
pkd_ent_sents.append(ent_sent)
pkd_nen_tags.append(1)
# Sampling negative entities
if flag:
cands = self.__entity_base.generate_candidates(
sentence, list(nen_tag.keys())
)
for c in cands:
pkd_ner_tags.append(["O"] * len(ner_tag))
pkd_cpt_ner_tags.append(ner_tag.copy())
pkd_sentences.append(sentence.copy())
ent_sent = self.__entity_base.getItem(c)
pkd_ent_sents.append(ent_sent)
pkd_nen_tags.append(0)
else:
for i in range(n_neg):
pkd_ner_tags.append(["O"] * len(ner_tag))
pkd_cpt_ner_tags.append(ner_tag.copy())
......
import codecs
import numpy as np
from keras_bert import Tokenizer
class EntityBase(object):
def __init__(self):
self.base_path = './entitybase'
class EntityBase(object):
def __init__(self, bert_path: str):
self.base_path = "./entitybase"
self.mesh_path = f"{self.base_path}/mesh.tsv"
self.omim_path = f"{self.base_path}/omim.tsv"
self.__bert_path = f"{bert_path}/vocab.txt"
self.__tokenizer = self.__load_vocabulary()
self.entities = {}
self.__load_mesh()
self.__load_omim()
def __load_vocabulary(self):
token_dict = {}
with codecs.open(self.__bert_path, "r", "utf8") as reader:
for line in reader:
token = line.strip()
token_dict[token] = len(token_dict)
return Tokenizer(token_dict)
def __tokenize_entity(self, entity):
ind, _ = self.__tokenizer.encode(first=entity)
return ind
def random_entity(self, idxs):
'''
"""
select one entity which are not in the list
generate negative sample for entity normalization
'''
"""
dicts = list(self.entities.keys())
idx = np.random.choice(dicts)
while idx in idxs:
idx = np.random.choice(dicts)
return self.getItem(idx)
def generate_candidates(self, sentence, idxs):
cands = []
parsed_sentence = self.__tokenize_entity(" ".join(sentence))
for k, v in self.entities.items():
ent = v[1]
if len(set(ent[1:-1]) & set(parsed_sentence[1:-1])) > 2 and k not in idxs:
cands.append(k)
return cands
def getVocabs(self):
names = []
for k, v in self.entities.items():
names += v
names += v[0]
return names
def getItem(self, uid):
result = self.entities.get(uid, 'none')
result = self.entities.get(uid, ["none"])
# if result == 'none':
# print(f"UNDEFINED WARNING! [{uid}]")
return result
return result[0]
@property
def Entity(self):
return self.entities
def __load_mesh(self):
data = np.loadtxt(self.mesh_path, delimiter='\t', dtype=str)
data = np.loadtxt(self.mesh_path, delimiter="\t", dtype=str)
for ele in data.tolist():
self.entities[ele[0]] = ele[1]
self.entities[ele[0]] = (ele[1], self.__tokenize_entity(ele[1]))
def __load_omim(self):
data = np.loadtxt(self.omim_path, delimiter='\t', dtype=str, encoding='utf-8')
data = np.loadtxt(self.omim_path, delimiter="\t", dtype=str, encoding="utf-8")
for ele in data.tolist():
self.entities[f"OMIM:{ele[0]}"] = ele[1]
self.entities[f"OMIM:{ele[0]}"] = (ele[1], self.__tokenize_entity(ele[1]))
if __name__ == '__main__':
if __name__ == "__main__":
eb = EntityBase()
vocabs = eb.getVocabs()
\ No newline at end of file
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册