66import json
77import math
88import random
9+ from enum import Enum , IntEnum
910from io import BytesIO
10- from typing import Optional , Union , Literal
11+ from typing import Optional , Union
1112from zipfile import ZipFile
1213
1314import httpx
2223from ...utils import try_jsonfy , NovelAiMetadata
2324
2425
26+ class Sampler (Enum ):
27+ K_EULER = "k_euler"
28+ K_EULER_ANCESTRAL = "k_euler_ancestral"
29+ K_DPMPP_2S_ANCESTRAL = "k_dpmpp_2s_ancestral"
30+ K_DPMPP_2M = "k_dpmpp_2m"
31+ K_DPMPP_SDE = "k_dpmpp_sde"
32+ DDIM_V3 = "ddim_v3"
33+
34+
35+ class NoiseSchedule (Enum ):
36+ NATIVE = "native"
37+ KARRAS = "karras"
38+ EXPONENTIAL = "exponential"
39+ POLYEXPONENTIAL = "polyexponential"
40+
41+
42+ class UCPreset (IntEnum ):
43+ TYPE0 = 0
44+ TYPE1 = 1
45+ TYPE2 = 2
46+ TYPE3 = 3
47+
48+
49+ class Action (Enum ):
50+ GENERATE = "generate"
51+ IMG2IMG = "img2img"
52+ INFILL = "infill"
53+
54+
55+ class Model (Enum ):
56+ NAI_DIFFUSION_3 = "nai-diffusion-3"
57+ NAI_DIFFUSION_3_INPAINTING = "nai-diffusion-3-inpainting"
58+
59+
60+ class Resolution (Enum ):
61+ RES_512_768 = (512 , 768 )
62+ RES_768_512 = (768 , 512 )
63+ RES_640_640 = (640 , 640 )
64+ RES_832_1216 = (832 , 1216 )
65+ RES_1216_832 = (1216 , 832 )
66+ RES_1024_1024 = (1024 , 1024 )
67+ RES_1024_1536 = (1024 , 1536 )
68+ RES_1536_1024 = (1536 , 1024 )
69+ RES_1472_1472 = (1472 , 1472 )
70+ RES_1088_1920 = (1088 , 1920 )
71+ RES_1920_1088 = (1920 , 1088 )
72+
73+
2574class GenerateImageInfer (BaseModel ):
2675 _endpoint : Optional [str ] = PrivateAttr ("https://api.novelai.net" )
2776 _charge : bool = PrivateAttr (False )
@@ -44,15 +93,15 @@ class Params(BaseModel):
4493
4594 n_samples : Optional [int ] = Field (1 , ge = 1 , le = 8 )
4695 negative_prompt : Optional [str ] = ''
47- noise_schedule : Optional [Union [ str , Literal [ 'native' , 'polyexponential' , 'exponential' ]]] = "native"
96+ noise_schedule : Optional [NoiseSchedule ] = NoiseSchedule . NATIVE
4897
4998 # Misc
5099 params_version : Optional [int ] = 1
51100 legacy : Optional [bool ] = False
52101 legacy_v3_extend : Optional [bool ] = False
53102
54103 qualityToggle : Optional [bool ] = True
55- sampler : Optional [str ] = "k_euler"
104+ sampler : Optional [Sampler ] = Sampler . K_EULER
56105 scale : Optional [float ] = Field (6.0 , ge = 0 , le = 10 , multiple_of = 0.1 )
57106 # Seed
58107 seed : Optional [int ] = Field (
@@ -69,7 +118,7 @@ class Params(BaseModel):
69118 sm : Optional [bool ] = False
70119 sm_dyn : Optional [bool ] = False
71120 steps : Optional [int ] = Field (28 , ge = 1 , le = 50 )
72- ucPreset : Optional [Literal [ 0 , 1 , 2 , 3 ] ] = 0
121+ ucPreset : Optional [UCPreset ] = 0
73122 uncond_scale : Optional [float ] = Field (1.0 , ge = 0 , le = 1.5 , multiple_of = 0.05 )
74123 width : Optional [int ] = Field (832 , ge = 64 , le = 49152 )
75124
@@ -81,25 +130,6 @@ def validate_img2img(self):
81130 raise ValueError ('Invalid Model Params For img2img2 mode... image should match add_original_image!' )
82131 return self
83132
84- @field_validator ('sampler' )
85- def sampler_validator (cls , v : str ):
86- if v not in ["k_euler" , "k_euler_ancestral" , 'k_dpmpp_2s_ancestral' , "k_dpmpp_2m" , "k_dpmpp_sde" ,
87- "ddim_v3" ]:
88- raise ValueError ("Invalid sampler." )
89- return v
90-
91- @field_validator ('noise_schedule' )
92- def noise_schedule_validator (cls , v : str ):
93- if v not in ["native" , "karras" , "exponential" , "polyexponential" ]:
94- raise ValueError ("Invalid noise_schedule." )
95- return v
96-
97- @field_validator ('steps' )
98- def steps_validator (cls , v : int ):
99- if v > 28 :
100- logger .warning (f"steps { v } > 28, maybe charge more." )
101- return v
102-
103133 @field_validator ('width' )
104134 def width_validator (cls , v : int ):
105135 """
@@ -122,9 +152,9 @@ def height_validator(cls, v: int):
122152 raise ValueError ("Invalid height, must be multiple of 64." )
123153 return v
124154
125- action : Union [str , Literal [ "generate" , "img2img" , "infill" ]] = " generate"
155+ action : Union [str , Action ] = Field ( Action . GENERATE , description = "Mode for img generate")
126156 input : str = "1girl, best quality, amazing quality, very aesthetic, absurdres"
127- model : Optional [str ] = "nai-diffusion-3"
157+ model : Optional [Model ] = "nai-diffusion-3"
128158 parameters : Params = Params ()
129159 model_config = ConfigDict (extra = "ignore" )
130160
@@ -170,18 +200,6 @@ def endpoint(self):
170200 def endpoint (self , value ):
171201 self ._endpoint = value
172202
173- @staticmethod
174- def valid_wh ():
175- """
176- 宽高
177- :return:
178- """
179- return [
180- (832 , 1216 ),
181- (1216 , 832 ),
182- (1024 , 1024 ),
183- ]
184-
185203 def calculate_cost (self , is_opus : bool = False ):
186204 """
187205 Calculate the Anlas cost of current parameters.
@@ -227,17 +245,18 @@ def calculate_cost(self, is_opus: bool = False):
227245 def build (cls ,
228246 prompt : str ,
229247 * ,
230- model : str = "nai-diffusion-3" ,
231- action : Literal [ 'generate' , 'img2img' ] = 'generate' ,
248+ model : Union [ Model , str ] = "nai-diffusion-3" ,
249+ action : Union [ Action , str ] = 'generate' ,
232250 negative_prompt : str = "" ,
233251 seed : int = None ,
234252 steps : int = 28 ,
235253 cfg_rescale : int = 0 ,
236- sampler : str = "k_euler" ,
254+ sampler : Union [ Sampler , str ] = Sampler . K_EULER ,
237255 width : int = 832 ,
238256 height : int = 1216 ,
239257 qualityToggle : bool = True ,
240- ucPreset : int = 0 ,
258+ ucPreset : Union [UCPreset , int ] = UCPreset .TYPE0 ,
259+ ** kwargs
241260 ):
242261 """
243262 正负面, step, cfg, 采样方式, seed
@@ -257,7 +276,7 @@ def build(cls,
257276 """
258277 assert isinstance (prompt , str )
259278 _negative_prompt = negative_prompt
260- param = {
279+ kwargs . update ( {
261280 "negative_prompt" : _negative_prompt ,
262281 "seed" : seed ,
263282 "steps" : steps ,
@@ -267,9 +286,9 @@ def build(cls,
267286 "height" : height ,
268287 "qualityToggle" : qualityToggle ,
269288 "ucPreset" : ucPreset ,
270- }
289+ })
271290 # 清理空值
272- param = {k : v for k , v in param .items () if v is not None }
291+ param = {k : v for k , v in kwargs .items () if v is not None }
273292 return cls (
274293 input = prompt ,
275294 model = model ,
@@ -288,7 +307,7 @@ async def generate(self, session: Union[AsyncSession, "CredentialBase"],
288307 """
289308 if isinstance (session , CredentialBase ):
290309 session = await session .get_session ()
291- request_data = self .model_dump (exclude_none = True )
310+ request_data = self .model_dump (mode = "json" , exclude_none = True )
292311 logger .debug (f"Request Data: { request_data } " )
293312 try :
294313 assert hasattr (session , "post" ), "session must have post method."
0 commit comments