2626from __future__ import annotations
2727
2828from dataclasses import dataclass
29- from typing import Optional
29+
3030
3131import httpx
3232import numpy as np
@@ -62,10 +62,11 @@ def __post_init__(self) -> None:
6262
6363 @staticmethod
6464 def _add_intercept (features : np .ndarray ) -> np .ndarray :
65- if features .ndim != 2 :
65+ arr = np .asarray (features , dtype = float )
66+ if arr .ndim != 2 :
6667 raise ValueError ("features must be a 2D array" )
67- n_samples = features .shape [0 ]
68- return np .c_ [np .ones (n_samples ), features ]
68+ n_samples = arr .shape [0 ]
69+ return np .c_ [np .ones (n_samples ), arr ]
6970
7071 def fit (
7172 self , features : np .ndarray , target : np .ndarray , add_intercept : bool = True
@@ -81,6 +82,9 @@ def fit(
8182 add_intercept: bool
8283 If True the model will add a bias column of ones to `features`.
8384 """
85+ features = np .asarray (features , dtype = float )
86+ target = np .asarray (target , dtype = float )
87+
8488 if features .ndim != 2 :
8589 raise ValueError ("features must be a 2D array" )
8690 if target .ndim != 1 :
@@ -119,6 +123,8 @@ def predict(self, features: np.ndarray, add_intercept: bool = True) -> np.ndarra
119123 """
120124 if self .weights is None :
121125 raise ValueError ("Model is not trained" )
126+
127+ features = np .asarray (features , dtype = float )
122128 x = features if not add_intercept else self ._add_intercept (features )
123129 return x @ self .weights
124130
0 commit comments