Skip to content

Commit e1b74da

Browse files
committed
Refactor VI indicator to exclusively use Rust implementation
- Removed the fallback Python implementation for the VI indicator, ensuring it directly utilizes the jesse_rust.vi function for improved performance. - Simplified the code by eliminating unnecessary checks for Rust availability, aligning with the ongoing integration of Rust across the codebase. - This change enhances the efficiency and maintainability of the indicators module.
1 parent 2f80d74 commit e1b74da

File tree

1 file changed

+8
-67
lines changed

1 file changed

+8
-67
lines changed

jesse/indicators/vi.py

Lines changed: 8 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import namedtuple
22

33
import numpy as np
4+
import jesse_rust
45
from jesse.helpers import slice_candles
56

67
VI = namedtuple('VI', ['plus', 'minus'])
@@ -16,72 +17,12 @@ def vi(candles: np.ndarray, period: int = 14, sequential: bool = False) -> VI:
1617
1718
:return: VI(plus, minus)
1819
"""
19-
# Check if jesse_rust is available
20-
try:
21-
import jesse_rust
22-
23-
candles = slice_candles(candles, sequential)
24-
25-
# Use the Rust implementation
26-
vi_plus, vi_minus = jesse_rust.vi(candles, period, sequential)
27-
28-
if sequential:
29-
return VI(vi_plus, vi_minus)
30-
else:
31-
return VI(vi_plus[-1], vi_minus[-1])
32-
33-
except ImportError:
34-
# Fallback to pure Python implementation
35-
candles = slice_candles(candles, sequential)
36-
37-
vpn_with_nan, vmn_with_nan = vi_fast_python(candles, period)
38-
39-
if sequential:
40-
return VI(vpn_with_nan, vmn_with_nan)
41-
else:
42-
return VI(vpn_with_nan[-1], vmn_with_nan[-1])
43-
44-
45-
def vi_fast_python(candles, period):
46-
"""
47-
Pure Python implementation of VI calculation
48-
"""
49-
candles_close = candles[:, 2]
50-
candles_high = candles[:, 3]
51-
candles_low = candles[:, 4]
52-
n = len(candles_high)
53-
54-
tr = np.zeros(n)
55-
vp = np.zeros(n)
56-
vm = np.zeros(n)
57-
58-
tr[0] = candles_high[0] - candles_low[0]
59-
60-
for i in range(1, n):
61-
hl = candles_high[i] - candles_low[i]
62-
hpc = np.abs(candles_high[i] - candles_close[i - 1])
63-
lpc = np.abs(candles_low[i] - candles_close[i - 1])
64-
tr[i] = np.amax(np.array([hl, hpc, lpc]))
65-
vp[i] = np.abs(candles_high[i] - candles_low[i - 1])
66-
vm[i] = np.abs(candles_low[i] - candles_high[i - 1])
67-
68-
trd = np.zeros(n)
69-
vpd = np.zeros(n)
70-
vmd = np.zeros(n)
71-
72-
for j in range(n - period + 1):
73-
trd[period - 1 + j] = np.sum(tr[j:j + period])
74-
vpd[period - 1 + j] = np.sum(vp[j:j + period])
75-
vmd[period - 1 + j] = np.sum(vm[j:j + period])
76-
77-
trd = trd[period - 1:]
78-
vpd = vpd[period - 1:]
79-
vmd = vmd[period - 1:]
80-
81-
vpn = vpd / trd
82-
vmn = vmd / trd
20+
candles = slice_candles(candles, sequential)
8321

84-
vpn_with_nan = np.concatenate((np.full((candles.shape[0] - vpn.shape[0]), np.nan), vpn))
85-
vmn_with_nan = np.concatenate((np.full((candles.shape[0] - vmn.shape[0]), np.nan), vmn))
22+
# Use the Rust implementation
23+
vi_plus, vi_minus = jesse_rust.vi(candles, period, sequential)
8624

87-
return vpn_with_nan, vmn_with_nan
25+
if sequential:
26+
return VI(vi_plus, vi_minus)
27+
else:
28+
return VI(vi_plus[-1], vi_minus[-1])

0 commit comments

Comments
 (0)