diff --git a/postgresql-simple.cabal b/postgresql-simple.cabal index 6387fa53..cd537b4b 100644 --- a/postgresql-simple.cabal +++ b/postgresql-simple.cabal @@ -74,7 +74,8 @@ Library if !impl(ghc >= 7.6) Build-depends: - ghc-prim + ghc-prim, + tagged >= 0.8 extensions: DoAndIfThenElse, OverloadedStrings, BangPatterns, ViewPatterns TypeOperators diff --git a/src/Database/PostgreSQL/Simple/FromField.hs b/src/Database/PostgreSQL/Simple/FromField.hs index 7f2e56c7..050de123 100644 --- a/src/Database/PostgreSQL/Simple/FromField.hs +++ b/src/Database/PostgreSQL/Simple/FromField.hs @@ -2,6 +2,8 @@ {-# LANGUAGE FlexibleInstances, TypeSynonymInstances #-} {-# LANGUAGE PatternGuards, ScopedTypeVariables #-} {-# LANGUAGE RecordWildCards, TemplateHaskell #-} +{-# LANGUAGE MultiWayIf, DefaultSignatures #-} +{-# LANGUAGE FlexibleContexts #-} {- | Module: Database.PostgreSQL.Simple.FromField @@ -83,6 +85,7 @@ instances use 'typename' instead. module Database.PostgreSQL.Simple.FromField ( FromField(..) + , genericFromField , FieldParser , Conversion() @@ -113,16 +116,19 @@ module Database.PostgreSQL.Simple.FromField #include "MachDeps.h" -import Control.Applicative ( (<|>), (<$>), pure, (*>), (<*) ) +import Control.Applicative ( Alternative(..), (<|>), (<$>), pure, (*>), (<*), liftA2 ) import Control.Concurrent.MVar (MVar, newMVar) import Control.Exception (Exception) import qualified Data.Aeson as JSON import qualified Data.Aeson.Parser as JSON (value') import Data.Attoparsec.ByteString.Char8 hiding (Result) import Data.ByteString (ByteString) +import Data.ByteString.Builder (Builder, toLazyByteString, byteString) import qualified Data.ByteString.Char8 as B +import Data.Char (toLower) import Data.Int (Int16, Int32, Int64) import Data.IORef (IORef, newIORef) +import Data.Proxy (Proxy(..)) import Data.Ratio (Ratio) import Data.Time ( UTCTime, ZonedTime, LocalTime, Day, TimeOfDay ) import Data.Typeable (Typeable, typeOf) @@ -150,6 +156,7 @@ import qualified Data.CaseInsensitive as CI import Data.UUID.Types (UUID) import qualified Data.UUID.Types as UUID import Data.Scientific (Scientific) +import GHC.Generics (Generic, Rep, M1(..), K1(..), D1, C1, S1, Rec0, Constructor, (:*:)(..), to, conName) import GHC.Real (infinity, notANumber) -- | Exception thrown if conversion from a SQL value to a Haskell @@ -188,6 +195,8 @@ type FieldParser a = Field -> Maybe ByteString -> Conversion a -- | A type that may be converted from a SQL type. class FromField a where fromField :: FieldParser a + default fromField :: (Generic a, Typeable a, GFromField (Rep a)) => FieldParser a + fromField = genericFromField (map toLower) -- ^ Convert a SQL value to a Haskell value. -- -- Returns a list of exceptions if the conversion fails. In the case of @@ -292,7 +301,8 @@ instance FromField Null where -- | bool instance FromField Bool where fromField f bs - | typeOid f /= $(inlineTypoid TI.bool) = returnError Incompatible f "" + | typeOid f /= $(inlineTypoid TI.bool) + && typeOid f /= $(inlineTypoid TI.unknown) = returnError Incompatible f "" | bs == Nothing = returnError UnexpectedNull f "" | bs == Just "t" = pure True | bs == Just "f" = pure False @@ -404,9 +414,9 @@ instance FromField (Binary SB.ByteString) where instance FromField (Binary LB.ByteString) where fromField f dat = Binary . LB.fromChunks . (:[]) . unBinary <$> fromField f dat --- | name, text, \"char\", bpchar, varchar +-- | name, text, \"char\", bpchar, varchar, unknown instance FromField ST.Text where - fromField f = doFromField f okText $ (either left pure . ST.decodeUtf8') + fromField f = doFromField f okText' $ (either left pure . ST.decodeUtf8') -- FIXME: check character encoding -- | name, text, \"char\", bpchar, varchar @@ -645,10 +655,93 @@ returnError mkErr f msg = do atto :: forall a. (Typeable a) => Compat -> Parser a -> Field -> Maybe ByteString -> Conversion a -atto types p0 f dat = doFromField f types (go p0) dat +atto types p0 f dat = doFromField f (\t -> types t || (t == $(inlineTypoid TI.unknown))) (go p0) dat where go :: Parser a -> ByteString -> Conversion a go p s = case parseOnly p s of Left err -> returnError ConversionFailed f err Right v -> pure v + + +-- | Type class for default implementation of FromField using generics. +class GFromField f where + gfromField :: (Typeable p) + => Proxy p + -> (String -> String) + -> Field + -> [Maybe ByteString] + -> Conversion (f p) + +instance (GFromField f) => GFromField (D1 i f) where + gfromField w t f v = M1 <$> gfromField w t f v + +instance (GFromField f, Typeable f, Constructor i) => GFromField (C1 i f) where + gfromField w t f (v:[]) = let + tname = B8.pack . t . conName $ (undefined::(C1 i f t)) + tcheck = (\t -> t /= "record" && t /= tname) + in tcheck <$> typename f >>= \b -> M1 <$> case b of + True -> returnError Incompatible f "" + False -> maybe + (returnError UnexpectedNull f "") + (either + (returnError ConversionFailed f) + (gfromField w t f) + . (parseOnly record)) v + gfromField _ _ f _ = M1 <$> returnError ConversionFailed f errUnexpectedArgs + +instance (GFromField f, Typeable f, GFromField g) => GFromField (f :*: g) where + gfromField _ _ f [] = liftA2 (:*:) (returnError ConversionFailed f errTooFewValues) empty + gfromField w t f (v:vs) = liftA2 (:*:) (gfromField w t f [v]) (gfromField w t f vs) + +instance (GFromField f, Typeable f) => GFromField (S1 i f) where + gfromField _ _ f [] = M1 <$> returnError ConversionFailed f errTooFewValues + gfromField w t f (v:[]) = M1 <$> gfromField w t f [v] + gfromField _ _ f _ = M1 <$> returnError ConversionFailed f errTooManyValues + +instance (FromField f, Typeable f) => GFromField (Rec0 f) where + gfromField _ _ f [v] = K1 <$> fromField (f {typeOid = typoid TI.unknown}) v + gfromField _ _ f _ = K1 <$> returnError ConversionFailed f errUnexpectedArgs + + +-- | Common error messages for GFromField instances. +errTooFewValues, errTooManyValues, errUnexpectedArgs :: String +errTooFewValues = "too few values" +errTooManyValues = "too many values" +errUnexpectedArgs = "unexpected arguments" + +-- | Parser of a postgresql record. +record :: Parser [Maybe ByteString] +record = (char '(') *> (recordField `sepBy` (char ',')) <* (char ')') + +-- | Parser of a postgresql record's field. +recordField :: Parser (Maybe ByteString) +recordField = (Just <$> quotedString) <|> (Just <$> unquotedString) <|> (pure Nothing) where + quotedString = unescape <$> (char '"' *> scan False updateState) where + updateState isBalanced c = if + | c == '"' -> Just . not $ isBalanced + | not isBalanced -> Just False + | c == ',' || c == ')' -> Nothing + | otherwise -> fail $ "unexpected symbol: " ++ [c] + + unescape = unescape' '\\' . unescape' '"' . B8.init where + unescape' c = halve c (byteString SB.empty) . groupByChar c + + groupByChar c = B8.groupBy $ \a b -> (a == c) == (b == c) + + halve :: Char -> Builder -> [ByteString] -> ByteString + halve _ b [] = LB.toStrict . toLazyByteString $ b + halve c b (s:ss) = halve c (b <> b') ss where + b' = if + | (/= c) . B8.head $ s -> byteString s + | otherwise -> byteString . SB.take ((SB.length s) `div` 2) $ s + + unquotedString = takeWhile1 (\c -> c /= ',' && c /= ')') + +-- | Function that creates fromField for a given type. +genericFromField :: forall a. (Generic a, Typeable a, GFromField (Rep a)) + => (String -> String) -- ^ How to transform constructor's name to match + -- postgresql type's name. + -> FieldParser a +genericFromField t f v = (to <$> (gfromField (Proxy :: Proxy a) t f [v])) + diff --git a/src/Database/PostgreSQL/Simple/ToField.hs b/src/Database/PostgreSQL/Simple/ToField.hs index cf1ace18..4e8bf2c4 100644 --- a/src/Database/PostgreSQL/Simple/ToField.hs +++ b/src/Database/PostgreSQL/Simple/ToField.hs @@ -1,5 +1,6 @@ {-# LANGUAGE CPP, DeriveDataTypeable, DeriveFunctor #-} {-# LANGUAGE FlexibleInstances, TypeSynonymInstances #-} +{-# LANGUAGE DefaultSignatures, FlexibleContexts #-} ------------------------------------------------------------------------------ -- | @@ -39,6 +40,7 @@ import Data.Word (Word, Word8, Word16, Word32, Word64) import {-# SOURCE #-} Database.PostgreSQL.Simple.ToRow import Database.PostgreSQL.Simple.Types import Database.PostgreSQL.Simple.Compat (toByteString) +import GHC.Generics (Generic, Rep, D1, C1, S1, (:*:)(..), Rec0, from, unM1, unK1) import qualified Data.ByteString as SB import qualified Data.ByteString.Lazy as LB @@ -92,6 +94,8 @@ instance Show Action where -- | A type that may be used as a single parameter to a SQL query. class ToField a where toField :: a -> Action + default toField :: (Generic a, GToField (Rep a)) => a -> Action + toField = head . gtoField . from -- ^ Prepare a value for substitution into a query string. instance ToField Action where @@ -369,3 +373,26 @@ instance ToRow a => ToField (Values a) where (litC ',') rest vals + +-- Type class for default implementation of ToField using generics. +class GToField f where + gtoField :: f p -> [Action] + +instance GToField f => GToField (D1 i f) where + gtoField = gtoField . unM1 + +instance GToField f => GToField (C1 i f) where + gtoField = (:[]) . Many . tupleWrap . gtoField . unM1 + +instance (GToField f, GToField g) => GToField (f :*: g) where + gtoField (f :*: g) = gtoField f ++ gtoField g + +instance (GToField f) => GToField (S1 i f) where + gtoField = gtoField . unM1 + +instance (ToField f) => GToField (Rec0 f) where + gtoField = (:[]) . toField . unK1 + +tupleWrap :: [Action] -> [Action] +tupleWrap xs = (Plain "("): (intersperse (Plain ",") xs) ++ [Plain ")"] + diff --git a/test/Main.hs b/test/Main.hs index d71a9d0a..e92a6a0f 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -3,11 +3,15 @@ {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DoAndIfThenElse #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE QuasiQuotes #-} + import Common import Database.PostgreSQL.Simple.FromField (FromField) -import Database.PostgreSQL.Simple.Types(Query(..),Values(..)) +import Database.PostgreSQL.Simple.ToField (ToField) +import Database.PostgreSQL.Simple.Types (Query(..), Values(..)) import Database.PostgreSQL.Simple.HStore import Database.PostgreSQL.Simple.Copy +import Database.PostgreSQL.Simple.SqlQQ (sql) import qualified Database.PostgreSQL.Simple.Transaction as ST import Control.Applicative @@ -42,25 +46,28 @@ tests :: TestEnv -> TestTree tests env = testGroup "tests" $ map ($ env) [ testBytea - , testCase "ExecuteMany" . testExecuteMany - , testCase "Fold" . testFold - , testCase "Notify" . testNotify - , testCase "Serializable" . testSerializable - , testCase "Time" . testTime - , testCase "Array" . testArray - , testCase "Array of nullables" . testNullableArray - , testCase "HStore" . testHStore - , testCase "JSON" . testJSON - , testCase "Savepoint" . testSavepoint - , testCase "Unicode" . testUnicode - , testCase "Values" . testValues - , testCase "Copy" . testCopy + , testCase "ExecuteMany" . testExecuteMany + , testCase "Fold" . testFold + , testCase "Notify" . testNotify + , testCase "Serializable" . testSerializable + , testCase "Time" . testTime + , testCase "Array" . testArray + , testCase "Array of nullables" . testNullableArray + , testCase "HStore" . testHStore + , testCase "JSON" . testJSON + , testCase "Savepoint" . testSavepoint + , testCase "Unicode" . testUnicode + , testCase "Values" . testValues + , testCase "Copy" . testCopy , testCopyFailures - , testCase "Double" . testDouble - , testCase "1-ary generic" . testGeneric1 - , testCase "2-ary generic" . testGeneric2 - , testCase "3-ary generic" . testGeneric3 - , testCase "Timeout" . testTimeout + , testCase "Double" . testDouble + , testCase "1-ary generic row" . testGeneric1Row + , testCase "2-ary generic row" . testGeneric2Row + , testCase "3-ary generic row" . testGeneric3Row + , testCase "1-ary generic field" . testGeneric1Field + , testCase "2-ary generic field" . testGeneric2Field + , testCase "3-ary generic field" . testGeneric3Field + , testCase "Timeout" . testTimeout ] testBytea :: TestEnv -> TestTree @@ -406,44 +413,73 @@ testDouble TestEnv{..} = do x @?= (-1 / 0) -testGeneric1 :: TestEnv -> Assertion -testGeneric1 TestEnv{..} = do +testGeneric1Row :: TestEnv -> Assertion +testGeneric1Row TestEnv{..} = do roundTrip conn (Gen1 123) where roundTrip conn x0 = do r <- query conn "SELECT ?::int" (x0 :: Gen1) r @?= [x0] -testGeneric2 :: TestEnv -> Assertion -testGeneric2 TestEnv{..} = do +testGeneric2Row :: TestEnv -> Assertion +testGeneric2Row TestEnv{..} = do roundTrip conn (Gen2 123 "asdf") where roundTrip conn x0 = do r <- query conn "SELECT ?::int, ?::text" x0 r @?= [x0] -testGeneric3 :: TestEnv -> Assertion -testGeneric3 TestEnv{..} = do +testGeneric3Row :: TestEnv -> Assertion +testGeneric3Row TestEnv{..} = do roundTrip conn (Gen3 123 "asdf" True) where roundTrip conn x0 = do r <- query conn "SELECT ?::int, ?::text, ?::bool" x0 r @?= [x0] +testGeneric1Field :: TestEnv -> Assertion +testGeneric1Field TestEnv{..} = withTransaction conn $ do + -- It's not possible to simply roundtrip a 1-ary tuple + -- as PostgreSQL will treat it as a scalar value. + -- Therefore we will create a separate type for it. + execute_ conn "CREATE TYPE gen1 AS (x bigint)" + execute_ conn [sql| + CREATE FUNCTION test_gen1() RETURNS SETOF gen1 AS $$ + (SELECT 1::bigint) UNION ALL (SELECT 2) UNION ALL (SELECT 3) + $$ LANGUAGE sql + |] + query_ conn "SELECT test_gen1()" >>= (@?= [Only (Gen1 1), Only (Gen1 2), Only (Gen1 3)]) + rollback conn + +testGeneric2Field :: TestEnv -> Assertion +testGeneric2Field TestEnv{..} = roundTripField conn (Gen2 123 "asdf") + +testGeneric3Field :: TestEnv -> Assertion +testGeneric3Field TestEnv{..} = roundTripField conn (Gen3 123 "asdf" True) + +roundTripField :: (Show a, Eq a, FromField a, ToField a) => Connection -> a -> Assertion +roundTripField conn x0 = query conn "SELECT ?" (Only x0) >>= (@?= [Only x0]) + data Gen1 = Gen1 Int - deriving (Show,Eq,Generic) -instance FromRow Gen1 -instance ToRow Gen1 + deriving (Show, Eq, Generic, Typeable) +instance FromRow Gen1 +instance ToRow Gen1 +instance FromField Gen1 +instance ToField Gen1 data Gen2 = Gen2 Int Text - deriving (Show,Eq,Generic) -instance FromRow Gen2 -instance ToRow Gen2 + deriving (Show, Eq, Generic, Typeable) +instance FromRow Gen2 +instance ToRow Gen2 +instance FromField Gen2 +instance ToField Gen2 data Gen3 = Gen3 Int Text Bool - deriving (Show,Eq,Generic) -instance FromRow Gen3 -instance ToRow Gen3 + deriving (Show, Eq, Generic, Typeable) +instance FromRow Gen3 +instance ToRow Gen3 +instance FromField Gen3 +instance ToField Gen3 data TestException = TestException