66import json
77import random
88from io import BytesIO
9- from typing import Optional , Union
9+ from typing import Optional , Union , Literal
1010from zipfile import ZipFile
1111
1212import httpx
1313from curl_cffi .requests import AsyncSession
1414from loguru import logger
15- from pydantic import BaseModel , ConfigDict , PrivateAttr , field_validator
15+ from pydantic import BaseModel , ConfigDict , PrivateAttr , field_validator , model_validator
1616
1717from ..._exceptions import APIError , AuthError
1818from ..._response import ImageGenerateResp
@@ -30,6 +30,7 @@ class Params(BaseModel):
3030 controlnet_strength : Optional [int ] = 1
3131 dynamic_thresholding : Optional [bool ] = False
3232 height : Optional [int ] = 1216
33+ image : Optional [str ] = None # img2img,base64
3334 legacy : Optional [bool ] = False
3435 legacy_v3_extend : Optional [bool ] = False
3536 n_samples : Optional [int ] = 1
@@ -41,7 +42,7 @@ class Params(BaseModel):
4142 "extra digit, fewer digits, cropped, worst quality, low quality, normal quality, "
4243 "jpeg artifacts, signature, watermark, username, blurry"
4344 )
44- noise_schedule : Optional [str ] = "native"
45+ noise_schedule : Optional [Union [ str , Literal [ 'native' , 'polyexponential' , 'exponential' ]] ] = "native"
4546 params_version : Optional [int ] = 1
4647 qualityToggle : Optional [bool ] = True
4748 sampler : Optional [str ] = "k_euler"
@@ -60,15 +61,25 @@ def seed_validator(cls, v: int):
6061 v = random .randint (0 , 2 ** 32 - 1 )
6162 return v
6263
64+ @model_validator (mode = "after" )
65+ def validate_img2img (self ):
66+
67+ image = True if self .image else False
68+ add_origin = True if self .add_original_image else False
69+ if image != add_origin :
70+ raise ValueError ('Invalid Model Params For img2img2 mode... image should match add_original_image!' )
71+ return self
72+
6373 @field_validator ('sampler' )
6474 def sampler_validator (cls , v : str ):
65- if v not in ["k_euler" , "k_euler_ancestral" , "k_dpmpp_2m" , "k_dpmpp_sde" , "ddim_v3" ]:
75+ if v not in ["k_euler" , "k_euler_ancestral" , 'k_dpmpp_2s_ancestral' , "k_dpmpp_2m" , "k_dpmpp_sde" ,
76+ "ddim_v3" ]:
6677 raise ValueError ("Invalid sampler." )
6778 return v
6879
6980 @field_validator ('noise_schedule' )
7081 def noise_schedule_validator (cls , v : str ):
71- if v not in ["native" , "exponential" , "polyexponential" ]:
82+ if v not in ["native" , "karras" , " exponential" , "polyexponential" ]:
7283 raise ValueError ("Invalid noise_schedule." )
7384 return v
7485
@@ -111,7 +122,7 @@ def n_samples_validator(cls, v: int):
111122 raise ValueError ("Invalid n_samples, must be less than 8." )
112123 return v
113124
114- action : Optional [str ] = "generate"
125+ action : Union [str , Literal [ "generate" , "img2img" ] ] = "generate"
115126 input : str = "1girl, best quality, amazing quality, very aesthetic, absurdres"
116127 model : Optional [str ] = "nai-diffusion-3"
117128 parameters : Params = Params ()
@@ -162,6 +173,8 @@ def validate_charge(self):
162173 def build (cls ,
163174 prompt : str ,
164175 * ,
176+ model : str = "nai-diffusion-3" ,
177+ action : Literal ['generate' , 'img2img' ] = 'generate' ,
165178 negative_prompt : Optional [str ] = None ,
166179 override_negative_prompt : bool = False ,
167180 seed : int = - 1 ,
@@ -175,6 +188,8 @@ def build(cls,
175188 正负面, step, cfg, 采样方式, seed
176189 :param override_negative_prompt:
177190 :param prompt:
191+ :param model:
192+ :param action: Mode for img generate
178193 :param negative_prompt:
179194 :param seed:
180195 :param steps:
@@ -209,13 +224,15 @@ def build(cls,
209224 param = {k : v for k , v in param .items () if v is not None }
210225 return cls (
211226 input = prompt ,
227+ model = model ,
228+ action = action ,
212229 parameters = cls .Params (** param )
213230 )
214231
215232 async def generate (self , session : Union [AsyncSession , JwtCredential ]) -> ImageGenerateResp :
216233 if isinstance (session , JwtCredential ):
217234 session = session .session
218- request_data = self .model_dump ()
235+ request_data = self .model_dump (exclude_none = True )
219236 logger .debug (f"Request Data: { request_data } " )
220237 try :
221238 assert hasattr (session , "post" ), "session must have post method."
0 commit comments