-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init get_new_training_data script and strand spec * refactor main script, fix strand-spec * debugging and testing update_seqweaver * fixed h5 file output * added class modules for training seqweaver * added validation/training strat * debugging main update seqweaver module * strand backward compatibility * further fixes to backward compatibility * val partition fix * indexing fix for backward compatibility * addressing kathy's comments * fixed relative paths in update_seqweaver * handling strand=. as None
- Loading branch information
1 parent
86f4df3
commit 6ff4a5e
Showing
3 changed files
with
244 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
""" | ||
Seqweaver architecture (Park & Troyanskaya, 2021). | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class LambdaBase(nn.Sequential): | ||
def __init__(self, fn, *args): | ||
super(LambdaBase, self).__init__(*args) | ||
self.lambda_func = fn | ||
|
||
def forward_prepare(self, input): | ||
output = [] | ||
for module in self._modules.values(): | ||
output.append(module(input)) | ||
return output if output else input | ||
|
||
|
||
class Lambda(LambdaBase): | ||
def forward(self, input): | ||
return self.lambda_func(self.forward_prepare(input)) | ||
|
||
|
||
class Seqweaver(nn.Module): | ||
|
||
def __init__(self, n_classes): # 217 human, 43 mouse | ||
super(Seqweaver, self).__init__() | ||
self.model = nn.Sequential( | ||
nn.Conv2d(4, 160, (1, 8)), | ||
nn.ReLU(), | ||
nn.MaxPool2d((1, 4), (1, 4)), | ||
nn.Dropout(0.1), | ||
nn.Conv2d(160, 320, (1, 8)), | ||
nn.ReLU(), | ||
nn.MaxPool2d((1, 4), (1, 4)), | ||
nn.Dropout(0.1), | ||
nn.Conv2d(320, 480, (1, 8)), | ||
nn.ReLU(), | ||
nn.Dropout(0.3)) | ||
self.fc = nn.Sequential( | ||
Lambda(lambda x: torch.reshape(x, (x.size(0), 25440))), | ||
nn.Sequential( | ||
Lambda(lambda x: x.reshape(1, -1) | ||
if 1 == len(x.size()) else x), | ||
nn.Linear(25440, n_classes) | ||
), | ||
nn.ReLU(), | ||
nn.Sequential( | ||
Lambda(lambda x: x.view(1, -1) | ||
if 1 == len(x.size()) else x), | ||
nn.Linear(n_classes, n_classes) | ||
), | ||
nn.Sigmoid(), | ||
) | ||
|
||
def forward(self, x): | ||
x = x.unsqueeze(2) | ||
x = self.model(x) | ||
x = self.fc(x) | ||
return x | ||
|
||
|
||
def criterion(): | ||
return nn.BCELoss() | ||
|
||
|
||
def get_optimizer(lr): | ||
return (torch.optim.SGD, | ||
{"lr": lr, "weight_decay": 1e-6, "momentum": 0.9}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
""" | ||
This module provides the `UpdateSeqweaver` class, which wraps the master bed file | ||
containing all of the features' binding sites parsed from CLIP-seq. | ||
It supports new dataset construction and training for Seqweaver. | ||
""" | ||
import h5py | ||
import gzip | ||
import numpy as np | ||
import sys | ||
|
||
from selene_sdk.sequences.genome import Genome | ||
from selene_sdk.targets.genomic_features import GenomicFeatures | ||
from selene_sdk.samplers.dataloader import H5DataLoader | ||
from selene_sdk.train_model import TrainModel | ||
from selene_sdk.utils.config import load_path | ||
from selene_sdk.utils.config_utils import parse_configs_and_run | ||
|
||
class UpdateSeqweaver(): | ||
""" | ||
Stores a dataset specifying sequence regions and features. | ||
Accepts a tabix-indexed `*.bed` file with the following columns, | ||
in order: | ||
[chrom, start, end, feature, strand] | ||
Parameters | ||
---------- | ||
input_path : str | ||
Path to the tabix-indexed dataset. Note that for the file to | ||
be tabix-indexed, it must have been compressed with `bgzip`. | ||
Thus, `input_path` should be a `*.gz` file with a | ||
corresponding `*.tbi` file in the same directory. | ||
output_path : str | ||
Path to the output constructed-training data file. | ||
feature_path : str | ||
Path to a '\n'-delimited .txt file containing feature names. | ||
hg_fasta : str | ||
Path to an indexed FASTA file -- a `*.fasta` file with | ||
a corresponding `*.fai` file in the same directory. This file | ||
should contain the target organism's genome sequence. | ||
""" | ||
def __init__(self, input_path, train_path, validate_path, feature_path, hg_fasta, yaml_path, val_prop=0.1, sequence_len=1000): | ||
""" | ||
Constructs a new `UpdateSeqweaver` object. | ||
""" | ||
self.input_path = input_path | ||
self.train_path = train_path | ||
self.validate_path = validate_path | ||
self.feature_path = feature_path | ||
self.yaml_path = yaml_path | ||
self.val_prop = val_prop | ||
|
||
self.hg_fasta = hg_fasta | ||
|
||
self.sequence_len = sequence_len | ||
|
||
with open(self.feature_path, 'r') as handle: | ||
self.feature_set = [line.split('\n')[0] for line in handle.readlines()] | ||
|
||
def _from_midpoint(self, start, end): | ||
""" | ||
Computes start and end of the sequence about the peak midpoint. | ||
Parameters | ||
---------- | ||
start : int | ||
The 0-based first position in the region. | ||
end : int | ||
One past the 0-based last position in the region. | ||
Returns | ||
------- | ||
seq_start : int | ||
Sequence start position about the peak midpoint. | ||
seq_end : int | ||
Sequence end position about the peak midpoint. | ||
""" | ||
region_len = end - start | ||
midpoint = start + region_len // 2 | ||
seq_start = midpoint - np.floor(self.sequence_len / 2.) | ||
seq_end = midpoint + np.ceil(self.sequence_len / 2.) | ||
|
||
return int(seq_start), int(seq_end) | ||
|
||
def construct_training_data(self): | ||
""" | ||
Construct training dataset from bed file and write to output_file. | ||
Parameters | ||
---------- | ||
output_path : str | ||
Path to the output file for the constructed training data. | ||
colname_file : str | ||
Path to a .txt file containing newline-delimited feature names. | ||
""" | ||
list_of_regions = [] | ||
with gzip.open(self.input_path) as f: | ||
for line in f: | ||
line = [str(data,'utf-8') for data in line.strip().split()] | ||
list_of_regions.append(line) | ||
|
||
seqs = Genome(self.hg_fasta, blacklist_regions = 'hg19') | ||
targets = GenomicFeatures(self.input_path, | ||
features = self.feature_set, feature_thresholds = 0.5) | ||
|
||
data_seqs = [] | ||
data_labels = [] | ||
for r in list_of_regions: | ||
chrom, start, end, target, strand = r | ||
start, end = int(start), int(end) | ||
sstart, ssend = self._from_midpoint(start, end) | ||
|
||
# 1 x 4 x 1000 bp | ||
# get_encoding_from_coords : Converts sequence to one-hot-encoding for each of the 4 bases | ||
dna_seq, has_unk = seqs.get_encoding_from_coords_check_unk(chrom, sstart, ssend, strand=strand) | ||
if has_unk: | ||
continue | ||
if len(dna_seq) != self.sequence_len: | ||
continue | ||
|
||
# 1 x n_features | ||
# get_feature_data: Computes which features overlap with the given region. | ||
labels = targets.get_feature_data(chrom, start, end, strand=strand) | ||
|
||
data_seqs.append(dna_seq) | ||
data_labels.append(labels) | ||
|
||
# partition some to validation before writing | ||
val_count = int(np.floor(self.val_prop * len(data_seqs))) | ||
validate_seqs = data_seqs[:val_count] | ||
validate_labels = data_labels[:val_count] | ||
training_seqs = data_seqs[val_count:] | ||
training_labels = data_labels[val_count:] | ||
|
||
with h5py.File(self.validate_path, "w") as fh: | ||
fh.create_dataset("valid_sequences", data=np.array(validate_seqs, dtype=np.int64)) | ||
fh.create_dataset("valid_targets", data=np.array(validate_labels, dtype=np.int64)) | ||
|
||
with h5py.File(self.train_path, "w") as fh: | ||
fh.create_dataset("train_sequences", data=np.array(training_seqs, dtype=np.int64)) | ||
fh.create_dataset("train_targets", data=np.array(training_labels, dtype=np.int64)) | ||
|
||
def _load_yaml(self): | ||
# load yaml configuration | ||
return load_path(self.yaml_path) | ||
|
||
def train_model(self): | ||
# load config file and train model | ||
yaml_config = self._load_yaml() | ||
parse_configs_and_run(yaml_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters