Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.school.mohitto.controller;

import com.school.mohitto.dto.requestDTO.ChangeFaceRecommandRequest;
import com.school.mohitto.dto.requestDTO.ChangeFaceSimulationRequest;
import com.school.mohitto.dto.requestDTO.SimulationRequest;
import com.school.mohitto.dto.responseDTO.ChangeFaceSimulationResponse;
Expand Down Expand Up @@ -33,6 +34,13 @@ public FinalRecommandResponse getRecommand(
return simulationService.extractFaceAndRecommand(multipartFile,simulationRequest);
}

@PostMapping(value = "/recommand/transfer-face")
public ChangeFaceSimulationResponse changeFaceInRecommandService(
@RequestBody ChangeFaceRecommandRequest changeFaceRecommandRequest
) throws IOException {
return simulationService.changeFaceInRecommandService(changeFaceRecommandRequest);
}

@PostMapping(value = "/transfer-face", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
public ChangeFaceSimulationResponse changeFaceInSimulationService(
@RequestPart(value = "image", required = true)
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/com/school/mohitto/domain/CreatedImage.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,9 @@ public class CreatedImage extends BaseTimeEntity {
public CreatedImage(String createdImageUrl) {
this.createdImageUrl = createdImageUrl;
}

public CreatedImage(Hair hair, String createdImageUrl) {
this.hair = hair;
this.createdImageUrl = createdImageUrl;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.school.mohitto.dto.requestDTO;

public record ChangeFaceRecommandRequest(
Long hairId,
Long modelImageId
) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import com.querydsl.jpa.impl.JPAQueryFactory;
import com.school.mohitto.domain.Diagnosis;
import com.school.mohitto.domain.ModelImage;
import com.school.mohitto.domain.enums.HairLengthType;
import com.school.mohitto.domain.enums.HairTypeEnum;
import com.school.mohitto.domain.enums.HasBangType;
import com.school.mohitto.domain.enums.SexType;
import lombok.RequiredArgsConstructor;

Expand All @@ -20,23 +22,23 @@ public class ModelImageRepositoryImpl implements ModelImageRepositoryCustom {
@Override
public ModelImage findModelImageByDiagnosisFeature(Diagnosis diagnosis, String style) {

BooleanBuilder hairTypeCondition = new BooleanBuilder();
SexType sex = diagnosis.getDiagnosisSex().getSex().getSex();
HairLengthType hairLength = diagnosis.getDiagnosisHairLength().getHairLength().getHairLength();
HasBangType hasBangType = (sex == SexType.MALE)
? HasBangType.NONE
: diagnosis.getDiagnosisHasbangs().getHasBangs().getHasBangType();
HairTypeEnum hairType = (sex == SexType.MALE)
? diagnosis.getDiagnosisHairType().getHairType().getType()
: HairTypeEnum.NONE;

if (diagnosis.getDiagnosisSex().getSex().getSex() == SexType.MALE) {
hairTypeCondition.and(modelImage.hairTypeEnum.eq(diagnosis.getDiagnosisHairType().getHairType().getType()));
} else {
hairTypeCondition.and(modelImage.hairTypeEnum.eq(HairTypeEnum.NONE));
}

ModelImage image = queryFactory.selectFrom(modelImage)
return queryFactory.selectFrom(modelImage)
.where(
modelImage.name.eq(style),
modelImage.hairLength.eq(diagnosis.getDiagnosisHairLength().getHairLength().getHairLength()),
modelImage.sex.eq(diagnosis.getDiagnosisSex().getSex().getSex()),
modelImage.hasBangType.eq(diagnosis.getDiagnosisHasbangs().getHasBangs().getHasBangType()),
hairTypeCondition
modelImage.hairLength.eq(hairLength),
modelImage.sex.eq(sex),
modelImage.hasBangType.eq(hasBangType),
modelImage.hairTypeEnum.eq(hairType)
)
.fetchFirst();
return image;
}
}
50 changes: 47 additions & 3 deletions src/main/java/com/school/mohitto/service/SimulationService.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -67,7 +68,6 @@ public FinalRecommandResponse extractFaceAndRecommand(
uploadImageRepository.save(image);
log.info(user_image_url.toString());


String hair_length = diagnosisHairLengthRepository.findByDiagnosisId(diagnosisId)
.orElseThrow(() -> new CustomException(ErrorCode.HAIR_LENGTH_NOT_FOUND))
.getHairLength().getHairLength().getValue();
Expand Down Expand Up @@ -192,7 +192,6 @@ private FinalRecommandResponse generateHairStyle(RecommandRequest recommandReque
recommendation ->
{
ModelImage modelImage = modelImageRepository.findModelImageByDiagnosisFeature(diagnosis,recommendation.style());
// log.info(modelImage.getName());
Hair hair = Hair.builder()
.name(recommendation.style())
.explanation(recommendation.description())
Expand All @@ -210,14 +209,57 @@ private FinalRecommandResponse generateHairStyle(RecommandRequest recommandReque
recommendation.hair_shops());
}
).collect(Collectors.toList());

log.info("Rag 모델 완료");
return new FinalRecommandResponse(result);
}

@Transactional
public ChangeFaceSimulationResponse changeFaceInRecommandService(
ChangeFaceRecommandRequest inputFaceRequest ) throws IOException {

Hair hair = hairRepository.findById(inputFaceRequest.hairId()).orElseThrow(() -> new CustomException(ErrorCode.HAIR_NOT_FOUND));

String user_image_url = hair.getDiagnosis().getUploadImage().getUploadImageUrl();
log.info("추천 헤어 적용 전 이미지 사진: " + user_image_url);

ModelImage modelImage = modelImageRepository.findById(inputFaceRequest.modelImageId())
.orElseThrow(() -> new CustomException(ErrorCode.HAIR_NOT_FOUND));

ChangeFaceRequest changeFaceRequest = new ChangeFaceRequest(user_image_url, modelImage.getUploadImageUrl());

Resource response = Mono.delay(Duration.ofSeconds(10)) // 10초 대기
.then(
webClientFactory.create(fastapiProperties.hairTransfer())
.post()
.uri("/simulate")
.accept(MediaType.IMAGE_PNG)
.accept(MediaType.IMAGE_JPEG)
.bodyValue(changeFaceRequest)
.retrieve()
.onStatus(HttpStatusCode::isError, clientResponse ->
clientResponse.bodyToMono(String.class).flatMap(error ->
Mono.error(new RuntimeException("시뮬레이션 모델 오류: " + error))
)
)
.bodyToMono(Resource.class)
)
.block(); // 전체 실행
log.info("시뮬레이션 모델 완료");

MultipartFile file = convertResourceToMultipartFile(response);
String result_image_url = s3Uploader.upload(file, "result");
log.info("추천 헤어 시뮬레이션 결과 이미지: " + result_image_url);

createdImageRepository.save(new CreatedImage(hair,result_image_url));

return new ChangeFaceSimulationResponse(result_image_url);
}

public ChangeFaceSimulationResponse changeFaceInSimulationService(
MultipartFile multipartFile,
ChangeFaceSimulationRequest inputFaceRequest) throws IOException {
String user_image_url = s3Uploader.upload(multipartFile, "face");
log.info(user_image_url);

ModelImage modelImage = modelImageRepository.findById(inputFaceRequest.modelImageId())
.orElseThrow(() -> new CustomException(ErrorCode.HAIR_NOT_FOUND));
Expand All @@ -237,9 +279,11 @@ public ChangeFaceSimulationResponse changeFaceInSimulationService(
)
.bodyToMono(Resource.class)
.block();
log.info("시뮬레이션 완료");

MultipartFile file = convertResourceToMultipartFile(response);
String result_image_url = s3Uploader.upload(file, "result");
log.info("시뮬레이션 서비스에서의 이미지 결과 : " + result_image_url);
return new ChangeFaceSimulationResponse(result_image_url);
}

Expand Down