|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +from transformers.configuration_utils import PretrainedConfig |
15 | 16 | from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( |
16 | | - Qwen2_5_VLConfig, |
17 | 17 | Qwen2_5_VLTextConfig, |
18 | 18 | Qwen2_5_VLVisionConfig, |
19 | 19 | ) |
20 | 20 |
|
21 | 21 |
|
22 | | -class EO1VisionVLTextConfig(Qwen2_5_VLTextConfig): |
23 | | - def __init__( |
24 | | - self, |
25 | | - state_token_id=None, |
26 | | - action_token_start_id=None, |
27 | | - action_token_id=None, |
28 | | - action_pass_id=None, |
29 | | - vision_token_start_id=None, |
30 | | - image_token_id=None, |
31 | | - video_token_id=None, |
32 | | - **kwargs, |
33 | | - ): |
34 | | - super().__init__(**kwargs) |
35 | | - self.state_token_id = state_token_id |
36 | | - self.action_token_start_id = action_token_start_id |
37 | | - self.action_token_id = action_token_id |
38 | | - self.action_pass_id = action_pass_id |
39 | | - |
40 | | - self.vision_token_start_id = vision_token_start_id |
41 | | - self.image_token_id = image_token_id |
42 | | - self.video_token_id = video_token_id |
43 | | - |
44 | | - |
45 | | -class EO1VisionFlowMatchingConfig(Qwen2_5_VLConfig): |
46 | | - model_type = "onevision_fm" |
47 | | - sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": EO1VisionVLTextConfig} |
| 22 | +class EO1VisionFlowMatchingConfig(PretrainedConfig): |
| 23 | + model_type = "eo1" |
| 24 | + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig} |
| 25 | + keys_to_ignore_at_inference = ["past_key_values"] |
48 | 26 |
|
49 | 27 | def __init__( |
50 | 28 | self, |
51 | 29 | text_config=None, |
52 | 30 | vision_config=None, |
53 | 31 | image_token_id=151655, |
54 | 32 | video_token_id=151656, |
55 | | - # flow matching specific |
56 | 33 | action_chunk_size=50, |
57 | 34 | max_action_dim=32, |
58 | 35 | num_denoise_steps=10, |
59 | 36 | action_act="linear", |
60 | 37 | num_action_layers=2, |
| 38 | + state_token_id=151670, |
| 39 | + action_token_id=151666, |
| 40 | + action_pass_id=151667, |
61 | 41 | **kwargs, |
62 | 42 | ): |
63 | | - super().__init__( |
64 | | - text_config=text_config, |
65 | | - vision_config=vision_config, |
66 | | - image_token_id=image_token_id, |
67 | | - video_token_id=video_token_id, |
68 | | - **kwargs, |
69 | | - ) |
| 43 | + if isinstance(vision_config, dict): |
| 44 | + self.vision_config = self.sub_configs["vision_config"](**vision_config) |
| 45 | + elif vision_config is None: |
| 46 | + self.vision_config = self.sub_configs["vision_config"]( |
| 47 | + hidden_size=1280, |
| 48 | + out_hidden_size=2048, |
| 49 | + tokens_per_second=2, |
| 50 | + ) |
| 51 | + |
| 52 | + if isinstance(text_config, dict): |
| 53 | + self.text_config = self.sub_configs["text_config"](**text_config) |
| 54 | + elif text_config is None: |
| 55 | + self.text_config = self.sub_configs["text_config"](**kwargs) |
| 56 | + |
| 57 | + self.image_token_id = image_token_id |
| 58 | + self.video_token_id = video_token_id |
| 59 | + self.state_token_id = state_token_id |
| 60 | + self.action_token_id = action_token_id |
| 61 | + self.action_pass_id = action_pass_id |
| 62 | + |
70 | 63 | self.action_chunk_size = action_chunk_size |
71 | 64 | self.max_action_dim = max_action_dim |
72 | 65 | self.num_denoise_steps = num_denoise_steps |
73 | 66 | self.action_act = action_act |
74 | 67 | self.num_action_layers = num_action_layers |
75 | 68 |
|
| 69 | + super().__init__(**kwargs) |
| 70 | + |
76 | 71 |
|
77 | 72 | EO1VisionFlowMatchingConfig.register_for_auto_class() |
0 commit comments