Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 65 additions & 7 deletions allel/io/vcf_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


import numpy as np
import zarr


import allel
Expand All @@ -18,10 +19,41 @@

VCF_FIXED_FIELDS = 'CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO'

CALLDATA_CALLSET_GROUP = 'calldata'
GENOTYPE_CALLSET_KEY = 'GT'
SAMPLES_CALLSET_KEY = 'samples'

NORMALIZED_SAMPLE_NAME_PREFIX = 'SAMPLE_'


def normalize_callset(callset):

if hasattr(callset, 'keys'):
if isinstance(callset, zarr.hierarchy.Group):
names = list()
new_callset = dict()

for group in callset.group_keys():
if group == CALLDATA_CALLSET_GROUP:
continue

for key in callset[group].array_keys():
names.append(key)
new_callset[key] = callset[group][key]

gt = callset.get(CALLDATA_CALLSET_GROUP + '/' + GENOTYPE_CALLSET_KEY)
if gt:
samples = callset.get(SAMPLES_CALLSET_KEY, [])
n_gt_cols = gt.shape[1]
if len(samples) != n_gt_cols:
raise ValueError('number of sample names and genotype columns mismatch')

for i in range(n_gt_cols):
sample_name = samples[i]
names.append(_filterable_sample_name(sample_name))
new_callset[sample_name] = gt[:, i, :]

callset = new_callset
elif hasattr(callset, 'keys'):
names = list()
new_callset = dict()
for k in list(callset.keys()):
Expand Down Expand Up @@ -77,7 +109,8 @@ def write_vcf_header(vcf_file, names, callset, rename, number, description):

info_names = [n for n in names
if not n.upper().startswith('FILTER_') and
not n.upper() in VCF_FIXED_FIELDS]
not n.upper() in VCF_FIXED_FIELDS and
not _is_sample_name(n)]
info_ids = [rename[n] if n in rename else n
for n in info_names]

Expand Down Expand Up @@ -141,8 +174,12 @@ def write_vcf_header(vcf_file, names, callset, rename, number, description):
% (vcf_id, vcf_description)
print(header_line, file=vcf_file)

# reconstruct sample names
sample_names = _filter_sample_names(names)

# write column names
line = '#' + '\t'.join(VCF_FIXED_FIELDS)
columns = list(VCF_FIXED_FIELDS) + sample_names
line = '#' + '\t'.join(columns)
print(line, file=vcf_file)


Expand Down Expand Up @@ -207,7 +244,8 @@ def write_vcf_data(vcf_file, names, callset, rename, fill):
# find INFO columns
info_names = [n for n in names
if not n.upper().startswith('FILTER_') and
not n.upper() in VCF_FIXED_FIELDS]
not n.upper() in VCF_FIXED_FIELDS and
not _is_sample_name(n)]
info_ids = [rename[n] if n in rename else n
for n in info_names]
info_cols = [callset[n] for n in info_names]
Expand All @@ -217,18 +255,22 @@ def write_vcf_data(vcf_file, names, callset, rename, fill):
key=itemgetter(1))
info_names, info_ids, info_cols = zip(*infos)

# genotype columns
sample_names = _filter_sample_names(names)
gt_cols = [callset[n] for n in sample_names]

# setup writer
writer = csv.writer(vcf_file, delimiter='\t', lineterminator='\n')

# zip up data as rows
rows = zip(col_chrom, col_pos, col_id, col_ref, col_alt, col_qual)
rows = zip(col_chrom, col_pos, col_id, col_ref, col_alt, col_qual, *gt_cols)
filter_rows = zip(*filter_cols)
info_rows = zip(*info_cols)

for row, filter_row, info_row in itertools.zip_longest(rows, filter_rows, info_rows):

# unpack main row
chrom, pos, id, ref, alt, qual = row
chrom, pos, id, ref, alt, qual, *gts = row
chrom = _vcf_value_str(chrom)
pos = _vcf_value_str(pos)
id = _vcf_value_str(id)
Expand All @@ -255,8 +297,11 @@ def write_vcf_data(vcf_file, names, callset, rename, fill):
else:
info = '.'

# construct genotype value
gts = ['/'.join(map(str, gt)) for gt in gts]

# repack
row = chrom, pos, id, ref, alt, qual, flt, info
row = chrom, pos, id, ref, alt, qual, flt, info, *gts
writer.writerow(row)


Expand All @@ -282,3 +327,16 @@ def _vcf_info_str(name, id, value, fill):
return None
else:
return '%s=%s' % (id, _vcf_value_str(value, fill=fill.get(name, None)))


def _filter_sample_names(names):
return [n[len(NORMALIZED_SAMPLE_NAME_PREFIX):] for n in names
if _is_sample_name(n)]


def _is_sample_name(name):
return name.startswith(NORMALIZED_SAMPLE_NAME_PREFIX)


def _filterable_sample_name(sample_name):
return NORMALIZED_SAMPLE_NAME_PREFIX + sample_name
48 changes: 48 additions & 0 deletions allel/test/io/test_vcf_write.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
import atexit
import os
import shutil
import tempfile
import warnings

import pytest
import zarr

from allel.io.vcf_read import vcf_to_zarr
from allel.io.vcf_write import write_vcf


# needed for PY2/PY3 consistent behaviour
warnings.resetwarnings()
warnings.simplefilter('always')


# setup temp dir for testing
tempdir = tempfile.mkdtemp()
atexit.register(shutil.rmtree, tempdir)


def fixture_path(fn):
return os.path.join(os.path.dirname(__file__), os.pardir, 'data', fn)


@pytest.fixture(scope='module') # run once and used by all tests in this file
def zarr_callset():
vcf_path = fixture_path('sample.vcf')
zarr_path = os.path.join(tempdir, 'sample.zarr')
vcf_to_zarr(vcf_path, zarr_path, fields='*')
return zarr.open_group(zarr_path, mode='r')


def test_write_from_zarr_callset(zarr_callset):
out_path = os.path.join(tempdir, 'out.vcf')
write_vcf(out_path, zarr_callset)

# TODO: Once the write function can write out full data,
# modify the test so that it load back the written file
# and compare it with the original faithfully.
random_line = '20 1110696 rs6040355 A G,T'
random_line = random_line.replace(r'\t', '\t')
with open(out_path, 'r') as file:
content = file.read()
assert random_line in content