Skip to content

Commit 1d3b326

Browse files
authored
support deepseek-ocr (#2208)
1 parent 9eed3af commit 1d3b326

File tree

4 files changed

+132
-3
lines changed

4 files changed

+132
-3
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
---
2+
pipeline_tag: image-text-to-text
3+
language:
4+
- multilingual
5+
tags:
6+
- mindspore
7+
- mindnlp
8+
- deepseek
9+
- vision-language
10+
- ocr
11+
- custom_code
12+
license: mit
13+
---
14+
<div align="center">
15+
<img src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/logo.svg?raw=true" width="60%" alt="DeepSeek AI" />
16+
</div>
17+
<hr>
18+
<div align="center">
19+
<a href="https://www.deepseek.com/" target="_blank">
20+
<img alt="Homepage" src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/badge.svg?raw=true" />
21+
</a>
22+
<a href="https://huggingface.co/lvyufeng/DeepSeek-OCR" target="_blank">
23+
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
24+
</a>
25+
26+
</div>
27+
28+
29+
30+
31+
<p align="center">
32+
<a href="https://github.com/mindspore-lab/mindnlp/tree/master/examples/transformers/inference/deepseek-ocr"><b>🌟 Github</b></a> |
33+
<a href="https://huggingface.co/lvyufeng/DeepSeek-OCR"><b>📥 Model Download</b></a> |
34+
<a href="https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek_OCR_paper.pdf"><b>📄 Paper Link</b></a> |
35+
<a href=""><b>📄 Arxiv Paper Link</b></a> |
36+
</p>
37+
<h2>
38+
<p align="center">
39+
<a href="">DeepSeek-OCR: Contexts Optical Compression</a>
40+
</p>
41+
</h2>
42+
<p align="center">
43+
<a href="">Explore the boundaries of visual-text compression.</a>
44+
</p>
45+
46+
## Usage
47+
Inference using Huggingface transformers on NVIDIA GPUs. Requirements tested on python 3.12.9 + CUDA11.8:
48+
49+
```
50+
mindspore==2.7.0
51+
mindnlp==0.5.0rc4
52+
transformers==4.57.1
53+
tokenizers
54+
einops
55+
addict
56+
easydict
57+
```
58+
59+
```python
60+
import os
61+
import mindnlp
62+
import torch
63+
from transformers import AutoModel, AutoTokenizer
64+
65+
model_name = 'lvyufeng/DeepSeek-OCR-Community-Latest'
66+
67+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
68+
model = AutoModel.from_pretrained(model_name, _attn_implementation='sdpa', trust_remote_code=True, use_safetensors=True, device_map='auto')
69+
model = model.eval()
70+
71+
# prompt = "<image>\nFree OCR. "
72+
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
73+
image_file = 'your_image.jpg'
74+
output_path = 'your/output/dir'
75+
76+
# infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
77+
78+
# Tiny: base_size = 512, image_size = 512, crop_mode = False
79+
# Small: base_size = 640, image_size = 640, crop_mode = False
80+
# Base: base_size = 1024, image_size = 1024, crop_mode = False
81+
# Large: base_size = 1280, image_size = 1280, crop_mode = False
82+
83+
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
84+
85+
res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 640, crop_mode=True, save_results = True, test_compress = True)
86+
```
87+
88+
## Acknowledgement
89+
90+
We would like to thank [Vary](https://github.com/Ucas-HaoranWei/Vary/), [GOT-OCR2.0](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/), [MinerU](https://github.com/opendatalab/MinerU), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [OneChart](https://github.com/LingyvKong/OneChart), [Slow Perception](https://github.com/Ucas-HaoranWei/Slow-Perception) for their valuable models and ideas.
91+
92+
We also appreciate the benchmarks: [Fox](https://github.com/ucaslcl/Fox), [OminiDocBench](https://github.com/opendatalab/OmniDocBench).
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import mindspore
2+
import mindnlp
3+
from transformers import AutoModel, AutoTokenizer
4+
5+
model_name = 'lvyufeng/DeepSeek-OCR-Community-Latest'
6+
7+
8+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9+
model = AutoModel.from_pretrained(model_name, _attn_implementation='sdpa', dtype=mindspore.float16,
10+
trust_remote_code=True, use_safetensors=True, device_map='auto')
11+
model = model.eval()
12+
13+
14+
# prompt = "<image>\nFree OCR. "
15+
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
16+
# wget "https://hf-mirror.com/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
17+
image_file = 'image_ocr.jpg'
18+
output_path = './'
19+
20+
21+
22+
# infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
23+
24+
# Tiny: base_size = 512, image_size = 512, crop_mode = False
25+
# Small: base_size = 640, image_size = 640, crop_mode = False
26+
# Base: base_size = 1024, image_size = 1024, crop_mode = False
27+
# Large: base_size = 1280, image_size = 1280, crop_mode = False
28+
29+
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
30+
31+
res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 640, crop_mode=True, save_results = True, test_compress = True)

mindtorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
from . import _C
7575
from ._dtype import *
7676
from ._tensor import Tensor, tensor, scalar_tensor, is_tensor, \
77-
LongTensor, FloatTensor, BoolTensor, HalfTensor, BFloat16Tensor, IntTensor
77+
LongTensor, FloatTensor, BoolTensor, HalfTensor, BFloat16Tensor, IntTensor, ByteTensor
7878

7979
from ._C import *
8080
from ._C.size import Size

mindtorch/_tensor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class StubTensor: pass
4848
}
4949

5050
class TypedTensorMeta(_TensorMeta):
51-
def __isinstancecheck__(self, instance):
51+
def __instancecheck__(self, instance):
5252
if not isinstance(instance, Tensor):
5353
return False
5454
return instance.dtype == self.dtype
@@ -78,7 +78,7 @@ def __init__(self, *args, **kwargs):
7878
super().__init__(*args, dtype=_dtype.float16, **kwargs)
7979

8080
class BFloat16Tensor(Tensor, metaclass=TypedTensorMeta):
81-
dtype = _dtype.float16
81+
dtype = _dtype.bfloat16
8282
def __init__(self, *args, **kwargs):
8383
self._device = kwargs.pop('device', device_('cpu'))
8484
super().__init__(*args, dtype=_dtype.bfloat16, **kwargs)
@@ -89,6 +89,12 @@ def __init__(self, *args, **kwargs):
8989
self._device = kwargs.pop('device', device_('cpu'))
9090
super().__init__(*args, dtype=_dtype.bool, **kwargs)
9191

92+
class ByteTensor(Tensor, metaclass=TypedTensorMeta):
93+
dtype = _dtype.uint8
94+
def __init__(self, *args, **kwargs):
95+
self._device = kwargs.pop('device', device_('cpu'))
96+
super().__init__(*args, dtype=_dtype.uint8, **kwargs)
97+
9298

9399
def tensor_meta_str(self):
94100
return "<class 'torch.Tensor'>"

0 commit comments

Comments
 (0)