7575 name = "ceil" , type_signature = op_typing .UNARY_REAL_NUMERIC
7676)
7777
78- abs_op = base_ops .create_unary_op (name = "abs" , type_signature = op_typing .UNARY_NUMERIC )
78+ abs_op = base_ops .create_unary_op (
79+ name = "abs" , type_signature = op_typing .UNARY_NUMERIC_AND_TIMEDELTA
80+ )
7981
80- pos_op = base_ops .create_unary_op (name = "pos" , type_signature = op_typing .UNARY_NUMERIC )
82+ pos_op = base_ops .create_unary_op (
83+ name = "pos" , type_signature = op_typing .UNARY_NUMERIC_AND_TIMEDELTA
84+ )
8185
82- neg_op = base_ops .create_unary_op (name = "neg" , type_signature = op_typing .UNARY_NUMERIC )
86+ neg_op = base_ops .create_unary_op (
87+ name = "neg" , type_signature = op_typing .UNARY_NUMERIC_AND_TIMEDELTA
88+ )
8389
8490exp_op = base_ops .create_unary_op (
8591 name = "exp" , type_signature = op_typing .UNARY_REAL_NUMERIC
@@ -123,6 +129,9 @@ def output_type(self, *input_types):
123129 if left_type is dtypes .TIMEDELTA_DTYPE and dtypes .is_datetime_like (right_type ):
124130 return right_type
125131
132+ if left_type is dtypes .TIMEDELTA_DTYPE and right_type is dtypes .TIMEDELTA_DTYPE :
133+ return dtypes .TIMEDELTA_DTYPE
134+
126135 if (left_type is None or dtypes .is_numeric (left_type )) and (
127136 right_type is None or dtypes .is_numeric (right_type )
128137 ):
@@ -142,32 +151,102 @@ class SubOp(base_ops.BinaryOp):
142151 def output_type (self , * input_types ):
143152 left_type = input_types [0 ]
144153 right_type = input_types [1 ]
145- if (left_type is None or dtypes .is_numeric (left_type )) and (
146- right_type is None or dtypes .is_numeric (right_type )
147- ):
148- # Numeric subtraction
149- return dtypes .coerce_to_common (left_type , right_type )
150154
151155 if dtypes .is_datetime_like (left_type ) and dtypes .is_datetime_like (right_type ):
152156 return dtypes .TIMEDELTA_DTYPE
153157
154158 if dtypes .is_datetime_like (left_type ) and right_type is dtypes .TIMEDELTA_DTYPE :
155159 return left_type
156160
161+ if left_type is dtypes .TIMEDELTA_DTYPE and right_type is dtypes .TIMEDELTA_DTYPE :
162+ return dtypes .TIMEDELTA_DTYPE
163+
164+ if (left_type is None or dtypes .is_numeric (left_type )) and (
165+ right_type is None or dtypes .is_numeric (right_type )
166+ ):
167+ # Numeric subtraction
168+ return dtypes .coerce_to_common (left_type , right_type )
169+
157170 raise TypeError (f"Cannot subtract dtypes { left_type } and { right_type } " )
158171
159172
160173sub_op = SubOp ()
161174
162- mul_op = base_ops .create_binary_op (name = "mul" , type_signature = op_typing .BINARY_NUMERIC )
163175
164- div_op = base_ops . create_binary_op (
165- name = "div" , type_signature = op_typing . BINARY_REAL_NUMERIC
166- )
176+ @ dataclasses . dataclass ( frozen = True )
177+ class MulOp ( base_ops . BinaryOp ):
178+ name : typing . ClassVar [ str ] = "mul"
167179
168- floordiv_op = base_ops .create_binary_op (
169- name = "floordiv" , type_signature = op_typing .BINARY_NUMERIC
170- )
180+ def output_type (self , * input_types : dtypes .ExpressionType ) -> dtypes .ExpressionType :
181+ left_type = input_types [0 ]
182+ right_type = input_types [1 ]
183+
184+ if left_type is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right_type ):
185+ return dtypes .TIMEDELTA_DTYPE
186+ if dtypes .is_numeric (left_type ) and right_type is dtypes .TIMEDELTA_DTYPE :
187+ return dtypes .TIMEDELTA_DTYPE
188+
189+ if (left_type is None or dtypes .is_numeric (left_type )) and (
190+ right_type is None or dtypes .is_numeric (right_type )
191+ ):
192+ return dtypes .coerce_to_common (left_type , right_type )
193+
194+ raise TypeError (f"Cannot multiply dtypes { left_type } and { right_type } " )
195+
196+
197+ mul_op = MulOp ()
198+
199+
200+ @dataclasses .dataclass (frozen = True )
201+ class DivOp (base_ops .BinaryOp ):
202+ name : typing .ClassVar [str ] = "div"
203+
204+ def output_type (self , * input_types : dtypes .ExpressionType ) -> dtypes .ExpressionType :
205+ left_type = input_types [0 ]
206+ right_type = input_types [1 ]
207+
208+ if left_type is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right_type ):
209+ return dtypes .TIMEDELTA_DTYPE
210+
211+ if left_type is dtypes .TIMEDELTA_DTYPE and right_type is dtypes .TIMEDELTA_DTYPE :
212+ return dtypes .FLOAT_DTYPE
213+
214+ if (left_type is None or dtypes .is_numeric (left_type )) and (
215+ right_type is None or dtypes .is_numeric (right_type )
216+ ):
217+ lcd_type = dtypes .coerce_to_common (left_type , right_type )
218+ # Real numeric ops produce floats on int input
219+ return dtypes .FLOAT_DTYPE if lcd_type == dtypes .INT_DTYPE else lcd_type
220+
221+ raise TypeError (f"Cannot divide dtypes { left_type } and { right_type } " )
222+
223+
224+ div_op = DivOp ()
225+
226+
227+ @dataclasses .dataclass (frozen = True )
228+ class FloorDivOp (base_ops .BinaryOp ):
229+ name : typing .ClassVar [str ] = "floordiv"
230+
231+ def output_type (self , * input_types : dtypes .ExpressionType ) -> dtypes .ExpressionType :
232+ left_type = input_types [0 ]
233+ right_type = input_types [1 ]
234+
235+ if left_type is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right_type ):
236+ return dtypes .TIMEDELTA_DTYPE
237+
238+ if left_type is dtypes .TIMEDELTA_DTYPE and right_type is dtypes .TIMEDELTA_DTYPE :
239+ return dtypes .INT_DTYPE
240+
241+ if (left_type is None or dtypes .is_numeric (left_type )) and (
242+ right_type is None or dtypes .is_numeric (right_type )
243+ ):
244+ return dtypes .coerce_to_common (left_type , right_type )
245+
246+ raise TypeError (f"Cannot floor divide dtypes { left_type } and { right_type } " )
247+
248+
249+ floordiv_op = FloorDivOp ()
171250
172251pow_op = base_ops .create_binary_op (name = "pow" , type_signature = op_typing .BINARY_NUMERIC )
173252
0 commit comments