wget https://vision.cs.utexas.edu/projects/finegrained/utzap50k/ut-zap50k-images-square.zip
解压,方便后面训练模型使用
unzip ut-zap50k-images-square.zip
基于djl实现图片分类
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.2.1</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>djl</artifactId>
<properties>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<!-- DJL -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
<!-- pytorch-engine-->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<scope>runtime</scope>
</dependency>
</dependencies>
<profiles>
<profile>
<id>windows</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<dependencies>
<!-- Windows CPU -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>win-x86_64</classifier>
<scope>runtime</scope>
<version>2.0.1</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.1-0.23.0</version>
<scope>runtime</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>centos7</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<dependencies>
<!-- For Pre-CXX11 build (CentOS7)-->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu-precxx11</artifactId>
<classifier>linux-x86_64</classifier>
<version>2.0.1</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.1-0.23.0</version>
<scope>runtime</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>linux</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<dependencies>
<!-- Linux CPU -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>linux-x86_64</classifier>
<scope>runtime</scope>
<version>2.0.1</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.1-0.23.0</version>
<scope>runtime</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>aarch64</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<dependencies>
<!-- For aarch64 build-->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu-precxx11</artifactId>
<classifier>linux-aarch64</classifier>
<scope>runtime</scope>
<version>2.0.1</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.1-0.23.0</version>
<scope>runtime</scope>
</dependency>
</dependencies>
</profile>
</profiles>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>0.23.0</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
</project>
package com.et.controller;
import ai.djl.MalformedModelException;
import ai.djl.translate.TranslateException;
import com.et.service.ImageClassificationService;
import lombok.RequiredArgsConstructor;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.stream.Stream;
@RestController
@RequiredArgsConstructor
public class ImageClassificationController {
private final ImageClassificationService imageClassificationService;
@PostMapping(path = "/analyze")
public String predict(@RequestPart("image") MultipartFile image,
@RequestParam(defaultValue = "/home/djl-test/models") String modePath)
throws TranslateException,
MalformedModelException,
IOException {
return imageClassificationService.predict(image, modePath);
}
@PostMapping(path = "/training")
public String training(@RequestParam(defaultValue = "/home/djl-test/images-test")
String datasetRoot,
@RequestParam(defaultValue = "/home/djl-test/models") String modePath) throws TranslateException, IOException {
return imageClassificationService.training(datasetRoot, modePath);
}
@GetMapping("/download")
public ResponseEntity<Resource> downloadFile(@RequestParam(defaultValue = "/home/djl-test/images-test") String directoryPath) {
List<String> imgPathList = new ArrayList<>();
try (Stream<Path> paths = Files.walk(Paths.get(directoryPath))) {
// Filter only regular files (excluding directories)
paths.filter(Files::isRegularFile)
.forEach(c-> imgPathList.add(c.toString()));
} catch (IOException e) {
return ResponseEntity.status(500).build();
}
Random random = new Random();
String filePath = imgPathList.get(random.nextInt(imgPathList.size()));
Path file = Paths.get(filePath);
Resource resource = new FileSystemResource(file.toFile());
if (!resource.exists()) {
return ResponseEntity.notFound().build();
}
HttpHeaders headers = new HttpHeaders();
headers.add(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=" + file.getFileName().toString());
headers.add(HttpHeaders.CONTENT_TYPE, MediaType.IMAGE_JPEG_VALUE);
try {
return ResponseEntity.ok()
.headers(headers)
.contentLength(resource.contentLength())
.body(resource);
} catch (IOException e) {
return ResponseEntity.status(500).build();
}
}
}
package com.et.service;
import ai.djl.MalformedModelException;
import ai.djl.translate.TranslateException;
import org.springframework.web.multipart.MultipartFile;
import java.io.IOException;
public interface ImageClassificationService {
public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException;
public String training(String datasetRoot, String modePath) throws TranslateException, IOException;
}
实现类
package com.et.service;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.ImageFolder;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.*;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import com.et.Models;
import lombok.Cleanup;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.nio.file.Paths;
@Slf4j
@Service
public class ImageClassificationServiceImpl implements ImageClassificationService {
// represents number of training samples processed before the model is updated
private static final int BATCH_SIZE = 32;
// the number of passes over the complete dataset
private static final int EPOCHS = 2;
//the number of classification labels: boots, sandals, shoes, slippers
@Value("${djl.num-of-output:4}")
public int numOfOutput;
@Override
public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException {
@Cleanup
InputStream is = image.getInputStream();
Path modelDir = Paths.get(modePath);
BufferedImage bi = ImageIO.read(is);
Image img = ImageFactory.getInstance().fromImage(bi);
// empty model instance
try (Model model = Models.getModel(numOfOutput)) {
// load the model
model.load(modelDir, Models.MODEL_NAME);
// define a translator for pre and post processing
// out of the box this translator converts images to ResNet friendly ResNet 18 shape
Translator<Image, Classifications> translator =
ImageClassificationTranslator.builder()
.addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT))
.addTransform(new ToTensor())
.optApplySoftmax(true)
.build();
// run the inference using a Predictor
try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
// holds the probability score per label
Classifications predictResult = predictor.predict(img);
log.info("reusult={}",predictResult.toJson());
return predictResult.toJson();
}
}
}
@Override
public String training(String datasetRoot, String modePath) throws TranslateException, IOException {
log.info("Image dataset training started...Image dataset address path:{}",datasetRoot);
// the location to save the model
Path modelDir = Paths.get(modePath);
// create ImageFolder dataset from directory
ImageFolder dataset = initDataset(datasetRoot);
// Split the dataset set into training dataset and validate dataset
RandomAccessDataset[] datasets = dataset.randomSplit(8, 2);
// set loss function, which seeks to minimize errors
// loss function evaluates model's predictions against the correct answer (during training)
// higher numbers are bad - means model performed poorly; indicates more errors; want to
// minimize errors (loss)
Loss loss = Loss.softmaxCrossEntropyLoss();
// setting training parameters (ie hyperparameters)
TrainingConfig config = setupTrainingConfig(loss);
try (Model model = Models.getModel(numOfOutput); // empty model instance to hold patterns
Trainer trainer = model.newTrainer(config)) {
// metrics collect and report key performance indicators, like accuracy
trainer.setMetrics(new Metrics());
Shape inputShape = new Shape(1, 3, Models.IMAGE_HEIGHT, Models.IMAGE_HEIGHT);
// initialize trainer with proper input shape
trainer.initialize(inputShape);
// find the patterns in data
EasyTrain.fit(trainer, EPOCHS, datasets[0], datasets[1]);
// set model properties
TrainingResult result = trainer.getTrainingResult();
model.setProperty("Epoch", String.valueOf(EPOCHS));
model.setProperty(
"Accuracy", String.format("%.5f", result.getValidateEvaluation("Accuracy")));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
// save the model after done training for inference later
// model saved as shoeclassifier-0000.params
model.save(modelDir, Models.MODEL_NAME);
// save labels into model directory
Models.saveSynset(modelDir, dataset.getSynset());
log.info("Image dataset training completed......");
return String.join("\n", dataset.getSynset());
}
}
private ImageFolder initDataset(String datasetRoot)
throws IOException, TranslateException {
ImageFolder dataset =
ImageFolder.builder()
// retrieve the data
.setRepositoryPath(Paths.get(datasetRoot))
.optMaxDepth(10)
.addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT))
.addTransform(new ToTensor())
// random sampling; don't process the data in order
.setSampling(BATCH_SIZE, true)
.build();
dataset.prepare();
return dataset;
}
private TrainingConfig setupTrainingConfig(Loss loss) {
return new DefaultTrainingConfig(loss)
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
}
}
server:
port: 8888
spring:
application:
name: djl-image-classification-demo
servlet:
multipart:
max-file-size: 100MB
max-request-size: 100MB
mvc:
pathmatch:
matching-strategy: ant_path_matcher
package com.et;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class DemoApplication {
public static void main(String[] args) {
SpringApplication.run(DemoApplication.class, args);
}
}
2024-10-11T21:00:05.407+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] c.e.s.ImageClassificationServiceImpl : Image dataset training started...Image dataset address path:/Users/liuhaihua/ai/ut-zap50k-images-square
2024-10-11T21:00:08.455+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.util.Platform : Ignore mismatching platform from: jar:file:/Users/liuhaihua/.m2/repository/ai/djl/pytorch/pytorch-native-cpu/2.0.1/pytorch-native-cpu-2.0.1-win-x86_64.jar!/native/lib/pytorch.properties
2024-10-11T21:00:09.240+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization
2024-10-11T21:00:09.241+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : Number of inter-op threads is 4
2024-10-11T21:00:09.241+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : Number of intra-op threads is 4
2024-10-11T21:00:09.287+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Training on: cpu().
2024-10-11T21:00:09.290+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Load PyTorch Engine Version 1.13.1 in 0.044 ms.
Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.38
Validating: 100% |████████████████████████████████████████|
2024-10-11T22:42:48.142+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Epoch 1 finished.
2024-10-11T22:42:48.187+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.38
2024-10-11T22:42:48.189+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Validate: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.24
Training: 5% |███ | Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.22