Skip to content

Commit 1e26d98

Browse files
Merge pull request #154 from cloudera/mob/main
Handle zip file uploads, and add another nifi integration
2 parents 5bf0b04 + b2ddcfc commit 1e26d98

23 files changed

+2400
-522
lines changed

backend/src/main/java/com/cloudera/cai/rag/Types.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,7 @@
4141
import jakarta.annotation.Nullable;
4242
import java.time.Instant;
4343
import java.util.List;
44-
import lombok.Builder;
45-
import lombok.Getter;
46-
import lombok.Singular;
47-
import lombok.With;
44+
import lombok.*;
4845

4946
public class Types {
5047
/** Data returned from the file upload endpoint. */
@@ -149,4 +146,11 @@ public record CreateSession(
149146

150147
public record MetadataMetrics(
151148
int numberOfDataSources, int numberOfSessions, int numberOfDocuments) {}
149+
150+
public record NifiConfigOptions(String name, String description, DataFlowConfigType configType) {}
151+
152+
public enum DataFlowConfigType {
153+
AZURE_BLOB,
154+
S3
155+
}
152156
}

backend/src/main/java/com/cloudera/cai/rag/datasources/RagDataSourceController.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/*
22
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
33
* (C) Cloudera, Inc. 2024
44
* All rights reserved.
@@ -101,7 +101,15 @@ public List<Types.RagDataSource> getRagDataSources() {
101101
}
102102

103103
@GetMapping(value = "/{id}/nifiConfig", produces = "application/json")
104-
public String getNifiConfig(@PathVariable Long id, @RequestParam String ragStudioUrl) {
105-
return dataSourceService.getNifiConfig(id, ragStudioUrl);
104+
public String getNifiConfig(
105+
@PathVariable Long id,
106+
@RequestParam String ragStudioUrl,
107+
@RequestParam Types.DataFlowConfigType configType) {
108+
return dataSourceService.getNifiConfig(id, ragStudioUrl, configType);
109+
}
110+
111+
@GetMapping(value = "/nifiConfigOptions", produces = "application/json")
112+
public List<Types.NifiConfigOptions> getNifiConfigOptions() {
113+
return dataSourceService.getNifiConfigOptions();
106114
}
107115
}

backend/src/main/java/com/cloudera/cai/rag/datasources/RagDataSourceService.java

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
package com.cloudera.cai.rag.datasources;
4040

41+
import com.cloudera.cai.rag.Types;
4142
import com.cloudera.cai.rag.Types.RagDataSource;
4243
import com.cloudera.cai.util.ResourceUtils;
4344
import java.io.IOException;
@@ -62,6 +63,15 @@ public RagDataSource createRagDataSource(RagDataSource input) {
6263
return ragDataSourceRepository.getRagDataSourceById(id);
6364
}
6465

66+
public RagDataSource updateRagDataSource(RagDataSource input) {
67+
ragDataSourceRepository.updateRagDataSource(input);
68+
return ragDataSourceRepository.getRagDataSourceById(input.id());
69+
}
70+
71+
public void deleteDataSource(Long id) {
72+
ragDataSourceRepository.deleteDataSource(id);
73+
}
74+
6575
public List<RagDataSource> getRagDataSources() {
6676
return ragDataSourceRepository.getRagDataSources();
6777
}
@@ -70,28 +80,36 @@ public RagDataSource getRagDataSourceById(Long id) {
7080
return ragDataSourceRepository.getRagDataSourceById(id);
7181
}
7282

73-
public void deleteDataSource(Long id) {
74-
ragDataSourceRepository.deleteDataSource(id);
75-
}
76-
77-
// Nullables stuff below here.
78-
79-
public static RagDataSourceService createNull() {
80-
return new RagDataSourceService(RagDataSourceRepository.createNull());
83+
public List<Types.NifiConfigOptions> getNifiConfigOptions() {
84+
return List.of(
85+
new Types.NifiConfigOptions(
86+
"S3 Cloudera DataFlow Definition",
87+
"Flow definition for pointing a S3 bucket to RAG Studio. Requires AWS credentials.",
88+
Types.DataFlowConfigType.S3),
89+
new Types.NifiConfigOptions(
90+
"Azure Blob Storage Cloudera DataFlow Definition",
91+
"Flow definition for pointing an Azure Blob Store to RAG Studio. Requires Azure credentials.",
92+
Types.DataFlowConfigType.AZURE_BLOB));
8193
}
8294

83-
public String getNifiConfig(Long id, String ragStudioUrl) {
95+
public String getNifiConfig(Long id, String ragStudioUrl, Types.DataFlowConfigType configType) {
8496
try {
85-
return ResourceUtils.getFileContents("S3-To-RagStudio-Nifi-template.json")
97+
String fileName =
98+
switch (configType) {
99+
case AZURE_BLOB -> "AzureBlob-To-RagStudio-Nifi-template.json";
100+
case S3 -> "S3-To-RagStudio-Nifi-template.json";
101+
};
102+
return ResourceUtils.getFileContents(fileName)
86103
.replace("$$$RAG_STUDIO_DATASOURCE_ID$$$", id.toString())
87104
.replace("$$$RAG_STUDIO_URL$$$", ragStudioUrl);
88105
} catch (IOException e) {
89106
throw new RuntimeException(e);
90107
}
91108
}
92109

93-
public RagDataSource updateRagDataSource(RagDataSource input) {
94-
ragDataSourceRepository.updateRagDataSource(input);
95-
return ragDataSourceRepository.getRagDataSourceById(input.id());
110+
// Nullables stuff below here.
111+
112+
public static RagDataSourceService createNull() {
113+
return new RagDataSourceService(RagDataSourceRepository.createNull());
96114
}
97115
}

backend/src/main/java/com/cloudera/cai/rag/files/FileSystemRagFileUploader.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import java.nio.file.Path;
4444
import lombok.extern.slf4j.Slf4j;
4545
import org.springframework.stereotype.Component;
46-
import org.springframework.web.multipart.MultipartFile;
4746

4847
@Slf4j
4948
@Component
@@ -52,12 +51,12 @@ public class FileSystemRagFileUploader implements RagFileUploader {
5251
private static final String FILE_STORAGE_ROOT = fileStoragePath();
5352

5453
@Override
55-
public void uploadFile(MultipartFile file, String s3Path) {
54+
public void uploadFile(UploadableFile file, String s3Path) {
5655
log.info("Uploading file to FS: {}", s3Path);
5756
try {
5857
Path filePath = Path.of(FILE_STORAGE_ROOT, s3Path);
5958
Files.createDirectories(filePath.getParent());
60-
Files.write(filePath, file.getBytes());
59+
Files.copy(file.getInputStream(), filePath);
6160
} catch (IOException e) {
6261
throw new RuntimeException(e);
6362
}

backend/src/main/java/com/cloudera/cai/rag/files/RagFileController.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/*
22
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
33
* (C) Cloudera, Inc. 2024
44
* All rights reserved.
@@ -65,7 +65,7 @@ public RagFileController(RagFileService ragFileService) {
6565
value = "/dataSources/{dataSourceId}/files",
6666
consumes = "multipart/form-data",
6767
produces = "application/json")
68-
public RagDocumentMetadata uploadRagDocument(
68+
public List<RagDocumentMetadata> uploadRagDocument(
6969
@RequestPart("file") MultipartFile file,
7070
@PathVariable Long dataSourceId,
7171
HttpServletRequest request) {

backend/src/main/java/com/cloudera/cai/rag/files/RagFileService.java

Lines changed: 89 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,13 @@
4444
import com.cloudera.cai.util.IdGenerator;
4545
import com.cloudera.cai.util.exceptions.BadRequest;
4646
import com.cloudera.cai.util.exceptions.NotFound;
47+
import java.io.IOException;
48+
import java.io.InputStream;
4749
import java.time.Instant;
50+
import java.util.ArrayList;
4851
import java.util.List;
52+
import java.util.zip.ZipEntry;
53+
import java.util.zip.ZipInputStream;
4954
import lombok.extern.slf4j.Slf4j;
5055
import org.springframework.beans.factory.annotation.Autowired;
5156
import org.springframework.beans.factory.annotation.Qualifier;
@@ -81,13 +86,53 @@ public RagFileService(
8186
this.ragFileDeleteReconciler = ragFileDeleteReconciler;
8287
}
8388

84-
public RagDocumentMetadata saveRagFile(MultipartFile file, Long dataSourceId, String actorCrn) {
89+
public List<RagDocumentMetadata> saveRagFile(
90+
MultipartFile file, Long dataSourceId, String actorCrn) {
8591
ragDataSourceRepository.getRagDataSourceById(dataSourceId);
92+
93+
if (isZipFile(file)) {
94+
List<RagDocumentMetadata> results = processZipFile(file, dataSourceId, actorCrn);
95+
if (results.isEmpty()) {
96+
throw new BadRequest("Invalid or empty zip file");
97+
}
98+
return results;
99+
}
100+
return List.of(processFile(dataSourceId, actorCrn, new MultipartUploadableFile(file)));
101+
}
102+
103+
private boolean isZipFile(MultipartFile file) {
104+
return file.getContentType() != null && file.getContentType().equals("application/zip")
105+
|| (file.getOriginalFilename() != null
106+
&& file.getOriginalFilename().toLowerCase().endsWith(".zip"));
107+
}
108+
109+
private List<RagDocumentMetadata> processZipFile(
110+
MultipartFile file, Long dataSourceId, String actorCrn) {
111+
List<RagDocumentMetadata> results = new ArrayList<>();
112+
try (ZipInputStream zipInputStream = new ZipInputStream(file.getInputStream())) {
113+
ZipEntry entry;
114+
while ((entry = zipInputStream.getNextEntry()) != null) {
115+
if (entry.isDirectory()) {
116+
continue;
117+
}
118+
results.add(
119+
processFile(dataSourceId, actorCrn, new ZipEntryUploadableFile(entry, zipInputStream)));
120+
zipInputStream.closeEntry();
121+
}
122+
} catch (IOException e) {
123+
throw new BadRequest("Failed to process zip file: " + e.getMessage());
124+
}
125+
return results;
126+
}
127+
128+
private RagDocumentMetadata processFile(
129+
Long dataSourceId, String actorCrn, UploadableFile uploadableFile) {
86130
String documentId = idGenerator.generateId();
87131
var s3Path = buildS3Path(dataSourceId, documentId);
88132

89-
ragFileUploader.uploadFile(file, s3Path);
90-
var ragDocument = createUnsavedDocument(file, documentId, s3Path, dataSourceId, actorCrn);
133+
ragFileUploader.uploadFile(uploadableFile, s3Path);
134+
var ragDocument =
135+
createUnsavedDocument(uploadableFile, documentId, s3Path, dataSourceId, actorCrn);
91136
Long id = ragFileRepository.insertDocumentMetadata(ragDocument);
92137
log.info("Saved document with id: {}", id);
93138

@@ -113,10 +158,10 @@ private String extractFileExtension(String originalFilename) {
113158
}
114159

115160
private RagDocument createUnsavedDocument(
116-
MultipartFile file, String documentId, String s3Path, Long dataSourceId, String actorCrn) {
161+
UploadableFile file, String documentId, String s3Path, Long dataSourceId, String actorCrn) {
117162
return new RagDocument(
118163
null,
119-
removeDirectories(file.getOriginalFilename()),
164+
validateFilename(file.getOriginalFilename()),
120165
dataSourceId,
121166
documentId,
122167
s3Path,
@@ -134,15 +179,11 @@ private RagDocument createUnsavedDocument(
134179
null);
135180
}
136181

137-
private static String removeDirectories(String originalFilename) {
182+
private static String validateFilename(String originalFilename) {
138183
if (originalFilename == null || originalFilename.isBlank()) {
139184
throw new BadRequest("Filename is required");
140185
}
141-
if (originalFilename.contains("/")) {
142-
return originalFilename.substring(originalFilename.lastIndexOf('/') + 1);
143-
} else {
144-
return originalFilename;
145-
}
186+
return originalFilename;
146187
}
147188

148189
public void deleteRagFile(Long id, Long dataSourceId) {
@@ -170,4 +211,41 @@ public static RagFileService createNull(String... dummyIds) {
170211
public List<RagDocument> getRagDocuments(Long dataSourceId) {
171212
return ragFileRepository.getRagDocuments(dataSourceId);
172213
}
214+
215+
public record MultipartUploadableFile(MultipartFile file) implements UploadableFile {
216+
217+
@Override
218+
public InputStream getInputStream() throws IOException {
219+
return file.getInputStream();
220+
}
221+
222+
@Override
223+
public long getSize() {
224+
return file.getSize();
225+
}
226+
227+
@Override
228+
public String getOriginalFilename() {
229+
return file.getOriginalFilename();
230+
}
231+
}
232+
233+
private record ZipEntryUploadableFile(ZipEntry entry, ZipInputStream zipInputStream)
234+
implements UploadableFile {
235+
236+
@Override
237+
public String getOriginalFilename() {
238+
return entry.getName();
239+
}
240+
241+
@Override
242+
public InputStream getInputStream() {
243+
return zipInputStream;
244+
}
245+
246+
@Override
247+
public long getSize() {
248+
return entry.getSize();
249+
}
250+
}
173251
}

backend/src/main/java/com/cloudera/cai/rag/files/RagFileUploader.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,11 @@
3939
package com.cloudera.cai.rag.files;
4040

4141
import com.cloudera.cai.util.Tracker;
42-
import org.springframework.web.multipart.MultipartFile;
4342

4443
public interface RagFileUploader {
45-
void uploadFile(MultipartFile file, String path);
44+
void uploadFile(UploadableFile file, String path);
4645

47-
record UploadRequest(MultipartFile file, String documentId) {}
46+
record UploadRequest(UploadableFile file, String documentId) {}
4847

4948
// nullables below here
5049

@@ -65,7 +64,7 @@ public UploaderStub(Tracker<UploadRequest> tracker) {
6564
}
6665

6766
@Override
68-
public void uploadFile(MultipartFile file, String s3Path) {
67+
public void uploadFile(UploadableFile file, String s3Path) {
6968
tracker.track(new UploadRequest(file, s3Path));
7069
}
7170
}

backend/src/main/java/com/cloudera/cai/rag/files/S3RagFileUploader.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import java.io.IOException;
4444
import lombok.extern.slf4j.Slf4j;
4545
import org.springframework.beans.factory.annotation.Qualifier;
46-
import org.springframework.web.multipart.MultipartFile;
4746
import software.amazon.awssdk.core.sync.RequestBody;
4847
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
4948

@@ -59,7 +58,7 @@ public S3RagFileUploader(
5958
}
6059

6160
@Override
62-
public void uploadFile(MultipartFile file, String s3Path) {
61+
public void uploadFile(UploadableFile file, String s3Path) {
6362
log.info("Uploading file to S3: {}", s3Path);
6463
PutObjectRequest objectRequest =
6564
PutObjectRequest.builder().bucket(bucketName).key(s3Path).build();
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
3+
* (C) Cloudera, Inc. 2024
4+
* All rights reserved.
5+
*
6+
* Applicable Open Source License: Apache 2.0
7+
*
8+
* NOTE: Cloudera open source products are modular software products
9+
* made up of hundreds of individual components, each of which was
10+
* individually copyrighted. Each Cloudera open source product is a
11+
* collective work under U.S. Copyright Law. Your license to use the
12+
* collective work is as provided in your written agreement with
13+
* Cloudera. Used apart from the collective work, this file is
14+
* licensed for your use pursuant to the open source license
15+
* identified above.
16+
*
17+
* This code is provided to you pursuant a written agreement with
18+
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
19+
* this code. If you do not have a written agreement with Cloudera nor
20+
* with an authorized and properly licensed third party, you do not
21+
* have any rights to access nor to use this code.
22+
*
23+
* Absent a written agreement with Cloudera, Inc. ("Cloudera") to the
24+
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
25+
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
26+
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
27+
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
28+
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
29+
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
30+
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
31+
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
32+
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
33+
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
34+
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
35+
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
36+
* DATA.
37+
*/
38+
39+
package com.cloudera.cai.rag.files;
40+
41+
import java.io.IOException;
42+
import java.io.InputStream;
43+
44+
public interface UploadableFile {
45+
InputStream getInputStream() throws IOException;
46+
47+
long getSize();
48+
49+
String getOriginalFilename();
50+
}

0 commit comments

Comments
 (0)