Commit 41ea7f5e authored by Baohang Zhou's avatar Baohang Zhou
Browse files

add candidate generation

parent f016f766
import copy
import codecs
from operator import getitem
import numpy as np
from tqdm import trange
import tensorflow as tf
from utils import BASEPATH
from typing import List, Dict
......@@ -18,7 +20,7 @@ class DataLoader(object):
self.__bert_path = f"{bert_path}/vocab.txt"
self.__tokenizer = self.__load_vocabulary()
self.__dict_dataset = dict_dataset
self.__entity_base = EntityBase()
self.__entity_base = EntityBase(bert_path)
self.__dict_ner_label = ["X"]
self.__max_seq_len = 100
......@@ -52,21 +54,30 @@ class DataLoader(object):
def __parse_idx_sequence(self, pred, label):
res_pred, res_label = [], []
records = {}
for i in range(len(pred)):
tmp_pred, tmp_label = [], []
str_pred = " ".join([str(ele) for ele in pred[i].numpy().tolist()])
str_label = " ".join([str(ele) for ele in label[i].numpy().tolist()])
str_record = str_pred + str_label
if str_record in records:
tmp = records[str_record]
res_pred.append(tmp[0])
res_label.append(tmp[1])
for p, l in zip(pred[i], label[i]):
if self.__dict_ner_label[l] != "X":
tmp_label.append(self.__dict_ner_label[l])
if self.__dict_ner_label[p] == "X":
tmp_pred.append("O")
else:
tmp_pred.append(self.__dict_ner_label[p])
else:
for p, l in zip(pred[i], label[i]):
if self.__dict_ner_label[l] != "X":
tmp_label.append(self.__dict_ner_label[l])
res_pred.append(tmp_pred)
res_label.append(tmp_label)
if self.__dict_ner_label[p] == "X":
tmp_pred.append("O")
else:
tmp_pred.append(self.__dict_ner_label[p])
res_pred.append(tmp_pred)
res_label.append(tmp_label)
records[str_record] = (tmp_pred, tmp_label)
return res_pred, res_label
......@@ -75,7 +86,7 @@ class DataLoader(object):
pred, true = self.__parse_idx_sequence(pred, label)
y_real, pred_real = [], []
records = []
for i in range(len(real_len)):
for i in trange(len(real_len), ascii=True):
record = " ".join(true[i]) + str(real_len[i])
if record not in records:
records.append(record)
......@@ -133,7 +144,7 @@ class DataLoader(object):
nen_label = nen_label.numpy().tolist()
tmp_nen_pred = []
tmp_nen_label = []
for i in range(len(nen_label)):
for i in trange(len(nen_label), ascii=True):
n_entity = 0
if nen_label[i] == 1:
for e in ner_label_real:
......@@ -312,8 +323,9 @@ class DataLoader(object):
parsed_ent_indices = []
parsed_ent_segments = []
for sentence, ner, nen in zip(sentences, ner_label, nen_label):
samples = self.__extend_sample(sentence, ner, nen, 1)
for i in trange(len(sentences), ascii=True):
sentence, ner, nen = [ele[i] for ele in [sentences, ner_label, nen_label]]
samples = self.__extend_sample(sentence, ner, nen, 1, "test" in path)
pkd_sentence = samples["sentences"]
pkd_ner_tags = samples["ner"]
......@@ -367,7 +379,12 @@ class DataLoader(object):
return dataset
def __extend_sample(
self, sentence: List[int], ner_tag: List[int], nen_tag: Dict, n_neg: int
self,
sentence: List[int],
ner_tag: List[int],
nen_tag: Dict,
n_neg: int,
flag: bool = False,
) -> Dict:
"""
extend the sample to all samples with specific nen tags
......@@ -395,13 +412,25 @@ class DataLoader(object):
pkd_ent_sents.append(ent_sent)
pkd_nen_tags.append(1)
# Sampling negative entities
for i in range(n_neg):
pkd_ner_tags.append(["O"] * len(ner_tag))
pkd_cpt_ner_tags.append(ner_tag.copy())
pkd_sentences.append(sentence.copy())
ent_sent = self.__entity_base.random_entity(list(nen_tag.keys()))
pkd_ent_sents.append(ent_sent)
pkd_nen_tags.append(0)
if flag:
cands = self.__entity_base.generate_candidates(
sentence, list(nen_tag.keys())
)
for c in cands:
pkd_ner_tags.append(["O"] * len(ner_tag))
pkd_cpt_ner_tags.append(ner_tag.copy())
pkd_sentences.append(sentence.copy())
ent_sent = self.__entity_base.getItem(c)
pkd_ent_sents.append(ent_sent)
pkd_nen_tags.append(0)
else:
for i in range(n_neg):
pkd_ner_tags.append(["O"] * len(ner_tag))
pkd_cpt_ner_tags.append(ner_tag.copy())
pkd_sentences.append(sentence.copy())
ent_sent = self.__entity_base.random_entity(list(nen_tag.keys()))
pkd_ent_sents.append(ent_sent)
pkd_nen_tags.append(0)
return {
"ner": pkd_ner_tags,
......
import codecs
import numpy as np
from keras_bert import Tokenizer
class EntityBase(object):
def __init__(self):
self.base_path = './entitybase'
class EntityBase(object):
def __init__(self, bert_path: str):
self.base_path = "./entitybase"
self.mesh_path = f"{self.base_path}/mesh.tsv"
self.omim_path = f"{self.base_path}/omim.tsv"
self.__bert_path = f"{bert_path}/vocab.txt"
self.__tokenizer = self.__load_vocabulary()
self.entities = {}
self.__load_mesh()
self.__load_omim()
def __load_vocabulary(self):
token_dict = {}
with codecs.open(self.__bert_path, "r", "utf8") as reader:
for line in reader:
token = line.strip()
token_dict[token] = len(token_dict)
return Tokenizer(token_dict)
def __tokenize_entity(self, entity):
ind, _ = self.__tokenizer.encode(first=entity)
return ind
def random_entity(self, idxs):
'''
select one entity which are not in the list
generate negative sample for entity normalization
'''
"""
select one entity which are not in the list
generate negative sample for entity normalization
"""
dicts = list(self.entities.keys())
idx = np.random.choice(dicts)
while idx in idxs:
idx = np.random.choice(dicts)
return self.getItem(idx)
def generate_candidates(self, sentence, idxs):
cands = []
parsed_sentence = self.__tokenize_entity(" ".join(sentence))
for k, v in self.entities.items():
ent = v[1]
if len(set(ent[1:-1]) & set(parsed_sentence[1:-1])) > 2 and k not in idxs:
cands.append(k)
return cands
def getVocabs(self):
names = []
for k, v in self.entities.items():
names += v
names += v[0]
return names
def getItem(self, uid):
result = self.entities.get(uid, 'none')
result = self.entities.get(uid, ["none"])
# if result == 'none':
# print(f"UNDEFINED WARNING! [{uid}]")
return result
# print(f"UNDEFINED WARNING! [{uid}]")
return result[0]
@property
def Entity(self):
return self.entities
def __load_mesh(self):
data = np.loadtxt(self.mesh_path, delimiter='\t', dtype=str)
data = np.loadtxt(self.mesh_path, delimiter="\t", dtype=str)
for ele in data.tolist():
self.entities[ele[0]] = ele[1]
self.entities[ele[0]] = (ele[1], self.__tokenize_entity(ele[1]))
def __load_omim(self):
data = np.loadtxt(self.omim_path, delimiter='\t', dtype=str, encoding='utf-8')
data = np.loadtxt(self.omim_path, delimiter="\t", dtype=str, encoding="utf-8")
for ele in data.tolist():
self.entities[f"OMIM:{ele[0]}"] = ele[1]
self.entities[f"OMIM:{ele[0]}"] = (ele[1], self.__tokenize_entity(ele[1]))
if __name__ == '__main__':
if __name__ == "__main__":
eb = EntityBase()
vocabs = eb.getVocabs()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment