提交 b8c5b61e 编辑于 作者: Tuan Lai's avatar Tuan Lai
浏览文件

Add code

上级 44911f76
.DS_Store
tmp/
relevant_types.txt
non_relevant_types.txt
knowledge_module_logs.txt
visualizations/
resources/
logs/
caches/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Joint Biomedical Entity and Relation Extraction with Knowledge-Enhanced Collective Inference
The code will be available soon. We are cleaning our codebase.
We are cleaning the codebase and will release more running instructions soon.
# ADE
mkdir logs
CUDA_VISIBLE_DEVICES=1 python ade_trainer.py -s 0 -e 3 > logs/ade_basic_0.txt &
CUDA_VISIBLE_DEVICES=2 python ade_trainer.py -s 3 -e 6 > logs/ade_basic_1.txt &
CUDA_VISIBLE_DEVICES=3 python ade_trainer.py -s 6 -e 10 > logs/ade_basic_2.txt &
CUDA_VISIBLE_DEVICES=1 python ade_trainer.py -s 0 -e 3 -c with_external_knowledge > logs/ade_with_external_knowledge_0.txt &
CUDA_VISIBLE_DEVICES=2 python ade_trainer.py -s 3 -e 6 -c with_external_knowledge > logs/ade_with_external_knowledge_1.txt &
CUDA_VISIBLE_DEVICES=3 python ade_trainer.py -s 6 -e 10 -c with_external_knowledge > logs/ade_with_external_knowledge_2.txt
import os
import copy
import utils
import torch
import json
import random
import math
import pyhocon
import warnings
import numpy as np
import torch.nn as nn
import torch.optim as optim
from utils import *
from constants import *
from transformers import *
from data import load_data
from scorer import evaluate
from models import JointModel
from argparse import ArgumentParser
from trainer import train
if __name__ == "__main__":
# Parse argument
parser = ArgumentParser()
parser.add_argument('-s', '--start_nb', default=0)
parser.add_argument('-e', '--end_nb', default=10)
parser.add_argument('-c', '--config_name', default='basic')
args = parser.parse_args()
# Determine the range
start_nb = int(args.start_nb); assert(0 <= start_nb and start_nb < 10)
end_nb = int(args.end_nb); assert(0 <= end_nb and end_nb <= 10)
split_nb_ranges = list(range(start_nb, end_nb))
dev_scores = []
for split_nb in split_nb_ranges:
configs = prepare_configs(args.config_name, ADE, split_nb)
configs['gradient_checkpointing'] = False
dev_scores.append(train(configs))
print('end of ade training')
print(dev_scores)
with open('ade_10_fold_results.json', 'w+') as f:
f.write(json.dumps(dev_scores))
import numpy as np
skipped = 0
cur_configs = []
semtype2values = {}
with open('knowledge_module_logs.txt', 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line.startswith('text'):
cur_configs = []
elif line.startswith('c = {'):
s_index, e_index = line.index('{'), line.index('}')
semtypes = eval(line[s_index:e_index+1])['semtypes']
cur_configs.append(semtypes)
elif line.startswith('tensor'):
try:
s_index, e_index = line.index('['), line.index(']')
except:
skipped += 1
continue
prob_values = eval(line[s_index:e_index+1])
for ix in range(len(cur_configs)):
for semtype in cur_configs[ix]:
if not semtype in semtype2values: semtype2values[semtype] = []
semtype2values[semtype].append(prob_values[ix])
print('skipped = {}'.format(skipped))
all_semtypes = []
semtype2avgval, semtype2max, semtype2min, semtype2ctx = {}, {}, {}, {}
semtype290percentile = {}
for semtype in semtype2values:
semtype2ctx[semtype] = len(semtype2values[semtype])
semtype2avgval[semtype] = round(np.average(semtype2values[semtype]), 3)
semtype2max[semtype] = round(np.max(semtype2values[semtype]), 3)
semtype2min[semtype] = round(np.min(semtype2values[semtype]), 3)
semtype290percentile[semtype] = round(np.percentile(semtype2values[semtype], 90), 3)
all_semtypes.append(semtype)
all_semtypes = sorted(all_semtypes, key=lambda x: semtype2avgval[x], reverse=True)
for semtype in all_semtypes:
print('[{}] avg = {} | 90 percentile = {} | max = {} | min = {} | ctx = {}'.format(semtype, semtype2avgval[semtype], semtype290percentile[semtype], semtype2max[semtype], semtype2min[semtype], semtype2ctx[semtype]))
import json
import pickle
import numpy as np
from constants import *
from os.path import join
from sklearn.dummy import DummyClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression
from external_knowledge import umls_search_concepts
all_sents = []
# Extract all sentences from ADE
with open('resources/ade/ade_full.json', 'r') as f:
ade_data = json.loads(f.read())
for inst in ade_data:
all_sents.append(' '.join(inst['tokens']))
# Extract all sentences from BioRelEx
for split in ['train', 'dev','test']:
with open('resources/biorelex/{}.json'.format(split), 'r') as f:
biorelex_data = json.loads(f.read())
for inst in biorelex_data:
all_sents.append(inst['text'])
print('len(all_sents) is {}'.format(len(all_sents)))
# Extract all concepts
all_concepts, appeared = [], set()
cuid2embs = pickle.load(open(UMLS_EMBS, 'rb'))
for text in all_sents:
concepts = umls_search_concepts([text], MM_TYPES)[0][0]['concepts']
concepts = [c for c in concepts if c['cui'] in cuid2embs and (not c['cui'] in appeared)]
for c in concepts: appeared.add(c['cui'])
all_concepts += concepts
print('Total number of concepts with embs is {}'.format(len(all_concepts)))
# Divide into train/test set (given the emb, predict the semtype)
split_index = int(0.9 * len(all_concepts))
train_concepts = all_concepts[:split_index]
test_concepts = all_concepts[split_index:]
train_data = [(cuid2embs[c['cui']], MM_TYPES.index(c['semtypes'][0])) for c in train_concepts]
test_data = [(cuid2embs[c['cui']], MM_TYPES.index(c['semtypes'][0])) for c in test_concepts]
# Logistic Regression
train_X = np.array([d[0] for d in train_data])
train_Y = np.array([d[1] for d in train_data])
test_X = np.array([d[0] for d in test_data])
test_Y = np.array([d[1] for d in test_data])
print('train_X: {} | train_Y: {} | test_X: {} | test_Y: {}'.format(train_X.shape, train_Y.shape, test_X.shape, test_Y.shape))
print('MLPClassifier')
clf = MLPClassifier(random_state=1, max_iter=10000, verbose=True).fit(train_X, train_Y)
print('Train Score: {}'.format(clf.score(train_X, train_Y)))
print('Test Score: {}'.format(clf.score(test_X, test_Y)))
# Dummy Classifiers (Most Frequent)
for strategy in ['most_frequent', 'stratified']:
print('Dummy Classifiers ({})'.format(strategy))
dummy_clf = DummyClassifier(strategy=strategy)
dummy_clf.fit(train_X, train_Y)
print('Train Score: {}'.format(dummy_clf.score(train_X, train_Y)))
print('Test Score: {}'.format(dummy_clf.score(test_X, test_Y)))
import os
import spacy
import copy
import utils
import torch
import random
import math
import json
import pickle
import pyhocon
import warnings
import numpy as np
import torch.nn as nn
import torch.optim as optim
from utils import *
from constants import *
from transformers import *
from os.path import join
from spacy import displacy
from data import load_data
from scorer import evaluate
from models import JointModel
from argparse import ArgumentParser
from collections import Counter
from scorer.ade import get_relation_mentions
from external_knowledge import umls_search_concepts
SHOW_ERRORS_ONLY = True
def get_entity_mentions(sentence):
typed_mentions = []
for cluster in sentence['entities']:
for alias, entity in cluster['names'].items():
for start, end in entity['mentions']:
if entity['is_mentioned']:
typed_mentions.append({'start': start, 'end': end, 'label': cluster['label']})
typed_mentions.sort(key=lambda x: x['start'])
return typed_mentions
def graph_from_sent(sent_text, sent):
nodes, edges = [], []
ents = get_entity_mentions(sent)
for e in ents:
e['text'] = sent_text[e['start']:e['end']]
nodes.append(e)
rels = get_relation_mentions(sent)
for r in rels:
head_loc = [int(l) for l in r['head'].split('_')]
tail_loc = [int(l) for l in r['tail'].split('_')]
r['head_text'] = sent_text[head_loc[0]:head_loc[1]]
r['tail_text'] = sent_text[tail_loc[0]:tail_loc[1]]
edges.append(r)
return {'nodes': nodes, 'edges': edges}
if __name__ == "__main__":
# Parse argument
parser = ArgumentParser()
parser.add_argument('-m', '--trained_model', default='model.pt')
parser.add_argument('-c', '--config_name', default='basic')
parser.add_argument('-d', '--dataset', default=BIORELEX, choices=DATASETS)
parser.add_argument('-s', '--split_nb', default=0) # Only affect ADE dataset
args = parser.parse_args()
args.split_nb = int(args.split_nb)
# Reload components
configs = prepare_configs(args.config_name, args.dataset, args.split_nb)
tokenizer = AutoTokenizer.from_pretrained(configs['transformer'])
train, dev = load_data(configs['dataset'], configs['split_nb'], tokenizer)
model = JointModel(configs)
# Reload a model
assert (os.path.exists(args.trained_model))
checkpoint = torch.load(args.trained_model, map_location=model.device)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
print('Reloaded a pretrained model')
# Evaluation on the dev set
print('Evaluation on the dev set')
if configs['use_external_knowledge']:
model.knowledge_enhancer.start_logging()
evaluate(model, dev, configs['dataset'])
if configs['use_external_knowledge']:
model.knowledge_enhancer.end_logging()
# Visualize predictions and groundtruths (on the dev set)
truths, preds, sent2truthgraph, sent2predgraph = [], [], {}, {}
total_ents, ents_covered_by_metamap = 0, 0
total_fns, fn_covered_by_metamap = 0, 0
relevants_covered_by_metamap, non_relevants_covered_by_metamap = [], []
cuid2embs = pickle.load(open(UMLS_EMBS, 'rb'))
with torch.no_grad():
for i in range(len(dev)):
truth_sentence = dev[i].data
truth_ents = get_entity_mentions(truth_sentence)
pred_sentence = model.predict(dev[i])
pred_ents = get_entity_mentions(pred_sentence)
# Update sent2truthgraph and sent2predgraph
sent_text = dev[i].text
sent2truthgraph[sent_text] = graph_from_sent(sent_text, truth_sentence)
sent2predgraph[sent_text] = graph_from_sent(sent_text, pred_sentence)
# check if the prediction is the same as the annotation
if SHOW_ERRORS_ONLY:
typed_truths = set({(x['start'], x['end'], x['label']) for x in truth_ents})
typed_preds = set({(x['start'], x['end'], x['label']) for x in pred_ents})
if typed_truths == typed_preds:
# Skip
continue
# Update truths
truths.append({
'text': truth_sentence['text'], 'title': None,
'ents': truth_ents,
})
# Update preds
preds.append({
'text': pred_sentence['text'], 'title': None,
'ents': pred_ents,
})
# Investigate the usefulness of MetaMap
text = truth_sentence['text']
umls_concepts = umls_search_concepts([text])[0][0]['concepts']
umls_concepts = [c for c in umls_concepts if c['cui'] in cuid2embs]
loc_umls = set({(x['start_char'], x['end_char']) for x in umls_concepts})
loc_truths = set({(x['start'], x['end']) for x in truth_ents})
loc_preds = set({(x['start'], x['end']) for x in pred_ents})
false_negatives = loc_truths - loc_preds
# Compute number of ground-truth entities covered by MetaMap
total_ents += len(loc_truths)
ents_covered_by_metamap += len(loc_truths.intersection(loc_umls))
for c in umls_concepts:
if (c['start_char'], c['end_char']) in loc_truths:
relevants_covered_by_metamap.append(c)
# Compute number of false negatives covered by MetaMap
total_fns += len(false_negatives)
fn_covered_by_metamap += len(false_negatives.intersection(loc_umls))
# Compute number of MetaMap concepts that are not considered as entities in the dataset
for c in umls_concepts:
if not (c['start_char'], c['end_char']) in loc_truths:
non_relevants_covered_by_metamap.append(c)
# Output sent2truthgraph, sent2predgraph
with open('sent2truthgraph.json', 'w+') as f:
f.write(json.dumps(sent2truthgraph))
with open('sent2predgraph.json', 'w+') as f:
f.write(json.dumps(sent2predgraph))
# Write out relevant types
relevant_types = flatten([c['semtypes'] for c in relevants_covered_by_metamap])
with open('relevant_types.txt', 'w+') as f:
f.write(json.dumps(Counter(relevant_types)))
# Write out non relevant types
non_relevant_types = flatten([c['semtypes'] for c in non_relevants_covered_by_metamap])
with open('non_relevant_types.txt', 'w+') as f:
f.write(json.dumps(Counter(non_relevant_types)))
print('types can be discarded: {}'.format(set(non_relevant_types) - set(relevant_types)))
print('non_relevants_covered_by_metamap = {}'.format(len(non_relevants_covered_by_metamap)))
print('ents_covered_by_metamap = {} | total_ents = {}'.format(ents_covered_by_metamap, total_ents))
print('fn_covered_by_metamap = {} | total_fns = {}'.format(fn_covered_by_metamap, total_fns))
# Generate html file
output_dir = 'visualizations/{}_{}'.format(args.dataset, args.split_nb)
os.makedirs(output_dir, exist_ok=True)
truth_html = displacy.render(truths, style="ent", page=True, manual=True)
pred_html = displacy.render(preds, style="ent", page=True, manual=True)
with open(join(output_dir, 'truths.html'), 'w+', encoding='utf-8') as f:
f.write(truth_html)
with open(join(output_dir, 'preds.html'), 'w+', encoding='utf-8') as f:
f.write(pred_html)
# -*- coding: utf-8 -*-
from __future__ import print_function
import io
import os
import json
import argparse
import torch
from constants import *
from transformers import *
from models import JointModel
from utils import prepare_configs
from data import DataInstance, tokenize
def load_components(model_path, config_name = 'basic'):
configs = prepare_configs(config_name, BIORELEX, 0)
tokenizer = AutoTokenizer.from_pretrained(configs['transformer'])
model = JointModel(configs)
checkpoint = torch.load(model_path, map_location=model.device)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
return tokenizer, model
def predict(model, tokenizer, sample):
id, text = sample['id'], sample['text']
test_sample = DataInstance(sample, id, text, tokenize(tokenizer, text.split(' ')))
with torch.no_grad():
pred_sample = model.predict(test_sample)
return pred_sample
def main():
parser = argparse.ArgumentParser()
parser.add_argument('model_dir', type=str,
help='Path to the pretrained model.')
parser.add_argument('input_dir', type=str,
help='Path to directory containing input.json.')
parser.add_argument('output_dir', type=str,
help='Path to output directory to write predictions.json in.')
parser.add_argument('shared_dir', type=str,
help='Path to shared directory.')
args = parser.parse_args()
# Collect information on known relations
self_path = os.path.realpath(__file__)
self_dir = os.path.dirname(self_path)
# Load main components
tokenizer, model = load_components(args.model_dir)
# Read input samples and predict w.r.t. set of relations.
input_json_path = os.path.join(args.input_dir, 'input.json')
output_json_path = os.path.join(args.output_dir, 'predictions.json')
with io.open(input_json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
predictions = []
for sample in data:
sample = predict(model, tokenizer, sample)
predictions.append(sample)
with open(output_json_path, 'w') as f:
json.dump(predictions, f, indent=True)
if __name__ == "__main__":
main()
#!/bin/bash
pip3 install -r $1/requirements.txt && python3 $1/biorelex_code.py $1/model.pt $2 $3 $4
# BioRelEx
mkdir logs
CUDA_VISIBLE_DEVICES=1 python trainer.py -s 0 > logs/biorelex_basic_0.txt &
CUDA_VISIBLE_DEVICES=2 python trainer.py -s 1 > logs/biorelex_basic_1.txt &
CUDA_VISIBLE_DEVICES=3 python trainer.py -s 2 > logs/biorelex_basic_2.txt &
CUDA_VISIBLE_DEVICES=1 python trainer.py -c with_external_knowledge -s 0 > logs/biorelex_with_external_knowledge_0.txt &
CUDA_VISIBLE_DEVICES=2 python trainer.py -c with_external_knowledge -s 1 > logs/biorelex_with_external_knowledge_1.txt &
CUDA_VISIBLE_DEVICES=3 python trainer.py -c with_external_knowledge -s 2 > logs/biorelex_with_external_knowledge_2.txt
import os
import spacy
import copy
import utils
import torch
import random
import math
import json
import pickle
import pyhocon
import warnings
import numpy as np
import torch.nn as nn
import torch.optim as optim
from utils import *
from constants import *
from transformers import *
from os import listdir
from os.path import isfile, join
from spacy import displacy
from data import load_data
from scorer import evaluate
from models import JointModel
from argparse import ArgumentParser
from collections import Counter
from external_knowledge import umls_search_concepts
if __name__ == "__main__":
# Parse argument
parser = ArgumentParser()
parser.add_argument('-m', '--models_dir', required=True)
args = parser.parse_args()
# Extract models_dir, dataset, and config_name
models_dir = args.models_dir
dir_name = os.path.basename(os.path.normpath(models_dir))
parts = dir_name.split('_')
dataset, config_name = parts[0], '_'.join(parts[1:])
print('models_dir = {} | dataset = {} | config_name = {}'.format(models_dir, dataset, config_name))
<