-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathextract_data_amc.py
More file actions
211 lines (183 loc) · 8.02 KB
/
extract_data_amc.py
File metadata and controls
211 lines (183 loc) · 8.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from dateutil.relativedelta import relativedelta
from faker import Faker
import random
from transformers import set_seed
from sklearn.model_selection import train_test_split
from functools import partial
import numpy as np
from collections import Counter
import time
import sys
# Constants
TRAIN_RATIO = 0.9
PAD_TOKEN = "<PAD>"
UNKNOWN_TOKEN = "<UNK>"
MIN_TARGET_COUNT = 10
ID_COLUMN = "_id"
TEXT_COLUMN = "text"
TARGET_COLUMN = "target"
SUBJECT_ID_COLUMN = "subject_id"
# Set the random seed
random.seed(42)
set_seed(42)
# Set the style for the plots
sns.set_style("whitegrid")
# Initialize the Faker library
fake = Faker()
Faker.seed(42)
def print_flush(*args, **kwargs):
"""Print with immediate flush."""
print(*args, **kwargs, flush=True)
def reformat_icd10(code: str, is_diag: bool) -> str:
"""Put a period in the right place for ICD-10 codes."""
code = "".join(code.split("."))
if not is_diag:
return code
return code[:3] + "." + code[3:]
def reformat_icd9(code: str, is_diag: bool) -> str:
"""Put a period in the right place for ICD-9 codes."""
code = "".join(code.split("."))
if is_diag:
if code.startswith("E"):
if len(code) > 4:
return code[:4] + "." + code[4:]
else:
if len(code) > 3:
return code[:3] + "." + code[3:]
else:
if len(code) > 2:
return code[:2] + "." + code[2:]
return code
def reformat_icd(code: str, version: int, is_diag: bool) -> str:
"""Format ICD code depending on version."""
if version == 9:
return reformat_icd9(code, is_diag)
elif version == 10:
return reformat_icd10(code, is_diag)
else:
raise ValueError("version must be 9 or 10")
def sort_by_indexes(lst, indexes, reverse=False):
return [val for (_, val) in sorted(zip(indexes, lst), key=lambda x: x[0], reverse=reverse)]
def reformat_code_dataframe(row: pd.DataFrame, cols: list) -> pd.Series:
"""Takes a dataframe and a column name and returns a series with the column name and a list of codes."""
out = dict()
# Sort the first column and rearrange the second column accordingly
sorted_indices = row[cols[0]].argsort()
out[cols[0]] = sort_by_indexes(row[cols[0]], sorted_indices)
out[cols[1]] = sort_by_indexes(row[cols[1]], sorted_indices)
return pd.Series(out)
def parse_codes_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""Parse the codes dataframe."""
df = df.rename(columns={"hadm_id": ID_COLUMN, "subject_id": SUBJECT_ID_COLUMN})
df = df.dropna(subset=["icd_code"])
df = df.drop_duplicates(subset=[ID_COLUMN, "icd_code"])
df = (
df.groupby([SUBJECT_ID_COLUMN, ID_COLUMN, "icd_version"])
.apply(partial(reformat_code_dataframe, cols=["icd_code","long_title"]))
.reset_index()
)
return df
def parse_notes_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""Parse the notes dataframe."""
df = df.rename(
columns={
"hadm_id": ID_COLUMN,
"subject_id": SUBJECT_ID_COLUMN,
"text": TEXT_COLUMN,
}
)
df = df.dropna(subset=[TEXT_COLUMN])
df = df.drop_duplicates(subset=[ID_COLUMN, TEXT_COLUMN])
return df
def filter_codes(df: pd.DataFrame, columns: list[str], min_count: int) -> pd.DataFrame:
"""Filter the codes dataframe to only include codes that appear at least min_count times."""
for col in columns:
code_counts = Counter([code for codes in df[col] for code in codes])
codes_to_keep = set(
code for code, count in code_counts.items() if count >= min_count
)
df[col] = df[col].apply(lambda x: [code for code in x if code in codes_to_keep])
print(f"Number of unique codes in {col} before filtering: {len(code_counts)}")
print(f"Number of unique codes in {col} after filtering: {len(codes_to_keep)}")
return df
def main():
start_time = time.time()
# Load data
print_flush("Loading MIMIC-IV Notes dataset...")
mimic_notes = pd.read_csv("./data/physionet.org/files/mimic-iv-note/2.2/note/discharge.csv.gz", compression='gzip')
print_flush("Loading MIMIC-IV Procedures dataset...")
mimic_proc = pd.read_csv("./data/physionet.org/files/mimiciv/2.2/hosp/procedures_icd.csv.gz", compression='gzip')
print_flush("Loading MIMIC-IV Diagnoses dataset...")
mimic_diag = pd.read_csv("./data/physionet.org/files/mimiciv/2.2/hosp/diagnoses_icd.csv.gz", compression='gzip')
print_flush("Loading ICD Procedures descriptions...")
procedures = pd.read_csv("./data/physionet.org/files/mimiciv/2.2/hosp/d_icd_procedures.csv.gz", compression='gzip')
print_flush("Loading ICD Diagnoses descriptions...")
diagnoses = pd.read_csv("./data/physionet.org/files/mimiciv/2.2/hosp/d_icd_diagnoses.csv.gz", compression='gzip')
# Merge procedures and diagnoses
print_flush("Merging procedures with their descriptions...")
mimic_proc = mimic_proc.merge(procedures, how='inner', on=['icd_code','icd_version'])
print_flush("Merging diagnoses with their descriptions...")
mimic_diag = mimic_diag.merge(diagnoses, how='inner', on=['icd_code','icd_version'])
# Format ICD codes
print_flush("Formatting procedure ICD codes...")
mimic_proc["icd_code"] = mimic_proc.apply(
lambda row: reformat_icd(code=row["icd_code"], version=row["icd_version"], is_diag=False),
axis=1,
)
print_flush("Formatting diagnosis ICD codes...")
mimic_diag["icd_code"] = mimic_diag.apply(
lambda row: reformat_icd(code=row["icd_code"], version=row["icd_version"], is_diag=True),
axis=1,
)
# Process codes and notes
print_flush("Processing procedure codes...")
mimic_proc = parse_codes_dataframe(mimic_proc)
print_flush("Processing diagnosis codes...")
mimic_diag = parse_codes_dataframe(mimic_diag)
print_flush("Processing clinical notes...")
mimic_notes = parse_notes_dataframe(mimic_notes)
# Filter for ICD-10 codes and merge
print_flush("Filtering for ICD-10 codes...")
mimic_proc_10 = mimic_proc[mimic_proc["icd_version"] == 10]
mimic_proc_10 = mimic_proc_10.rename(columns={"icd_code": "icd10_proc"})
mimic_diag_10 = mimic_diag[mimic_diag["icd_version"] == 10]
mimic_diag_10 = mimic_diag_10.rename(columns={"icd_code": "icd10_diag"})
# Merge notes with procedures and diagnoses
print_flush("Merging clinical notes with procedures...")
mimiciv_10 = mimic_notes.merge(
mimic_proc_10[[ID_COLUMN, "icd10_proc", "long_title"]], on=ID_COLUMN, how="inner"
)
print_flush("Merging clinical notes with diagnoses...")
mimiciv_10 = mimiciv_10.merge(
mimic_diag_10[[ID_COLUMN, "icd10_diag", "long_title"]], on=ID_COLUMN, how="inner"
)
# Clean up data
print_flush("Cleaning and filtering records...")
mimiciv_10 = mimiciv_10.dropna(subset=["icd10_proc", "icd10_diag"], how="all")
mimiciv_10["icd10_proc"] = mimiciv_10["icd10_proc"].apply(
lambda x: [] if x is np.nan else x
)
mimiciv_10["icd10_diag"] = mimiciv_10["icd10_diag"].apply(
lambda x: [] if x is np.nan else x
)
# Filter codes and create target
print_flush("Filtering codes by minimum count...")
mimiciv_10 = filter_codes(mimiciv_10, ["icd10_proc", "icd10_diag"], MIN_TARGET_COUNT)
mimiciv_10[TARGET_COLUMN] = mimiciv_10["icd10_proc"] + mimiciv_10["icd10_diag"]
mimiciv_10["long_title"] = mimiciv_10["long_title_x"] + mimiciv_10["long_title_y"]
# Remove empty targets and reset index
print_flush("Removing records with empty targets...")
mimiciv_10 = mimiciv_10[mimiciv_10[TARGET_COLUMN].apply(lambda x: len(x) > 0)]
mimiciv_10 = mimiciv_10.reset_index(drop=True)
# Save to disk
print_flush("Saving processed dataset to disk...")
mimiciv_10.to_feather("./data/mimiciv_icd10.feather")
print_flush("Dataset saved successfully!")
#print time taken
print_flush(f"Time taken: {(time.time() - start_time)/60} minutes")
if __name__ == "__main__":
main()