1616
1717package org .springframework .ai .tokenizer ;
1818
19+ import java .util .Base64 ;
20+
1921import com .knuddels .jtokkit .Encodings ;
2022import com .knuddels .jtokkit .api .Encoding ;
2123import com .knuddels .jtokkit .api .EncodingType ;
3436 */
3537public class JTokkitTokenCountEstimator implements TokenCountEstimator {
3638
39+ /**
40+ * The JTokkit encoding instance used for token counting.
41+ */
3742 private final Encoding estimator ;
3843
44+ /**
45+ * Creates a new JTokkitTokenCountEstimator with default CL100K_BASE encoding.
46+ */
3947 public JTokkitTokenCountEstimator () {
4048 this (EncodingType .CL100K_BASE );
4149 }
4250
43- public JTokkitTokenCountEstimator (EncodingType tokenEncodingType ) {
51+ /**
52+ * Creates a new JTokkitTokenCountEstimator with the specified encoding type.
53+ * @param tokenEncodingType the encoding type to use for token counting
54+ */
55+ public JTokkitTokenCountEstimator (final EncodingType tokenEncodingType ) {
4456 this .estimator = Encodings .newLazyEncodingRegistry ().getEncoding (tokenEncodingType );
4557 }
4658
4759 @ Override
48- public int estimate (String text ) {
60+ public int estimate (final String text ) {
4961 if (text == null ) {
5062 return 0 ;
5163 }
5264 return this .estimator .countTokens (text );
5365 }
5466
5567 @ Override
56- public int estimate (MediaContent content ) {
68+ public int estimate (final MediaContent content ) {
5769 int tokenCount = 0 ;
5870
5971 if (content .getText () != null ) {
6072 tokenCount += this .estimate (content .getText ());
6173 }
6274
6375 if (!CollectionUtils .isEmpty (content .getMedia ())) {
64-
6576 for (Media media : content .getMedia ()) {
66-
6777 tokenCount += this .estimate (media .getMimeType ().toString ());
6878
6979 if (media .getData () instanceof String textData ) {
7080 tokenCount += this .estimate (textData );
7181 }
7282 else if (media .getData () instanceof byte [] binaryData ) {
73- tokenCount += binaryData .length ; // This is likely incorrect.
83+ String base64 = Base64 .getEncoder ().encodeToString (binaryData );
84+ tokenCount += this .estimate (base64 );
7485 }
7586 }
7687 }
@@ -79,7 +90,7 @@ else if (media.getData() instanceof byte[] binaryData) {
7990 }
8091
8192 @ Override
82- public int estimate (Iterable <MediaContent > contents ) {
93+ public int estimate (final Iterable <MediaContent > contents ) {
8394 int totalSize = 0 ;
8495 for (MediaContent mediaContent : contents ) {
8596 totalSize += this .estimate (mediaContent );
0 commit comments