@@ -3,49 +3,89 @@ const sharp = require('sharp');
3
3
const fs = require ( 'fs' ) ;
4
4
const path = require ( 'path' ) ;
5
5
6
- // Load content and style images (using sharp to read images).
6
+ // Load content image with size limits
7
7
async function loadImage ( imagePath ) {
8
8
const imageBuffer = imagePath . buffer ;
9
- const image = await sharp ( imageBuffer ) . removeAlpha ( ) . raw ( ) . toBuffer ( ) ;
10
- const width = ( await sharp ( imageBuffer ) . metadata ( ) ) . width ;
11
- const height = ( await sharp ( imageBuffer ) . metadata ( ) ) . height ;
9
+ const metadata = await sharp ( imageBuffer ) . metadata ( ) ;
10
+
11
+ // Limit maximum dimensions to reduce memory usage
12
+ const MAX_DIM = 1024 ;
13
+ let width = metadata . width ;
14
+ let height = metadata . height ;
15
+
16
+ if ( width > MAX_DIM || height > MAX_DIM ) {
17
+ const scale = Math . min ( MAX_DIM / width , MAX_DIM / height ) ;
18
+ width = Math . round ( width * scale ) ;
19
+ height = Math . round ( height * scale ) ;
20
+ }
21
+
22
+ const image = await sharp ( imageBuffer )
23
+ . resize ( width , height )
24
+ . removeAlpha ( )
25
+ . raw ( )
26
+ . toBuffer ( ) ;
27
+
12
28
const imageTensor = tf . tensor3d ( image , [ height , width , 3 ] ) ;
13
- return imageTensor . div ( tf . scalar ( 255 ) ) ; // Normalize to [0, 1]
29
+ return imageTensor . div ( tf . scalar ( 255 ) ) ;
14
30
}
15
31
16
- // Resize style image to 256x256
32
+ // Resize style image to smaller dimensions
17
33
async function resizeStyleImage ( imagePath ) {
18
- const resizedBuffer = await sharp ( imagePath . buffer ) . removeAlpha ( ) . resize ( 256 , 256 ) . raw ( ) . toBuffer ( ) ;
19
- const styleTensor = tf . tensor3d ( resizedBuffer , [ 256 , 256 , 3 ] ) ;
20
- return styleTensor . div ( tf . scalar ( 255 ) ) ; // Normalize to [0, 1]
34
+ const STYLE_DIM = 256 ; // Reduced from original size if it was larger
35
+
36
+ const resizedBuffer = await sharp ( imagePath . buffer )
37
+ . resize ( STYLE_DIM , STYLE_DIM )
38
+ . removeAlpha ( )
39
+ . raw ( )
40
+ . toBuffer ( ) ;
41
+
42
+ const styleTensor = tf . tensor3d ( resizedBuffer , [ STYLE_DIM , STYLE_DIM , 3 ] ) ;
43
+ return styleTensor . div ( tf . scalar ( 255 ) ) ;
21
44
}
22
45
23
46
async function stylizeImages ( contentImagePath , styleImagePath ) {
24
- // Load content and style images
25
- const contentImage = await loadImage ( contentImagePath ) ;
26
- const styleImage = await resizeStyleImage ( styleImagePath ) ;
27
-
28
- // Load the style transfer model
29
- const modelPath = path . resolve ( __dirname , "./arbitrary-image-stylization-v1-tensorflow1-256-v2" ) ;
30
-
31
- const styleTransferModel = await tf . node . loadSavedModel ( modelPath ) ;
32
-
33
- // Stylize the image
34
- let stylizedImageTensor = await styleTransferModel . predict ( {
35
- placeholder : contentImage . expandDims ( ) ,
36
- placeholder_1 : styleImage . expandDims ( ) ,
37
- } ) ;
38
-
39
- // Save the stylized image
40
- stylizedImageTensor = stylizedImageTensor [ "output_0" ] ;
41
- const unnormal = stylizedImageTensor . mul ( tf . scalar ( 255 ) ) ;
42
-
43
- const stylizedImageData = unnormal . dataSync ( ) ;
44
- const [ height , width , channels ] = stylizedImageTensor . shape . slice ( 1 ) ;
45
-
46
- return await sharp ( Buffer . from ( stylizedImageData ) , {
47
- raw : { width, height, channels } ,
48
- } ) . toFormat ( 'jpeg' ) . toBuffer ( ) ;
47
+ try {
48
+ // Enable memory logging
49
+ tf . engine ( ) . startScope ( ) ;
50
+
51
+ // Load and process images
52
+ const contentImage = await loadImage ( contentImagePath ) ;
53
+ const styleImage = await resizeStyleImage ( styleImagePath ) ;
54
+
55
+ // Load model
56
+ const modelPath = path . resolve ( __dirname , "./arbitrary-image-stylization-v1-tensorflow1-256-v2" ) ;
57
+ const styleTransferModel = await tf . node . loadSavedModel ( modelPath ) ;
58
+
59
+ // Process image
60
+ const stylizedImageTensor = await styleTransferModel . predict ( {
61
+ placeholder : contentImage . expandDims ( ) ,
62
+ placeholder_1 : styleImage . expandDims ( ) ,
63
+ } ) ;
64
+
65
+ // Post-process result
66
+ const output = stylizedImageTensor [ "output_0" ] ;
67
+ const unnormal = output . mul ( tf . scalar ( 255 ) ) ;
68
+ const stylizedImageData = unnormal . dataSync ( ) ;
69
+ const [ height , width , channels ] = output . shape . slice ( 1 ) ;
70
+
71
+ // Clean up tensors
72
+ tf . dispose ( [ contentImage , styleImage , output , unnormal ] ) ;
73
+ tf . engine ( ) . endScope ( ) ;
74
+
75
+ // Convert to JPEG buffer
76
+ return await sharp ( Buffer . from ( stylizedImageData ) , {
77
+ raw : { width, height, channels } ,
78
+ } )
79
+ . jpeg ( { quality : 90 } )
80
+ . toBuffer ( ) ;
81
+
82
+ } catch ( error ) {
83
+ console . error ( 'Style transfer error:' , error ) ;
84
+ throw error ;
85
+ } finally {
86
+ // Ensure memory cleanup
87
+ tf . engine ( ) . disposeVariables ( ) ;
88
+ }
49
89
}
50
90
51
91
module . exports = stylizeImages ;
0 commit comments