Skip to content

Commit 601d309

Browse files
authored
Work-around for protobuf init crash in sentencepiece (#9693)
A work-around for the protobuf init crash in sentencepiece. The code within the try/except block triggers sentencepiece loading path. This has to be done between ``import torch`` and ``import _XLAC`` and works even though the loading of "" triggers file not found error. #9691
1 parent abc7854 commit 601d309

File tree

5 files changed

+34
-0
lines changed

5 files changed

+34
-0
lines changed

.github/workflows/_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ jobs:
107107
pip install fsspec
108108
pip install rich
109109
pip install flax
110+
pip install sentencepiece
110111
- name: Extra CI deps
111112
if: inputs.has_code_changes == 'true'
112113
shell: bash

test/run_tests.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,10 @@ function run_xla_op_tests3 {
266266

267267
function run_xla_op_tests4 {
268268
run_test "$_TEST_DIR/test_jax_interop.py"
269+
# issue #9691: random crashes with sentencepiece protobuf; run multiple times to trigger
270+
for i in $(seq 1 5); do
271+
run_test "$_TEST_DIR/test_sentencepiece_interop.py"
272+
done
269273
}
270274

271275
function run_xla_op_tests5 {

test/test_sentencepiece_interop.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import unittest
2+
3+
4+
class SentencepieceInterop(unittest.TestCase):
5+
6+
def test_sentencepiece_interop(self):
7+
import os
8+
if not os.path.exists("/tmp/test_model.model"):
9+
import urllib.request
10+
urllib.request.urlretrieve(
11+
"https://github.com/google/sentencepiece/raw/refs/heads/master/python/test/test_model.model",
12+
"/tmp/test_model.model")
13+
import torch_xla
14+
import sentencepiece as spm
15+
sp_model = spm.SentencePieceProcessor("/tmp/test_model.model")
16+
17+
18+
if __name__ == "__main__":
19+
unittest.main()

torch_xla/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@
77

88
import torch
99

10+
# issue 9691: ensure sentencepiece protobuf init happen between
11+
# torch/torch-xla protobuf inits to work-around protobuf crash
12+
try:
13+
import sentencepiece as spm
14+
sp_model = spm.SentencePieceProcessor()
15+
sp_model.load('')
16+
except:
17+
pass
18+
1019
import _XLAC
1120
from ._internal import tpu
1221
from .version import __version__

torchax/dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
torch==2.8.0 ; sys_platform == 'darwin' # macOS
33
torch==2.8.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU
44
yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml`
5+
jax==0.7.2 # N.B.: torchax breaks on newer JAX versions that would be pulled from `flax` dependencies
56
flax==0.10.6

0 commit comments

Comments
 (0)