-
Notifications
You must be signed in to change notification settings - Fork 6
/
PlasForest.py
314 lines (295 loc) · 15.7 KB
/
PlasForest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#######################################################################################
### ###
### PlasForest 1.4 ###
### Copyright (C) 2020 Léa Pradier, Tazzio Tissot, Anna-Sophie Fiston-Lavier, ###
### Stéphanie Bedhomme. (leaemiliepradier@gmail.com) ###
### ###
### This program is free software: you can redistribute it and/or modify ###
### it under the terms of the GNU General Public License as published by ###
### the Free Software Foundation, either version 3 of the License, or ###
### (at your option) any later version. ###
### ###
### This program is distributed in the hope that it will be useful, ###
### but WITHOUT ANY WARRANTY; without even the implied warranty of ###
### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the ###
### GNU General Public License for more details. ###
### ###
### You should have received a copy of the GNU General Public License ###
### along with this program. If not, see <http://www.gnu.org/licenses/>. ###
### ###
#######################################################################################
import os
import sys, getopt
from sklearn.ensemble import RandomForestClassifier
import pandas as pd
import numpy as np
import pickle
import math
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.Blast.Applications import NcbiblastnCommandline
from Bio import SeqUtils
from Bio.SeqRecord import SeqRecord
from multiprocessing import Pool
import time
import warnings
warnings.filterwarnings("ignore")
start_time = time.time()
### GLOBALS ###
global attributed_IDs, attributed_identities
attributed_IDs = []; attributed_identities = [];
### MAIN ###
def main(argv):
inputfile = ""
outputfile = ""
besthits = False
showFeatures = False
verbose = False
reattribute = False
nthreads = 1
batch=0
modelpath = "plasforest.sav"
databasepath = "plasmid_refseq.fasta"
# INPUT OPTIONS #
print("PlasForest: a homology-based random forest classifier for plasmid identification.")
print("(C) Lea Pradier, Tazzio Tissot, Anna-Sophie Fiston-Lavier, Stephanie Bedhomme. 2020.")
try:
opts, args = getopt.getopt(argv,"hi:o:bfvrmd", ["help","input=","output=","threads=","size_of_batch=","model=","database="])
except getopt.GetoptError:
print('./PlasForest.py -i <inputfile> -o <outputfile>')
print('\t -i, --input <inputfile>: a FASTA input file')
print('\t -o, --output <outputfile>: a CSV file')
print('List of options:')
print('\t -b: show best hit from the plasmid database for each contig')
print('\t -f: keep the features used by the classifier in the output')
print('\t --threads <int>: number of threads (default: 1)')
print('\t --size_of_batch <int>: number of sequences per batch')
print('\t -r: reattribute contigs which are already described as plasmid or chromosome')
print('\t --model <path>/plasforest.sav: path to the .sav model file')
print('\t --database <path>/plasmid_refseq.fasta: path to the database')
print('\t -v: verbose mode')
print('\t -h, --help: show this message and quit')
sys.exit(2)
for opt, arg in opts:
if opt in ('-h','--help'):
print('./PlasForest.py -i <inputfile> -o <outputfile>')
print('\t -i, --input <inputfile>: a FASTA input file')
print('\t -o, --output <outputfile>: a CSV file')
print('List of options:')
print('\t -b: show best hit from the plasmid database for each contig')
print('\t -f: keep the features used by the classifier in the output')
print('\t --threads <int>: number of threads (default: 1)')
print('\t --size_of_batch <int>: number of sequences per batch')
print('\t -r: reassign contigs which are already described as plasmid or chromosome')
print('\t --model <path>/plasforest.sav: path to the .sav model file')
print('\t --database <path>/plasmid_refseq.fasta: path to the database')
print('\t -v: verbose mode')
print('\t -h, --help: show this message and quit')
sys.exit()
elif opt in ("-i", "--input"):
inputfile = arg
if (not arg.endswith(".fasta")) and (not arg.endswith(".fna")) and (not arg.endswith(".fa")):
print('./PlasForest.py -i <inputfile> -o <outputfile>')
print('Error: input file must be in FASTA format.')
sys.exit()
outputfile = inputfile+".csv"
elif opt in ("-o", "--output"):
outputfile=arg
if not arg.endswith(".csv"):
print('./PlasForest.py -i <inputfile> -o <outputfile>')
print('Error: output file is a column-separated file.')
sys.exit()
elif opt=="--threads":
nthreads = arg
elif opt=="-b":
besthits = True
elif opt=="-f":
showFeatures = True
elif opt=="-r":
reattribute = True
elif opt=="-v":
verbose = True
elif opt=="--size_of_batch":
if not arg.isnumeric():
print("Error: size_of_batch must be integer.")
sys.exit()
batch = int(arg)
elif opt in ("--model"):
modelpath = arg
elif opt in ("-d", "--database"):
if os.path.exists(arg) and os.path.isfile(arg) and arg.endswith(".fasta"):
databasepath = arg
else:
print("Error: cannot find the path to the plasmid database")
sys.exit()
if os.path.exists(modelpath) and os.path.isfile(modelpath) and modelpath.endswith(".sav"):
global plasforest
plasforest = pickle.load(open(modelpath,"rb"))
else:
print("Error: cannot find the path to the .sav file")
sys.exit()
if verbose: print("Applying PlasForest on "+inputfile+".")
tmp_fasta = inputfile+"_tmp.fasta"
blast_table = inputfile+"_blast.out"
# THE WHOLE DATASET WILL BE ANALYZED AT ONCE #
if batch==0:
list_records = seq_checker(inputfile, tmp_fasta, verbose, reattribute)
if(os.path.isfile(tmp_fasta)):
blast_launcher(tmp_fasta, blast_table, verbose, nthreads, databasepath)
features = get_features(list_records, blast_table, verbose, nthreads)
finalfile = plasforest_predict(features, showFeatures, besthits, verbose, attributed_IDs, attributed_identities, nthreads)
os.remove(tmp_fasta)
os.remove(blast_table)
if verbose: print("Temporary files are deleted.")
else:
if verbose: print("Contig descriptions already mentioned their identities.")
finalfile = pd.DataFrame({"ID":attributed_IDs, "Prediction":attributed_identities})
# THE INPUT DATASET WILL BE ANALYZED IN SEPARATE BATCHES #
else:
nb_seqs_to_analyze = seq_counter(inputfile, verbose, batch)
nb_seqs_analyzed = 0
finalfile = pd.DataFrame()
while nb_seqs_analyzed < nb_seqs_to_analyze:
list_records = seq_checker_batch(inputfile, tmp_fasta, verbose, reattribute, batch, nb_seqs_analyzed)
nb_seqs_analyzed = nb_seqs_analyzed + len(list_records)
if(os.path.isfile(tmp_fasta)):
blast_launcher(tmp_fasta, blast_table, verbose, nthreads, databasepath)
features = get_features(list_records, blast_table, verbose, nthreads)
tmp_finalfile = plasforest_predict(features, showFeatures, besthits, verbose, attributed_IDs, attributed_identities, nthreads)
os.remove(tmp_fasta)
os.remove(blast_table)
finalfile = finalfile.append(tmp_finalfile)
# THE OUTPUT IS PRINTED TO A CSV FILE #
finalfile.to_csv(os.path.abspath(outputfile), sep=',', encoding='utf-8', index=False)
print('Predictions are printed in '+outputfile)
with open("time_plasforest_"+str(nthreads)+"_threads.dat","a+") as f:
f.write(str(time.time() - start_time)+"\n")
### COUNT SEQUENCES ###
def seq_counter(inputfile, verbose, batch):
nb_seqs_analyze = 0
with open(inputfile, "r+") as handle:
for record in SeqIO.parse(handle, "fasta"):
nb_seqs_analyze+=1
if verbose:
ntotbatch = nb_seqs_analyze // batch
if nb_seqs_analyze % batch != 0:
ntotbatch+=1
print("PlasForest will analyze the input in ",str(ntotbatch)," different batches.")
return nb_seqs_analyze
### CHECK WHICH SEQUENCES NEED TO BE CHECKED ###
def seq_checker(inputfile, tmp_fasta, verbose, reattribute):
nb_seqs_analyze = 0
list_records = []
words_to_eliminate = ["plasmid", "chromosome", "Plasmid", "Chromosome"]
words_plasmid = ["plasmid","Plasmid"]
with open(inputfile, "r+") as handle:
for record in SeqIO.parse(handle, "fasta"):
if all(word not in str(record.id) for word in words_to_eliminate) and all(word not in str(record.description) for word in words_to_eliminate):
nb_seqs_analyze+=1
list_records.append(record)
with open(tmp_fasta,"a+") as tmp:
SeqIO.write(record, tmp, "fasta")
else:
if reattribute:
nb_seqs_analyze+=1
list_records.append(record)
with open(tmp_fasta,"a+") as tmp:
SeqIO.write(record, tmp, "fasta")
else:
attributed_IDs.append(str(record.id))
if any(word in str(record.id) for word in words_plasmid) or any(word in str(record.description) for word in words_plasmid):
attributed_identities.append("Plasmid")
else:
attributed_identities.append("Chromosome")
if verbose: print(str(nb_seqs_analyze)+" contigs will be analyzed by PlasForest.")
return list_records
def seq_checker_batch(inputfile, tmp_fasta, verbose, reattribute, batch, nb_already_analyzed):
nb_seqs_analyze = 0
list_records = []
words_to_eliminate = ["plasmid", "chromosome", "Plasmid", "Chromosome"]
words_plasmid = ["plasmid","Plasmid"]
with open(inputfile, "r+") as handle:
for record in SeqIO.parse(handle, "fasta"):
if all(word not in str(record.id) for word in words_to_eliminate) and all(word not in str(record.description) for word in words_to_eliminate):
nb_seqs_analyze+=1
if nb_seqs_analyze > nb_already_analyzed and nb_seqs_analyze <= nb_already_analyzed+batch:
list_records.append(record)
with open(tmp_fasta,"a+") as tmp:
SeqIO.write(record, tmp, "fasta")
else:
if reattribute:
nb_seqs_analyze+=1
if nb_seqs_analyze > nb_already_analyzed and nb_seqs_analyze <= nb_already_analyzed+batch:
list_records.append(record)
with open(tmp_fasta,"a+") as tmp:
SeqIO.write(record, tmp, "fasta")
else:
attributed_IDs.append(str(record.id))
if any(word in str(record.id) for word in words_plasmid) or any(word in str(record.description) for word in words_plasmid):
attributed_identities.append("Plasmid")
else:
attributed_identities.append("Chromosome")
if verbose: print("Contigs #"+str(nb_already_analyzed+1)+" to #"+str(min(nb_seqs_analyze,nb_already_analyzed+batch))+" will be analyzed by PlasForest.")
return list_records
### BLAST LAUNCHER ###
def blast_launcher(tmp_fasta, blast_table, verbose, nthreads, databasepath):
blastn_cline = NcbiblastnCommandline(query = tmp_fasta, db = databasepath, evalue = 0.001, outfmt = 6, out = blast_table, num_threads=nthreads)
if verbose: print("Starting BLASTn...")
stdout, stderr = blastn_cline()
if verbose: print("BLASTn is over!")
### GET FEATURES ###
def get_features(list_records, blast_table, verbose, nthreads):
if verbose: print("Computing the features")
pool = Pool(int(nthreads), read_blast, [blast_table])
seqFeatures = pd.concat(pool.map(get_seq_feature, list_records))
blastFeatures = pd.concat(pool.map(get_blast_feature, seqFeatures["ID"].unique().tolist()))
pool.close()
pool.join()
features = pd.merge(seqFeatures, blastFeatures, on="ID")
features[["Maximum coverage", "Median coverage","Average coverage"]] = features[["Maximum coverage", "Median coverage","Average coverage"]].div(features["Contig size"], axis="index")
features["Variance of coverage"] = features["Variance of coverage"].div(features["Contig size"], axis="index").div(features["Contig size"], axis="index")
return(features)
def read_blast(blast_table):
global BLAST
BLAST = pd.read_table(blast_table, sep='\t', header=None, names=['qseqid','sseqid','pident','length','mismatch','gapopen','qstart','qend', 'sstart','send','evalue','bitscore'])
def get_seq_feature(record):
return(pd.DataFrame({"ID":[str(record.id)], "G+C content":[float(SeqUtils.GC(record.seq))], "Contig size":[len(str(record.seq))]}))
def get_blast_feature(ID):
idBLAST = BLAST[BLAST["qseqid"].astype(str)==ID]
maxCover=0; averCover=0; medianCover=0; varCover=0; bestHit="NA"
nHits=len(idBLAST.index)
if nHits > 0:
maxCover=max(idBLAST["length"])
averCover = np.mean(idBLAST["length"])
medianCover = np.median(idBLAST["length"])
varCover = np.var(idBLAST["length"])
theBestHit = idBLAST[idBLAST["length"]==max(idBLAST["length"])]
bestHit = theBestHit["sseqid"].unique().tolist()[0]
return(pd.DataFrame({"ID":[ID],"Number of hits":[nHits],"Maximum coverage":[maxCover],"Median coverage":[medianCover],"Average coverage":[averCover],"Variance of coverage":[varCover],"Best hit":[bestHit]}))
### PREDICT WITH PLASFOREST ###
def plasforest_predict(features, showFeatures, besthits, verbose, attributed_IDs, attributed_identities, nthreads):
if verbose: print("Starting predictions with the random forest classifier...")
plasforest.n_jobs = int(nthreads)
features["num_predict"] = plasforest.predict(features[["G+C content","Contig size","Number of hits","Maximum coverage","Median coverage","Average coverage","Variance of coverage"]])
wordy_prediction = []
for index, row in features.iterrows():
if row["num_predict"]==0: wordy_prediction.append("Chromosome")
else: wordy_prediction.append("Plasmid")
features["Prediction"] = wordy_prediction
features = features.drop(columns=["num_predict"])
if showFeatures:
finalfile=features
elif besthits:
finalfile=features[["ID","Prediction", "Best hit"]]
else:
finalfile=features[["ID","Prediction"]]
if verbose: print("Predictions made!")
if len(attributed_IDs)>0:
attributed_df = pd.DataFrame({"ID":attributed_IDs, "Prediction":attributed_identities})
finalfile = pd.concat([finalfile, attributed_df])
return(finalfile)
### EXECUTE MAIN ###
if __name__ == "__main__":
main(sys.argv[1:])