Skip to content

Commit 05a6080

Browse files
Joe Jevnikllllllllll
authored andcommitted
DEV: add testing utility for checking term lookback windows
1 parent 6110ce3 commit 05a6080

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import numpy as np
2+
3+
from zipline.testing.predicates import assert_equal
4+
from .factor import CustomFactor
5+
6+
7+
class IDBox(object):
8+
"""A wrapper that hashs to the id of the underlying object and compares
9+
equality on the id of the underlying.
10+
11+
Parameters
12+
----------
13+
ob : any
14+
The object to wrap.
15+
16+
Attributes
17+
----------
18+
ob : any
19+
The object being wrapped.
20+
21+
Notes
22+
-----
23+
This is useful for storing non-hashable values in a set or dict.
24+
"""
25+
def __init__(self, ob):
26+
self.ob = ob
27+
28+
def __hash__(self):
29+
return id(self)
30+
31+
def __eq__(self, other):
32+
if not isinstance(other, IDBox):
33+
return NotImplemented
34+
35+
return id(self.ob) == id(other.ob)
36+
37+
38+
class CheckWindowsFactor(CustomFactor):
39+
"""A custom factor that makes assertions about the lookback windows that
40+
it gets passed.
41+
42+
Parameters
43+
----------
44+
input_ : Term
45+
The input term to the factor.
46+
window_length : int
47+
The length of the lookback window.
48+
expected_windows : dict[int, dict[pd.Timestamp, np.ndarray]]
49+
For each asset, for each day, what the expected lookback window is.
50+
51+
Notes
52+
-----
53+
The output of this factor is the same as ``Latest``. Any assets or days
54+
not in ``expected_windows`` are not checked.
55+
"""
56+
params = ('expected_windows',)
57+
58+
def __new__(cls, input_, window_length, expected_windows):
59+
return super(CheckWindowsFactor, cls).__new__(
60+
cls,
61+
inputs=[input_],
62+
dtype=input_.dtype,
63+
window_length=window_length,
64+
expected_windows=frozenset(
65+
(k, IDBox(v)) for k, v in expected_windows.items()
66+
),
67+
)
68+
69+
def compute(self, today, assets, out, input_, expected_windows):
70+
for asset, expected_by_day in expected_windows:
71+
expected_by_day = expected_by_day.ob
72+
73+
col_ix = np.searchsorted(assets, asset)
74+
if assets[col_ix] != asset:
75+
raise AssertionError('asset %s is not in the window' % asset)
76+
77+
try:
78+
expected = expected_by_day[today]
79+
except KeyError:
80+
pass
81+
else:
82+
expected = np.array(expected)
83+
actual = input_[:, col_ix]
84+
assert_equal(actual, expected)
85+
86+
# output is just latest
87+
out[:] = input_[-1]

0 commit comments

Comments
 (0)