提交 41ea7f5e 编辑于 作者: Baohang Zhou's avatar Baohang Zhou
浏览文件

add candidate generation

上级 f016f766
import copy import copy
import codecs import codecs
from operator import getitem
import numpy as np import numpy as np
from tqdm import trange
import tensorflow as tf import tensorflow as tf
from utils import BASEPATH from utils import BASEPATH
from typing import List, Dict from typing import List, Dict
...@@ -18,7 +20,7 @@ class DataLoader(object): ...@@ -18,7 +20,7 @@ class DataLoader(object):
self.__bert_path = f"{bert_path}/vocab.txt" self.__bert_path = f"{bert_path}/vocab.txt"
self.__tokenizer = self.__load_vocabulary() self.__tokenizer = self.__load_vocabulary()
self.__dict_dataset = dict_dataset self.__dict_dataset = dict_dataset
self.__entity_base = EntityBase() self.__entity_base = EntityBase(bert_path)
self.__dict_ner_label = ["X"] self.__dict_ner_label = ["X"]
self.__max_seq_len = 100 self.__max_seq_len = 100
...@@ -52,21 +54,30 @@ class DataLoader(object): ...@@ -52,21 +54,30 @@ class DataLoader(object):
def __parse_idx_sequence(self, pred, label): def __parse_idx_sequence(self, pred, label):
res_pred, res_label = [], [] res_pred, res_label = [], []
records = {}
for i in range(len(pred)): for i in range(len(pred)):
tmp_pred, tmp_label = [], [] tmp_pred, tmp_label = [], []
str_pred = " ".join([str(ele) for ele in pred[i].numpy().tolist()])
str_label = " ".join([str(ele) for ele in label[i].numpy().tolist()])
str_record = str_pred + str_label
if str_record in records:
tmp = records[str_record]
res_pred.append(tmp[0])
res_label.append(tmp[1])
for p, l in zip(pred[i], label[i]): else:
if self.__dict_ner_label[l] != "X": for p, l in zip(pred[i], label[i]):
tmp_label.append(self.__dict_ner_label[l]) if self.__dict_ner_label[l] != "X":
tmp_label.append(self.__dict_ner_label[l])
if self.__dict_ner_label[p] == "X":
tmp_pred.append("O")
else:
tmp_pred.append(self.__dict_ner_label[p])
res_pred.append(tmp_pred) if self.__dict_ner_label[p] == "X":
res_label.append(tmp_label) tmp_pred.append("O")
else:
tmp_pred.append(self.__dict_ner_label[p])
res_pred.append(tmp_pred)
res_label.append(tmp_label)
records[str_record] = (tmp_pred, tmp_label)
return res_pred, res_label return res_pred, res_label
...@@ -75,7 +86,7 @@ class DataLoader(object): ...@@ -75,7 +86,7 @@ class DataLoader(object):
pred, true = self.__parse_idx_sequence(pred, label) pred, true = self.__parse_idx_sequence(pred, label)
y_real, pred_real = [], [] y_real, pred_real = [], []
records = [] records = []
for i in range(len(real_len)): for i in trange(len(real_len), ascii=True):
record = " ".join(true[i]) + str(real_len[i]) record = " ".join(true[i]) + str(real_len[i])
if record not in records: if record not in records:
records.append(record) records.append(record)
...@@ -133,7 +144,7 @@ class DataLoader(object): ...@@ -133,7 +144,7 @@ class DataLoader(object):
nen_label = nen_label.numpy().tolist() nen_label = nen_label.numpy().tolist()
tmp_nen_pred = [] tmp_nen_pred = []
tmp_nen_label = [] tmp_nen_label = []
for i in range(len(nen_label)): for i in trange(len(nen_label), ascii=True):
n_entity = 0 n_entity = 0
if nen_label[i] == 1: if nen_label[i] == 1:
for e in ner_label_real: for e in ner_label_real:
...@@ -312,8 +323,9 @@ class DataLoader(object): ...@@ -312,8 +323,9 @@ class DataLoader(object):
parsed_ent_indices = [] parsed_ent_indices = []
parsed_ent_segments = [] parsed_ent_segments = []
for sentence, ner, nen in zip(sentences, ner_label, nen_label): for i in trange(len(sentences), ascii=True):
samples = self.__extend_sample(sentence, ner, nen, 1) sentence, ner, nen = [ele[i] for ele in [sentences, ner_label, nen_label]]
samples = self.__extend_sample(sentence, ner, nen, 1, "test" in path)
pkd_sentence = samples["sentences"] pkd_sentence = samples["sentences"]
pkd_ner_tags = samples["ner"] pkd_ner_tags = samples["ner"]
...@@ -367,7 +379,12 @@ class DataLoader(object): ...@@ -367,7 +379,12 @@ class DataLoader(object):
return dataset return dataset
def __extend_sample( def __extend_sample(
self, sentence: List[int], ner_tag: List[int], nen_tag: Dict, n_neg: int self,
sentence: List[int],
ner_tag: List[int],
nen_tag: Dict,
n_neg: int,
flag: bool = False,
) -> Dict: ) -> Dict:
""" """
extend the sample to all samples with specific nen tags extend the sample to all samples with specific nen tags
...@@ -395,13 +412,25 @@ class DataLoader(object): ...@@ -395,13 +412,25 @@ class DataLoader(object):
pkd_ent_sents.append(ent_sent) pkd_ent_sents.append(ent_sent)
pkd_nen_tags.append(1) pkd_nen_tags.append(1)
# Sampling negative entities # Sampling negative entities
for i in range(n_neg): if flag:
pkd_ner_tags.append(["O"] * len(ner_tag)) cands = self.__entity_base.generate_candidates(
pkd_cpt_ner_tags.append(ner_tag.copy()) sentence, list(nen_tag.keys())
pkd_sentences.append(sentence.copy()) )
ent_sent = self.__entity_base.random_entity(list(nen_tag.keys())) for c in cands:
pkd_ent_sents.append(ent_sent) pkd_ner_tags.append(["O"] * len(ner_tag))
pkd_nen_tags.append(0) pkd_cpt_ner_tags.append(ner_tag.copy())
pkd_sentences.append(sentence.copy())
ent_sent = self.__entity_base.getItem(c)
pkd_ent_sents.append(ent_sent)
pkd_nen_tags.append(0)
else:
for i in range(n_neg):
pkd_ner_tags.append(["O"] * len(ner_tag))
pkd_cpt_ner_tags.append(ner_tag.copy())
pkd_sentences.append(sentence.copy())
ent_sent = self.__entity_base.random_entity(list(nen_tag.keys()))
pkd_ent_sents.append(ent_sent)
pkd_nen_tags.append(0)
return { return {
"ner": pkd_ner_tags, "ner": pkd_ner_tags,
......
import codecs
import numpy as np import numpy as np
from keras_bert import Tokenizer
class EntityBase(object):
def __init__(self): class EntityBase(object):
self.base_path = './entitybase' def __init__(self, bert_path: str):
self.base_path = "./entitybase"
self.mesh_path = f"{self.base_path}/mesh.tsv" self.mesh_path = f"{self.base_path}/mesh.tsv"
self.omim_path = f"{self.base_path}/omim.tsv" self.omim_path = f"{self.base_path}/omim.tsv"
self.__bert_path = f"{bert_path}/vocab.txt"
self.__tokenizer = self.__load_vocabulary()
self.entities = {} self.entities = {}
self.__load_mesh() self.__load_mesh()
self.__load_omim() self.__load_omim()
def __load_vocabulary(self):
token_dict = {}
with codecs.open(self.__bert_path, "r", "utf8") as reader:
for line in reader:
token = line.strip()
token_dict[token] = len(token_dict)
return Tokenizer(token_dict)
def __tokenize_entity(self, entity):
ind, _ = self.__tokenizer.encode(first=entity)
return ind
def random_entity(self, idxs): def random_entity(self, idxs):
''' """
select one entity which are not in the list select one entity which are not in the list
generate negative sample for entity normalization generate negative sample for entity normalization
''' """
dicts = list(self.entities.keys()) dicts = list(self.entities.keys())
idx = np.random.choice(dicts) idx = np.random.choice(dicts)
while idx in idxs: while idx in idxs:
idx = np.random.choice(dicts) idx = np.random.choice(dicts)
return self.getItem(idx) return self.getItem(idx)
def generate_candidates(self, sentence, idxs):
cands = []
parsed_sentence = self.__tokenize_entity(" ".join(sentence))
for k, v in self.entities.items():
ent = v[1]
if len(set(ent[1:-1]) & set(parsed_sentence[1:-1])) > 2 and k not in idxs:
cands.append(k)
return cands
def getVocabs(self): def getVocabs(self):
names = [] names = []
for k, v in self.entities.items(): for k, v in self.entities.items():
names += v names += v[0]
return names return names
def getItem(self, uid): def getItem(self, uid):
result = self.entities.get(uid, 'none') result = self.entities.get(uid, ["none"])
# if result == 'none': # if result == 'none':
# print(f"UNDEFINED WARNING! [{uid}]") # print(f"UNDEFINED WARNING! [{uid}]")
return result return result[0]
@property @property
def Entity(self): def Entity(self):
return self.entities return self.entities
def __load_mesh(self): def __load_mesh(self):
data = np.loadtxt(self.mesh_path, delimiter='\t', dtype=str) data = np.loadtxt(self.mesh_path, delimiter="\t", dtype=str)
for ele in data.tolist(): for ele in data.tolist():
self.entities[ele[0]] = ele[1] self.entities[ele[0]] = (ele[1], self.__tokenize_entity(ele[1]))
def __load_omim(self): def __load_omim(self):
data = np.loadtxt(self.omim_path, delimiter='\t', dtype=str, encoding='utf-8') data = np.loadtxt(self.omim_path, delimiter="\t", dtype=str, encoding="utf-8")
for ele in data.tolist(): for ele in data.tolist():
self.entities[f"OMIM:{ele[0]}"] = ele[1] self.entities[f"OMIM:{ele[0]}"] = (ele[1], self.__tokenize_entity(ele[1]))
if __name__ == '__main__': if __name__ == "__main__":
eb = EntityBase() eb = EntityBase()
vocabs = eb.getVocabs() vocabs = eb.getVocabs()
\ No newline at end of file
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册