Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
xia yu
E2EMERN
提交
41ea7f5e
提交
41ea7f5e
编辑于
10月 09, 2021
作者:
Baohang Zhou
浏览文件
add candidate generation
上级
f016f766
变更
2
Show whitespace changes
Inline
Side-by-side
data_loader.py
浏览文件 @
41ea7f5e
import
copy
import
codecs
from
operator
import
getitem
import
numpy
as
np
from
tqdm
import
trange
import
tensorflow
as
tf
from
utils
import
BASEPATH
from
typing
import
List
,
Dict
...
...
@@ -18,7 +20,7 @@ class DataLoader(object):
self
.
__bert_path
=
f
"
{
bert_path
}
/vocab.txt"
self
.
__tokenizer
=
self
.
__load_vocabulary
()
self
.
__dict_dataset
=
dict_dataset
self
.
__entity_base
=
EntityBase
()
self
.
__entity_base
=
EntityBase
(
bert_path
)
self
.
__dict_ner_label
=
[
"X"
]
self
.
__max_seq_len
=
100
...
...
@@ -52,10 +54,19 @@ class DataLoader(object):
def
__parse_idx_sequence
(
self
,
pred
,
label
):
res_pred
,
res_label
=
[],
[]
records
=
{}
for
i
in
range
(
len
(
pred
)):
tmp_pred
,
tmp_label
=
[],
[]
str_pred
=
" "
.
join
([
str
(
ele
)
for
ele
in
pred
[
i
].
numpy
().
tolist
()])
str_label
=
" "
.
join
([
str
(
ele
)
for
ele
in
label
[
i
].
numpy
().
tolist
()])
str_record
=
str_pred
+
str_label
if
str_record
in
records
:
tmp
=
records
[
str_record
]
res_pred
.
append
(
tmp
[
0
])
res_label
.
append
(
tmp
[
1
])
else
:
for
p
,
l
in
zip
(
pred
[
i
],
label
[
i
]):
if
self
.
__dict_ner_label
[
l
]
!=
"X"
:
tmp_label
.
append
(
self
.
__dict_ner_label
[
l
])
...
...
@@ -64,9 +75,9 @@ class DataLoader(object):
tmp_pred
.
append
(
"O"
)
else
:
tmp_pred
.
append
(
self
.
__dict_ner_label
[
p
])
res_pred
.
append
(
tmp_pred
)
res_label
.
append
(
tmp_label
)
records
[
str_record
]
=
(
tmp_pred
,
tmp_label
)
return
res_pred
,
res_label
...
...
@@ -75,7 +86,7 @@ class DataLoader(object):
pred
,
true
=
self
.
__parse_idx_sequence
(
pred
,
label
)
y_real
,
pred_real
=
[],
[]
records
=
[]
for
i
in
range
(
len
(
real_len
)):
for
i
in
t
range
(
len
(
real_len
)
,
ascii
=
True
):
record
=
" "
.
join
(
true
[
i
])
+
str
(
real_len
[
i
])
if
record
not
in
records
:
records
.
append
(
record
)
...
...
@@ -133,7 +144,7 @@ class DataLoader(object):
nen_label
=
nen_label
.
numpy
().
tolist
()
tmp_nen_pred
=
[]
tmp_nen_label
=
[]
for
i
in
range
(
len
(
nen_label
)):
for
i
in
t
range
(
len
(
nen_label
)
,
ascii
=
True
):
n_entity
=
0
if
nen_label
[
i
]
==
1
:
for
e
in
ner_label_real
:
...
...
@@ -312,8 +323,9 @@ class DataLoader(object):
parsed_ent_indices
=
[]
parsed_ent_segments
=
[]
for
sentence
,
ner
,
nen
in
zip
(
sentences
,
ner_label
,
nen_label
):
samples
=
self
.
__extend_sample
(
sentence
,
ner
,
nen
,
1
)
for
i
in
trange
(
len
(
sentences
),
ascii
=
True
):
sentence
,
ner
,
nen
=
[
ele
[
i
]
for
ele
in
[
sentences
,
ner_label
,
nen_label
]]
samples
=
self
.
__extend_sample
(
sentence
,
ner
,
nen
,
1
,
"test"
in
path
)
pkd_sentence
=
samples
[
"sentences"
]
pkd_ner_tags
=
samples
[
"ner"
]
...
...
@@ -367,7 +379,12 @@ class DataLoader(object):
return
dataset
def
__extend_sample
(
self
,
sentence
:
List
[
int
],
ner_tag
:
List
[
int
],
nen_tag
:
Dict
,
n_neg
:
int
self
,
sentence
:
List
[
int
],
ner_tag
:
List
[
int
],
nen_tag
:
Dict
,
n_neg
:
int
,
flag
:
bool
=
False
,
)
->
Dict
:
"""
extend the sample to all samples with specific nen tags
...
...
@@ -395,6 +412,18 @@ class DataLoader(object):
pkd_ent_sents
.
append
(
ent_sent
)
pkd_nen_tags
.
append
(
1
)
# Sampling negative entities
if
flag
:
cands
=
self
.
__entity_base
.
generate_candidates
(
sentence
,
list
(
nen_tag
.
keys
())
)
for
c
in
cands
:
pkd_ner_tags
.
append
([
"O"
]
*
len
(
ner_tag
))
pkd_cpt_ner_tags
.
append
(
ner_tag
.
copy
())
pkd_sentences
.
append
(
sentence
.
copy
())
ent_sent
=
self
.
__entity_base
.
getItem
(
c
)
pkd_ent_sents
.
append
(
ent_sent
)
pkd_nen_tags
.
append
(
0
)
else
:
for
i
in
range
(
n_neg
):
pkd_ner_tags
.
append
([
"O"
]
*
len
(
ner_tag
))
pkd_cpt_ner_tags
.
append
(
ner_tag
.
copy
())
...
...
entitybase/entity_base_loader.py
浏览文件 @
41ea7f5e
import
codecs
import
numpy
as
np
from
keras_bert
import
Tokenizer
class
EntityBase
(
object
):
def
__init__
(
self
):
self
.
base_path
=
'./entitybase'
class
EntityBase
(
object
):
def
__init__
(
self
,
bert_path
:
str
):
self
.
base_path
=
"./entitybase"
self
.
mesh_path
=
f
"
{
self
.
base_path
}
/mesh.tsv"
self
.
omim_path
=
f
"
{
self
.
base_path
}
/omim.tsv"
self
.
__bert_path
=
f
"
{
bert_path
}
/vocab.txt"
self
.
__tokenizer
=
self
.
__load_vocabulary
()
self
.
entities
=
{}
self
.
__load_mesh
()
self
.
__load_omim
()
def
__load_vocabulary
(
self
):
token_dict
=
{}
with
codecs
.
open
(
self
.
__bert_path
,
"r"
,
"utf8"
)
as
reader
:
for
line
in
reader
:
token
=
line
.
strip
()
token_dict
[
token
]
=
len
(
token_dict
)
return
Tokenizer
(
token_dict
)
def
__tokenize_entity
(
self
,
entity
):
ind
,
_
=
self
.
__tokenizer
.
encode
(
first
=
entity
)
return
ind
def
random_entity
(
self
,
idxs
):
'''
"""
select one entity which are not in the list
generate negative sample for entity normalization
'''
"""
dicts
=
list
(
self
.
entities
.
keys
())
idx
=
np
.
random
.
choice
(
dicts
)
while
idx
in
idxs
:
idx
=
np
.
random
.
choice
(
dicts
)
return
self
.
getItem
(
idx
)
def
generate_candidates
(
self
,
sentence
,
idxs
):
cands
=
[]
parsed_sentence
=
self
.
__tokenize_entity
(
" "
.
join
(
sentence
))
for
k
,
v
in
self
.
entities
.
items
():
ent
=
v
[
1
]
if
len
(
set
(
ent
[
1
:
-
1
])
&
set
(
parsed_sentence
[
1
:
-
1
]))
>
2
and
k
not
in
idxs
:
cands
.
append
(
k
)
return
cands
def
getVocabs
(
self
):
names
=
[]
for
k
,
v
in
self
.
entities
.
items
():
names
+=
v
names
+=
v
[
0
]
return
names
def
getItem
(
self
,
uid
):
result
=
self
.
entities
.
get
(
uid
,
'
none
'
)
result
=
self
.
entities
.
get
(
uid
,
[
"
none
"
]
)
# if result == 'none':
# print(f"UNDEFINED WARNING! [{uid}]")
return
result
return
result
[
0
]
@
property
def
Entity
(
self
):
return
self
.
entities
def
__load_mesh
(
self
):
data
=
np
.
loadtxt
(
self
.
mesh_path
,
delimiter
=
'
\t
'
,
dtype
=
str
)
data
=
np
.
loadtxt
(
self
.
mesh_path
,
delimiter
=
"
\t
"
,
dtype
=
str
)
for
ele
in
data
.
tolist
():
self
.
entities
[
ele
[
0
]]
=
ele
[
1
]
self
.
entities
[
ele
[
0
]]
=
(
ele
[
1
]
,
self
.
__tokenize_entity
(
ele
[
1
]))
def
__load_omim
(
self
):
data
=
np
.
loadtxt
(
self
.
omim_path
,
delimiter
=
'
\t
'
,
dtype
=
str
,
encoding
=
'
utf-8
'
)
data
=
np
.
loadtxt
(
self
.
omim_path
,
delimiter
=
"
\t
"
,
dtype
=
str
,
encoding
=
"
utf-8
"
)
for
ele
in
data
.
tolist
():
self
.
entities
[
f
"OMIM:
{
ele
[
0
]
}
"
]
=
ele
[
1
]
self
.
entities
[
f
"OMIM:
{
ele
[
0
]
}
"
]
=
(
ele
[
1
],
self
.
__tokenize_entity
(
ele
[
1
]))
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
eb
=
EntityBase
()
vocabs
=
eb
.
getVocabs
()
\ No newline at end of file
编辑
预览
Supports
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录