Skip to content

Commit 89c37f7

Browse files
committed
fix: training model and build error
1 parent 1d0ff93 commit 89c37f7

File tree

6 files changed

+255
-105
lines changed

6 files changed

+255
-105
lines changed

apps/backend/index.ts

Lines changed: 96 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ app.get("/models", authMiddleware, async (req, res) => {
277277
});
278278

279279
app.post("/fal-ai/webhook/train", async (req, res) => {
280+
console.log("====================Received training webhook====================");
280281
console.log("Received training webhook:", req.body);
281282
const requestId = req.body.request_id as string;
282283

@@ -287,81 +288,122 @@ app.post("/fal-ai/webhook/train", async (req, res) => {
287288
},
288289
});
289290

291+
console.log("Found model:", model);
292+
290293
if (!model) {
291294
console.error("No model found for requestId:", requestId);
292295
res.status(404).json({ message: "Model not found" });
293296
return;
294297
}
295298

296-
console.log("Found model:", model);
297-
298-
const result = await fal.queue.result("fal-ai/flux-lora", {
299-
requestId,
300-
});
301-
302-
console.log("Fal.ai result:", result);
303-
304-
// check if the user has enough credits
305-
const credits = await prismaClient.userCredit.findUnique({
306-
where: {
307-
userId: model.userId,
308-
},
309-
});
310-
311-
console.log("User credits:", credits);
312-
313-
if ((credits?.amount ?? 0) < TRAIN_MODEL_CREDITS) {
314-
console.error("Not enough credits for user:", model.userId);
315-
res.status(411).json({
316-
message: "Not enough credits",
299+
// Handle error case
300+
if (req.body.status === "ERROR") {
301+
console.error("Training error:", req.body.error);
302+
await prismaClient.model.updateMany({
303+
where: {
304+
falAiRequestId: requestId,
305+
},
306+
data: {
307+
trainingStatus: "Failed",
308+
},
309+
});
310+
311+
res.json({
312+
message: "Error recorded",
317313
});
318314
return;
319315
}
320316

321-
try {
322-
// Use type assertion to bypass TypeScript type checking
323-
const resultData = result.data as any;
324-
const loraUrl = resultData.diffusers_lora_file.url;
317+
// Check for both "COMPLETED" and "OK" status
318+
if (req.body.status === "COMPLETED" || req.body.status === "OK") {
319+
try {
320+
// Check if we have payload data directly in the webhook
321+
let loraUrl;
322+
if (req.body.payload && req.body.payload.diffusers_lora_file && req.body.payload.diffusers_lora_file.url) {
323+
// Extract directly from webhook payload
324+
loraUrl = req.body.payload.diffusers_lora_file.url;
325+
console.log("Using lora URL from webhook payload:", loraUrl);
326+
} else {
327+
// Fetch result from fal.ai if not in payload
328+
console.log("Fetching result from fal.ai");
329+
const result = await fal.queue.result("fal-ai/flux-lora-fast-training", {
330+
requestId,
331+
});
332+
console.log("Fal.ai result:", result);
333+
const resultData = result.data as any;
334+
loraUrl = resultData.diffusers_lora_file.url;
335+
}
336+
337+
// check if the user has enough credits
338+
const credits = await prismaClient.userCredit.findUnique({
339+
where: {
340+
userId: model.userId,
341+
},
342+
});
325343

326-
const { imageUrl } = await falAiModel.generateImageSync(loraUrl);
344+
console.log("User credits:", credits);
345+
346+
if ((credits?.amount ?? 0) < TRAIN_MODEL_CREDITS) {
347+
console.error("Not enough credits for user:", model.userId);
348+
res.status(411).json({
349+
message: "Not enough credits",
350+
});
351+
return;
352+
}
353+
354+
console.log("Generating preview image with lora URL:", loraUrl);
355+
const { imageUrl } = await falAiModel.generateImageSync(loraUrl);
356+
357+
console.log("Generated preview image:", imageUrl);
358+
359+
await prismaClient.model.updateMany({
360+
where: {
361+
falAiRequestId: requestId,
362+
},
363+
data: {
364+
trainingStatus: "Generated",
365+
tensorPath: loraUrl,
366+
thumbnail: imageUrl,
367+
},
368+
});
327369

328-
console.log("Generated preview image:", imageUrl);
370+
await prismaClient.userCredit.update({
371+
where: {
372+
userId: model.userId,
373+
},
374+
data: {
375+
amount: { decrement: TRAIN_MODEL_CREDITS },
376+
},
377+
});
329378

379+
console.log("Updated model and decremented credits for user:", model.userId);
380+
} catch (error) {
381+
console.error("Error processing webhook:", error);
382+
await prismaClient.model.updateMany({
383+
where: {
384+
falAiRequestId: requestId,
385+
},
386+
data: {
387+
trainingStatus: "Failed",
388+
},
389+
});
390+
}
391+
} else {
392+
// For any other status, keep it as Pending
393+
console.log("Updating model status to: Pending");
330394
await prismaClient.model.updateMany({
331395
where: {
332396
falAiRequestId: requestId,
333397
},
334398
data: {
335-
trainingStatus: "Generated",
336-
tensorPath: loraUrl,
337-
thumbnail: imageUrl,
338-
},
339-
});
340-
341-
await prismaClient.userCredit.update({
342-
where: {
343-
userId: model.userId,
344-
},
345-
data: {
346-
amount: { decrement: TRAIN_MODEL_CREDITS },
399+
trainingStatus: "Pending",
347400
},
348401
});
349-
350-
console.log(
351-
"Updated model and decremented credits for user:",
352-
model.userId
353-
);
354-
355-
res.json({
356-
message: "Webhook processed successfully",
357-
});
358-
} catch (error) {
359-
console.error("Error processing webhook:", error);
360-
res.status(500).json({
361-
message: "Error processing webhook",
362-
error: error instanceof Error ? error.message : "Unknown error",
363-
});
364402
}
403+
404+
res.json({
405+
message: "Webhook processed successfully",
406+
});
365407
});
366408

367409
app.post("/fal-ai/webhook/image", async (req, res) => {

apps/backend/models/FalAIModel.ts

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,64 @@ import { fal } from "@fal-ai/client";
22
import { BaseModel } from "./BaseModel";
33

44
export class FalAIModel {
5-
constructor() {
6-
7-
}
5+
constructor() {}
86

97
public async generateImage(prompt: string, tensorPath: string) {
10-
const { request_id, response_url } = await fal.queue.submit("fal-ai/flux-lora", {
8+
const { request_id, response_url } = await fal.queue.submit(
9+
"fal-ai/flux-lora",
10+
{
1111
input: {
12-
prompt: prompt,
13-
loras: [{ path: tensorPath, scale: 1 }]
12+
prompt: prompt,
13+
loras: [{ path: tensorPath, scale: 1 }],
1414
},
1515
webhookUrl: `${process.env.WEBHOOK_BASE_URL}/fal-ai/webhook/image`,
16-
});
16+
}
17+
);
1718

1819
return { request_id, response_url };
1920
}
2021

2122
public async trainModel(zipUrl: string, triggerWord: string) {
22-
23-
const { request_id, response_url } = await fal.queue.submit("fal-ai/flux-lora-fast-training", {
23+
console.log("Training model with URL:", zipUrl);
24+
25+
try {
26+
const response = await fetch(zipUrl, { method: "HEAD" });
27+
if (!response.ok) {
28+
console.error(
29+
`ZIP URL not accessible: ${zipUrl}, status: ${response.status}`
30+
);
31+
throw new Error(`ZIP URL not accessible: ${response.status}`);
32+
}
33+
} catch (error) {
34+
console.error("Error checking ZIP URL:", error);
35+
throw new Error(`ZIP URL validation failed: ${error as any}.message}`);
36+
}
37+
38+
const { request_id, response_url } = await fal.queue.submit(
39+
"fal-ai/flux-lora-fast-training",
40+
{
2441
input: {
25-
images_data_url: zipUrl,
26-
trigger_word: triggerWord
42+
images_data_url: zipUrl,
43+
trigger_word: triggerWord,
2744
},
2845
webhookUrl: `${process.env.WEBHOOK_BASE_URL}/fal-ai/webhook/train`,
29-
});
46+
}
47+
);
3048

49+
console.log("Model training submitted:", request_id);
3150
return { request_id, response_url };
3251
}
3352

3453
public async generateImageSync(tensorPath: string) {
3554
const response = await fal.subscribe("fal-ai/flux-lora", {
36-
input: {
37-
prompt: "Generate a head shot for this user in front of a white background",
38-
loras: [{ path: tensorPath, scale: 1 }]
39-
},
40-
})
55+
input: {
56+
prompt:
57+
"Generate a head shot for this user in front of a white background",
58+
loras: [{ path: tensorPath, scale: 1 }],
59+
},
60+
});
4161
return {
42-
imageUrl: response.data.images[0].url
43-
}
62+
imageUrl: response.data.images[0].url,
63+
};
4464
}
4565
}

apps/web/app/config.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
export const BACKEND_URL = process.env.NEXT_PUBLIC_BACKEND_URL || "http://localhost:8080";
2-
export const CLOUDFLARE_URL =
3-
"https://pub-b2acac8ef6a84c39b35165219b664570.r2.dev";
2+
export const CLOUDFLARE_URL = "https://pub-4dedfd170bbc4cf8bcf66357c32ab64c.r2.dev";

apps/web/app/train/page.tsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ export default function Train() {
135135
setBald(!bald)
136136
}} />
137137
</div>
138-
<UploadModal onUploadDone={(zipUrl) => {
139-
setZipUrl(zipUrl)
140-
}} />
138+
<UploadModal handleUpload={(files) => {
139+
setZipUrl(files[0]?.name ?? "")
140+
}} uploadProgress={0} isUploading={false} />
141141
</div>
142142
</CardContent>
143143
<CardFooter className="flex justify-between">

0 commit comments

Comments
 (0)