Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
xia yu
E2EMERN
提交
41ea7f5e
提交
41ea7f5e
编辑于
10月 09, 2021
作者:
Baohang Zhou
浏览文件
add candidate generation
上级
f016f766
变更
2
Hide whitespace changes
Inline
Side-by-side
data_loader.py
浏览文件 @
41ea7f5e
import
copy
import
codecs
from
operator
import
getitem
import
numpy
as
np
from
tqdm
import
trange
import
tensorflow
as
tf
from
utils
import
BASEPATH
from
typing
import
List
,
Dict
...
...
@@ -18,7 +20,7 @@ class DataLoader(object):
self
.
__bert_path
=
f
"
{
bert_path
}
/vocab.txt"
self
.
__tokenizer
=
self
.
__load_vocabulary
()
self
.
__dict_dataset
=
dict_dataset
self
.
__entity_base
=
EntityBase
()
self
.
__entity_base
=
EntityBase
(
bert_path
)
self
.
__dict_ner_label
=
[
"X"
]
self
.
__max_seq_len
=
100
...
...
@@ -52,21 +54,30 @@ class DataLoader(object):
def
__parse_idx_sequence
(
self
,
pred
,
label
):
res_pred
,
res_label
=
[],
[]
records
=
{}
for
i
in
range
(
len
(
pred
)):
tmp_pred
,
tmp_label
=
[],
[]
str_pred
=
" "
.
join
([
str
(
ele
)
for
ele
in
pred
[
i
].
numpy
().
tolist
()])
str_label
=
" "
.
join
([
str
(
ele
)
for
ele
in
label
[
i
].
numpy
().
tolist
()])
str_record
=
str_pred
+
str_label
if
str_record
in
records
:
tmp
=
records
[
str_record
]
res_pred
.
append
(
tmp
[
0
])
res_label
.
append
(
tmp
[
1
])
for
p
,
l
in
zip
(
pred
[
i
],
label
[
i
]):
if
self
.
__dict_ner_label
[
l
]
!=
"X"
:
tmp_label
.
append
(
self
.
__dict_ner_label
[
l
])
if
self
.
__dict_ner_label
[
p
]
==
"X"
:
tmp_pred
.
append
(
"O"
)
else
:
tmp_pred
.
append
(
self
.
__dict_ner_label
[
p
])
else
:
for
p
,
l
in
zip
(
pred
[
i
],
label
[
i
]):
if
self
.
__dict_ner_label
[
l
]
!=
"X"
:
tmp_label
.
append
(
self
.
__dict_ner_label
[
l
])
res_pred
.
append
(
tmp_pred
)
res_label
.
append
(
tmp_label
)
if
self
.
__dict_ner_label
[
p
]
==
"X"
:
tmp_pred
.
append
(
"O"
)
else
:
tmp_pred
.
append
(
self
.
__dict_ner_label
[
p
])
res_pred
.
append
(
tmp_pred
)
res_label
.
append
(
tmp_label
)
records
[
str_record
]
=
(
tmp_pred
,
tmp_label
)
return
res_pred
,
res_label
...
...
@@ -75,7 +86,7 @@ class DataLoader(object):
pred
,
true
=
self
.
__parse_idx_sequence
(
pred
,
label
)
y_real
,
pred_real
=
[],
[]
records
=
[]
for
i
in
range
(
len
(
real_len
)):
for
i
in
t
range
(
len
(
real_len
)
,
ascii
=
True
):
record
=
" "
.
join
(
true
[
i
])
+
str
(
real_len
[
i
])
if
record
not
in
records
:
records
.
append
(
record
)
...
...
@@ -133,7 +144,7 @@ class DataLoader(object):
nen_label
=
nen_label
.
numpy
().
tolist
()
tmp_nen_pred
=
[]
tmp_nen_label
=
[]
for
i
in
range
(
len
(
nen_label
)):
for
i
in
t
range
(
len
(
nen_label
)
,
ascii
=
True
):
n_entity
=
0
if
nen_label
[
i
]
==
1
:
for
e
in
ner_label_real
:
...
...
@@ -312,8 +323,9 @@ class DataLoader(object):
parsed_ent_indices
=
[]
parsed_ent_segments
=
[]
for
sentence
,
ner
,
nen
in
zip
(
sentences
,
ner_label
,
nen_label
):
samples
=
self
.
__extend_sample
(
sentence
,
ner
,
nen
,
1
)
for
i
in
trange
(
len
(
sentences
),
ascii
=
True
):
sentence
,
ner
,
nen
=
[
ele
[
i
]
for
ele
in
[
sentences
,
ner_label
,
nen_label
]]
samples
=
self
.
__extend_sample
(
sentence
,
ner
,
nen
,
1
,
"test"
in
path
)
pkd_sentence
=
samples
[
"sentences"
]
pkd_ner_tags
=
samples
[
"ner"
]
...
...
@@ -367,7 +379,12 @@ class DataLoader(object):
return
dataset
def
__extend_sample
(
self
,
sentence
:
List
[
int
],
ner_tag
:
List
[
int
],
nen_tag
:
Dict
,
n_neg
:
int
self
,
sentence
:
List
[
int
],
ner_tag
:
List
[
int
],
nen_tag
:
Dict
,
n_neg
:
int
,
flag
:
bool
=
False
,
)
->
Dict
:
"""
extend the sample to all samples with specific nen tags
...
...
@@ -395,13 +412,25 @@ class DataLoader(object):
pkd_ent_sents
.
append
(
ent_sent
)
pkd_nen_tags
.
append
(
1
)
# Sampling negative entities
for
i
in
range
(
n_neg
):
pkd_ner_tags
.
append
([
"O"
]
*
len
(
ner_tag
))
pkd_cpt_ner_tags
.
append
(
ner_tag
.
copy
())
pkd_sentences
.
append
(
sentence
.
copy
())
ent_sent
=
self
.
__entity_base
.
random_entity
(
list
(
nen_tag
.
keys
()))
pkd_ent_sents
.
append
(
ent_sent
)
pkd_nen_tags
.
append
(
0
)
if
flag
:
cands
=
self
.
__entity_base
.
generate_candidates
(
sentence
,
list
(
nen_tag
.
keys
())
)
for
c
in
cands
:
pkd_ner_tags
.
append
([
"O"
]
*
len
(
ner_tag
))
pkd_cpt_ner_tags
.
append
(
ner_tag
.
copy
())
pkd_sentences
.
append
(
sentence
.
copy
())
ent_sent
=
self
.
__entity_base
.
getItem
(
c
)
pkd_ent_sents
.
append
(
ent_sent
)
pkd_nen_tags
.
append
(
0
)
else
:
for
i
in
range
(
n_neg
):
pkd_ner_tags
.
append
([
"O"
]
*
len
(
ner_tag
))
pkd_cpt_ner_tags
.
append
(
ner_tag
.
copy
())
pkd_sentences
.
append
(
sentence
.
copy
())
ent_sent
=
self
.
__entity_base
.
random_entity
(
list
(
nen_tag
.
keys
()))
pkd_ent_sents
.
append
(
ent_sent
)
pkd_nen_tags
.
append
(
0
)
return
{
"ner"
:
pkd_ner_tags
,
...
...
entitybase/entity_base_loader.py
浏览文件 @
41ea7f5e
import
codecs
import
numpy
as
np
from
keras_bert
import
Tokenizer
class
EntityBase
(
object
):
def
__init__
(
self
):
self
.
base_path
=
'./entitybase'
class
EntityBase
(
object
):
def
__init__
(
self
,
bert_path
:
str
):
self
.
base_path
=
"./entitybase"
self
.
mesh_path
=
f
"
{
self
.
base_path
}
/mesh.tsv"
self
.
omim_path
=
f
"
{
self
.
base_path
}
/omim.tsv"
self
.
__bert_path
=
f
"
{
bert_path
}
/vocab.txt"
self
.
__tokenizer
=
self
.
__load_vocabulary
()
self
.
entities
=
{}
self
.
__load_mesh
()
self
.
__load_omim
()
def
__load_vocabulary
(
self
):
token_dict
=
{}
with
codecs
.
open
(
self
.
__bert_path
,
"r"
,
"utf8"
)
as
reader
:
for
line
in
reader
:
token
=
line
.
strip
()
token_dict
[
token
]
=
len
(
token_dict
)
return
Tokenizer
(
token_dict
)
def
__tokenize_entity
(
self
,
entity
):
ind
,
_
=
self
.
__tokenizer
.
encode
(
first
=
entity
)
return
ind
def
random_entity
(
self
,
idxs
):
'''
select one entity which are not in the list
generate negative sample for entity normalization
'''
"""
select one entity which are not in the list
generate negative sample for entity normalization
"""
dicts
=
list
(
self
.
entities
.
keys
())
idx
=
np
.
random
.
choice
(
dicts
)
while
idx
in
idxs
:
idx
=
np
.
random
.
choice
(
dicts
)
return
self
.
getItem
(
idx
)
def
generate_candidates
(
self
,
sentence
,
idxs
):
cands
=
[]
parsed_sentence
=
self
.
__tokenize_entity
(
" "
.
join
(
sentence
))
for
k
,
v
in
self
.
entities
.
items
():
ent
=
v
[
1
]
if
len
(
set
(
ent
[
1
:
-
1
])
&
set
(
parsed_sentence
[
1
:
-
1
]))
>
2
and
k
not
in
idxs
:
cands
.
append
(
k
)
return
cands
def
getVocabs
(
self
):
names
=
[]
for
k
,
v
in
self
.
entities
.
items
():
names
+=
v
names
+=
v
[
0
]
return
names
def
getItem
(
self
,
uid
):
result
=
self
.
entities
.
get
(
uid
,
'
none
'
)
result
=
self
.
entities
.
get
(
uid
,
[
"
none
"
]
)
# if result == 'none':
# print(f"UNDEFINED WARNING! [{uid}]")
return
result
# print(f"UNDEFINED WARNING! [{uid}]")
return
result
[
0
]
@
property
def
Entity
(
self
):
return
self
.
entities
def
__load_mesh
(
self
):
data
=
np
.
loadtxt
(
self
.
mesh_path
,
delimiter
=
'
\t
'
,
dtype
=
str
)
data
=
np
.
loadtxt
(
self
.
mesh_path
,
delimiter
=
"
\t
"
,
dtype
=
str
)
for
ele
in
data
.
tolist
():
self
.
entities
[
ele
[
0
]]
=
ele
[
1
]
self
.
entities
[
ele
[
0
]]
=
(
ele
[
1
]
,
self
.
__tokenize_entity
(
ele
[
1
]))
def
__load_omim
(
self
):
data
=
np
.
loadtxt
(
self
.
omim_path
,
delimiter
=
'
\t
'
,
dtype
=
str
,
encoding
=
'
utf-8
'
)
data
=
np
.
loadtxt
(
self
.
omim_path
,
delimiter
=
"
\t
"
,
dtype
=
str
,
encoding
=
"
utf-8
"
)
for
ele
in
data
.
tolist
():
self
.
entities
[
f
"OMIM:
{
ele
[
0
]
}
"
]
=
ele
[
1
]
self
.
entities
[
f
"OMIM:
{
ele
[
0
]
}
"
]
=
(
ele
[
1
],
self
.
__tokenize_entity
(
ele
[
1
]))
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
eb
=
EntityBase
()
vocabs
=
eb
.
getVocabs
()
\ No newline at end of file
编辑
预览
Supports
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录