Skip to content

Commit 8d95d16

Browse files
authored
improve: train/test 데이터셋 분할 방식으로 평가하도록 변경 (#105)
1 parent 1ea9fa0 commit 8d95d16

File tree

9 files changed

+546
-230
lines changed

9 files changed

+546
-230
lines changed

src/main/java/com/techfork/domain/recommendation/service/LlmRecommendationService.java

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,99 @@ public List<Long> generateRecommendationsForEvaluation(User user) {
164164
}
165165
}
166166

167+
/**
168+
* 추천 생성 (평가 전용 - Train/Test Split 지원)
169+
* 특정 읽은 글 목록(Train Set)만 제외하고 추천 생성
170+
*
171+
* @param user 사용자
172+
* @param trainPostIds Train Set 게시글 ID 목록 (제외할 글)
173+
* @return 추천된 게시글 ID 리스트
174+
*/
175+
public List<Long> generateRecommendationsForEvaluation(User user, Set<Long> trainPostIds) {
176+
// 1. 사용자 프로필 벡터 조회
177+
Optional<UserProfileDocument> profileOpt = userProfileDocumentRepository.findByUserId(user.getId());
178+
if (profileOpt.isEmpty() || profileOpt.get().getProfileVector() == null) {
179+
log.warn("사용자 {}의 프로필 또는 벡터를 찾을 수 없음. 추천 생성 스킵.", user.getId());
180+
return Collections.emptyList();
181+
}
182+
183+
float[] userProfileVector = profileOpt.get().getProfileVector();
184+
185+
try {
186+
// 2. k-NN 검색으로 초기 후보군 가져오기 (Train Set만 제외)
187+
List<MmrCandidate> candidates = searchCandidatesWithCustomReadHistory(userProfileVector, user, trainPostIds);
188+
189+
if (candidates.isEmpty()) {
190+
log.debug("사용자 {}의 추천 후보군을 찾을 수 없음 (Train Set {} 개 제외)", user.getId(), trainPostIds.size());
191+
return Collections.emptyList();
192+
}
193+
194+
// 3. MMR 적용하여 최종 추천 선택
195+
List<MmrResult> mmrResults = mmrService.applyMmr(candidates);
196+
197+
// 4. 추천된 게시글 ID 리스트 반환
198+
return mmrResults.stream()
199+
.map(MmrResult::getPostId)
200+
.toList();
201+
202+
} catch (Exception e) {
203+
log.error("사용자 {} 추천 생성 실패 (Train/Test Split 평가용)", user.getId(), e);
204+
return Collections.emptyList();
205+
}
206+
}
207+
208+
/**
209+
* Elasticsearch k-NN 검색으로 초기 후보군 조회 (커스텀 읽은 글 목록)
210+
* Train/Test Split 평가를 위해 Train Set만 제외
211+
*/
212+
private List<MmrCandidate> searchCandidatesWithCustomReadHistory(
213+
float[] userProfileVector,
214+
User user,
215+
Set<Long> readPostIds) throws IOException {
216+
217+
log.debug("사용자 {}의 읽은 게시글 {} 개 제외 (Train Set)", user.getId(), readPostIds.size());
218+
219+
// 가중치 가져오기
220+
RecommendationProperties.EmbeddingWeights weights = properties.getEmbeddingWeights();
221+
222+
// 랜덤 시드 생성 (현재 시간 기반)
223+
long randomSeed = System.currentTimeMillis();
224+
double randomWeight = 0.2; // 랜덤 가중치 20%
225+
226+
// k-NN 쿼리 (가중 평균: title + summary + content chunks + 랜덤 요소)
227+
Query knnQuery = vectorQueryBuilder.createWeightedVectorQueryWithRandomness(
228+
TITLE_EMBEDDING_FIELD,
229+
SUMMARY_EMBEDDING_FIELD,
230+
CONTENT_CHUNKS_FIELD,
231+
CHUNK_EMBEDDING_FIELD,
232+
userProfileVector,
233+
weights.getTitle(),
234+
weights.getSummary(),
235+
weights.getContent(),
236+
randomSeed,
237+
randomWeight
238+
);
239+
240+
log.debug("ES 쿼리 실행 (Train/Test Split) - 벡터 차원: {}, 가중치 [title:{}, summary:{}, content:{}]",
241+
userProfileVector.length, weights.getTitle(), weights.getSummary(), weights.getContent());
242+
243+
SearchResponse<PostDocument> response = elasticsearchClient.search(s -> s
244+
.index(POSTS_INDEX)
245+
.query(knnQuery)
246+
.size(properties.getKnnSearchSize())
247+
,
248+
PostDocument.class
249+
);
250+
251+
// 결과를 MmrCandidate로 변환 (Train Set만 필터링)
252+
return response.hits().hits().stream()
253+
.filter(hit -> hit.source() != null)
254+
.filter(hit -> !readPostIds.contains(hit.source().getPostId()))
255+
.map(this::mapToMmrCandidate)
256+
.filter(candidate -> candidate.getSummaryVector() != null)
257+
.toList();
258+
}
259+
167260
/**
168261
* Elasticsearch k-NN 검색으로 초기 후보군 조회
169262
* - 이미 읽은 글 제외
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package com.techfork.domain.recommendation_quality;
2+
3+
import lombok.AllArgsConstructor;
4+
import lombok.Builder;
5+
import lombok.Data;
6+
import lombok.NoArgsConstructor;
7+
8+
import java.util.List;
9+
import java.util.Set;
10+
import java.util.stream.Collectors;
11+
12+
/**
13+
* Train/Test Split 기반 개선된 추천 시스템 테스트 케이스
14+
*
15+
* 기존 방식의 문제:
16+
* 1. Ground Truth가 문자열 매칭 기반 (추천 시스템은 벡터 유사도 기반)
17+
* 2. Recall 분모가 너무 커서(100개) 지표가 낮게 나옴
18+
*
19+
* 개선 방식:
20+
* 1. 읽은 글을 8:2로 분할 (Train/Test)
21+
* 2. Test Set을 Ground Truth로 사용 (실제로 읽은 글 = 관심있는 글)
22+
* 3. 적절한 Recall 분모 (Test Set 크기 = 20개 정도)
23+
*/
24+
@Data
25+
@Builder
26+
@NoArgsConstructor
27+
@AllArgsConstructor
28+
public class ImprovedRecommendationTestCase {
29+
30+
/**
31+
* 사용자 ID
32+
*/
33+
private Long userId;
34+
35+
/**
36+
* 사용자 관심사
37+
*/
38+
private List<String> interests;
39+
40+
/**
41+
* Train/Test 분할 결과
42+
*/
43+
private TrainTestSplit trainTestSplit;
44+
45+
/**
46+
* Test Set을 Ground Truth로 반환 (Recall 계산용)
47+
*/
48+
public Set<Long> getGroundTruthPostIds() {
49+
return trainTestSplit.getTestPostIds().stream()
50+
.collect(Collectors.toSet());
51+
}
52+
53+
/**
54+
* Train Set 반환 (사용자 프로필 생성용)
55+
*/
56+
public List<Long> getTrainPostIds() {
57+
return trainTestSplit.getTrainPostIds();
58+
}
59+
60+
/**
61+
* Test Set 반환 (평가용)
62+
*/
63+
public List<Long> getTestPostIds() {
64+
return trainTestSplit.getTestPostIds();
65+
}
66+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package com.techfork.domain.recommendation_quality;
2+
3+
import lombok.AllArgsConstructor;
4+
import lombok.Builder;
5+
import lombok.Data;
6+
import lombok.NoArgsConstructor;
7+
8+
import java.util.List;
9+
10+
/**
11+
* 사용자 읽기 이력의 Train/Test 분할 결과
12+
* Train: 사용자 프로필 생성에 사용
13+
* Test: 평가 Ground Truth로 사용
14+
*/
15+
@Data
16+
@Builder
17+
@NoArgsConstructor
18+
@AllArgsConstructor
19+
public class TrainTestSplit {
20+
21+
/**
22+
* Train 세트: 사용자 프로필 생성에 사용될 게시글 ID 목록 (80%)
23+
*/
24+
private List<Long> trainPostIds;
25+
26+
/**
27+
* Test 세트: 평가 Ground Truth로 사용될 게시글 ID 목록 (20%)
28+
* 추천 시스템이 이 글들을 상위권에 추천했는지 평가
29+
*/
30+
private List<Long> testPostIds;
31+
32+
/**
33+
* Train 세트 크기
34+
*/
35+
public int getTrainSize() {
36+
return trainPostIds.size();
37+
}
38+
39+
/**
40+
* Test 세트 크기
41+
*/
42+
public int getTestSize() {
43+
return testPostIds.size();
44+
}
45+
}

src/test/java/com/techfork/domain/recommendation/LambdaOptimizationTest.java

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,32 @@
66
import org.junit.jupiter.api.Test;
77

88
import java.util.ArrayList;
9+
import java.util.Comparator;
910
import java.util.List;
1011

1112
@Slf4j
1213
public class LambdaOptimizationTest extends RecommendationTestBase {
1314

1415
@Test
15-
@DisplayName("Lambda 최적화 - 요약 중심 vs 현재 기본값")
16-
void optimizeLambda() {
17-
log.info("===== Lambda 최적화 테스트 =====");
16+
@DisplayName("Lambda 최적화 - 3가지 가중치 조합 (Train/Test Split 방식)")
17+
void optimizeLambdaWithTrainTestSplit() {
18+
log.info("===== Lambda 최적화 테스트 (Train/Test Split) =====");
19+
log.info("읽은 글 100개 → Train 80개 (프로필 생성용) + Test 20개 (평가용)");
20+
log.info("가중치 조합: 컨텐츠중심, 요약중심, 기본값");
21+
log.info("Lambda 범위: 0.0 ~ 1.0 (0.1 단위)");
1822

1923
List<ConfigCombo> configs = createLambdaTestConfigs();
20-
List<User> testUsers = getTestUsers(DEFAULT_TEST_USER_COUNT);
24+
List<User> testUsers = getTestUsers();
25+
log.info("테스트 사용자: {} 명 (IDs: {})", testUsers.size(), TEST_USER_IDS);
2126

22-
printConfigComparisonHeader();
23-
List<EvaluationResult> results = evaluateAllConfigs(configs, testUsers);
24-
printBestResult(results);
27+
printImprovedConfigComparisonHeader();
28+
List<ImprovedEvaluationResult> results = evaluateAllConfigsWithTrainTestSplit(configs, testUsers);
29+
printBestImprovedResultByWeightType(results);
2530
}
2631

2732
/**
2833
* Lambda 0.0 ~ 1.0 (0.1 단위) 테스트 설정 생성
29-
* 요약 중심 + 현재 기본값 조합
34+
* 컨텐츠 중심
3035
*/
3136
private List<ConfigCombo> createLambdaTestConfigs() {
3237
List<ConfigCombo> configs = new ArrayList<>();
@@ -35,53 +40,51 @@ private List<ConfigCombo> createLambdaTestConfigs() {
3540
for (int i = 0; i <= 10; i++) {
3641
double lambda = i / 10.0;
3742

38-
// 1. 요약 중심 (title:0.2, summary:0.6, content:0.2)
3943
configs.add(ConfigCombo.builder()
40-
.name(String.format("요약중심 λ=%.1f", lambda))
44+
.name(String.format("컨텐츠중심 λ=%.1f", lambda))
4145
.titleWeight(0.2f)
42-
.summaryWeight(0.6f)
43-
.contentWeight(0.2f)
44-
.mmrLambda(lambda)
45-
.build());
46-
47-
// 2. 현재 기본값 (title:0.4, summary:0.4, content:0.2)
48-
configs.add(ConfigCombo.builder()
49-
.name(String.format("기본값 λ=%.1f", lambda))
50-
.titleWeight(DEFAULT_TITLE_WEIGHT)
51-
.summaryWeight(DEFAULT_SUMMARY_WEIGHT)
52-
.contentWeight(DEFAULT_CONTENT_WEIGHT)
46+
.summaryWeight(0.2f)
47+
.contentWeight(0.6f)
5348
.mmrLambda(lambda)
5449
.build());
5550
}
5651

52+
log.info("총 {} 개 설정 생성", configs.size());
5753
return configs;
5854
}
5955

6056
/**
61-
* 모든 설정 평가
57+
* 모든 설정 평가 (Train/Test Split)
6258
*/
63-
private List<EvaluationResult> evaluateAllConfigs(List<ConfigCombo> configs, List<User> testUsers) {
59+
private List<ImprovedEvaluationResult> evaluateAllConfigsWithTrainTestSplit(
60+
List<ConfigCombo> configs,
61+
List<User> testUsers) {
6462
return configs.stream()
6563
.map(config -> {
66-
log.debug("설정 평가 시작: {}", config.getName());
67-
EvaluationResult result = evaluateConfig(config, testUsers);
68-
log.debug("설정 평가 완료: {} - Recall={}, nDCG={}, ILD={}",
64+
log.debug("설정 평가 시작 (Train/Test Split): {}", config.getName());
65+
ImprovedEvaluationResult result = evaluateConfigWithTrainTestSplit(config, testUsers);
66+
log.debug("설정 평가 완료 (Train/Test Split): {} - Recall={}, nDCG={}, ILD={}",
6967
config.getName(), result.getAvgRecall(), result.getAvgNdcg(), result.getAvgIld());
70-
printResult(config.getName(), result);
68+
log.info(result.toString());
7169
return result;
7270
})
7371
.toList();
7472
}
7573

7674
/**
77-
* 최고 성능 설정 출력
75+
* 가중치 타입별 최고 성능 설정 출력 (Train/Test Split)
7876
*/
79-
private void printBestResult(List<EvaluationResult> results) {
80-
EvaluationResult best = results.stream()
81-
.max((a, b) -> Double.compare(a.getOverallScore(), b.getOverallScore()))
82-
.orElseThrow();
77+
private void printBestImprovedResultByWeightType(List<ImprovedEvaluationResult> results) {
78+
log.info("\n===== 가중치 타입별 최고 성능 설정 (Train/Test Split) =====");
8379

84-
log.info("\n===== 최고 성능 설정 =====");
85-
printResult(best.getConfigName(), best);
80+
// 컨텐츠 중심
81+
ImprovedEvaluationResult bestContent = results.stream()
82+
.filter(r -> r.getConfigName().startsWith("컨텐츠중심"))
83+
.max(Comparator.comparingDouble(ImprovedEvaluationResult::getCompositeScore))
84+
.orElse(null);
85+
if (bestContent != null) {
86+
log.info("\n[컨텐츠 중심 최고]");
87+
log.info(bestContent.toString());
88+
}
8689
}
8790
}

0 commit comments

Comments
 (0)