From 9d6c719cbc168fb70a3b1d40adc03e4159b9eca8 Mon Sep 17 00:00:00 2001 From: OMGeeky Date: Sat, 15 Apr 2023 14:37:23 +0200 Subject: [PATCH] implement saving None values --- google_bigquery_v2_derive/src/lib.rs | 8 +- src/data/bigquery_table.rs | 60 +++++----- src/data/query_builder.rs | 164 +++++++++++++++++---------- tests/tests.rs | 30 +++-- 4 files changed, 155 insertions(+), 107 deletions(-) diff --git a/google_bigquery_v2_derive/src/lib.rs b/google_bigquery_v2_derive/src/lib.rs index a742246..9f68db7 100644 --- a/google_bigquery_v2_derive/src/lib.rs +++ b/google_bigquery_v2_derive/src/lib.rs @@ -77,7 +77,7 @@ fn implement_get_all_params(ast: &DeriveInput, table_ident: &Ident) -> TokenStre let field_ident = f.field_ident; let field_name = f.local_name; quote::quote! { - #table_ident::get_parameter(&self.#field_ident, &#table_ident::get_field_param_name(&#field_name.to_string())?)? + #table_ident::get_parameter(&self.#field_ident, &#table_ident::get_field_param_name(&#field_name.to_string())?) } } let table_ident = &ast.ident; @@ -87,7 +87,7 @@ fn implement_get_all_params(ast: &DeriveInput, table_ident: &Ident) -> TokenStre .map(|f| get_param_from_field(f, &table_ident)); quote::quote! { - fn get_all_params(&self) -> google_bigquery_v2::prelude::Result> { + fn get_all_params(&self) -> google_bigquery_v2::prelude::Result>> { log::trace!("get_all_params() self:{:?}", self); Ok(vec![ #(#fields),* @@ -101,7 +101,7 @@ fn implement_get_parameter_from_field(ast: &DeriveInput, table_ident: &Ident) -> let field_ident = f.field_ident; let field_name = f.local_name; quote::quote! { - #field_name => #table_ident::get_parameter(&self.#field_ident, &#table_ident::get_field_param_name(&#field_name.to_string())?), + #field_name => Ok(#table_ident::get_parameter(&self.#field_ident, &#table_ident::get_field_param_name(&#field_name.to_string())?)), } } let table_ident = &ast.ident; @@ -111,7 +111,7 @@ fn implement_get_parameter_from_field(ast: &DeriveInput, table_ident: &Ident) -> .map(|f| get_param_from_field(f, &table_ident)); quote::quote! { - fn get_parameter_from_field(&self, field_name: &str) -> google_bigquery_v2::prelude::Result { + fn get_parameter_from_field(&self, field_name: &str) -> google_bigquery_v2::prelude::Result> { log::trace!("get_parameter_from_field(); field_name: '{}' self:{:?}", field_name, self); match field_name { #(#fields)* diff --git a/src/data/bigquery_table.rs b/src/data/bigquery_table.rs index 9970f2c..503d163 100644 --- a/src/data/bigquery_table.rs +++ b/src/data/bigquery_table.rs @@ -3,15 +3,15 @@ use std::fmt::{Debug, Display, Formatter}; use std::marker::PhantomData; use async_trait::async_trait; +pub use google_bigquery2::api::{QueryParameterType, QueryParameterValue}; pub use google_bigquery2::api::QueryParameter; use google_bigquery2::api::QueryRequest; -pub use google_bigquery2::api::{QueryParameterType, QueryParameterValue}; use log::debug; use log::trace; use serde_json::Value; use crate::client::BigqueryClient; -use crate::data::param_conversion::{convert_value_to_string, BigDataValueType}; +use crate::data::param_conversion::{BigDataValueType, convert_value_to_string}; use crate::data::query_builder::{ NoClient, NoStartingData, QueryBuilder, QueryResultType, QueryTypeInsert, QueryTypeNoType, QueryTypeSelect, QueryTypeUpdate, QueryWasNotBuilt, @@ -20,8 +20,8 @@ use crate::prelude::*; #[async_trait] pub trait BigQueryTableBase { - fn get_all_params(&self) -> Result>; - fn get_parameter_from_field(&self, field_name: &str) -> Result; + fn get_all_params(&self) -> Result>>; + fn get_parameter_from_field(&self, field_name: &str) -> Result>; //region get infos /// Returns the name of the table in the database. fn get_table_name() -> String; @@ -53,12 +53,11 @@ pub trait BigQueryTableBase { client: BigqueryClient, row: &HashMap, ) -> Result - where - Self: Sized; + where + Self: Sized; //region update - //TODO: fn update(&mut self) -> Result<()>; //TODO: fn delete(&mut self) -> Result<()>; //endregion @@ -69,26 +68,26 @@ pub trait BigQueryTableBase { #[async_trait] pub trait BigQueryTable: BigQueryTableBase { fn select() -> QueryBuilder - where - Self: Sized, + where + Self: Sized, { QueryBuilder::::select() } fn insert() -> QueryBuilder - where - Self: Sized, + where + Self: Sized, { QueryBuilder::::insert() } fn update() -> QueryBuilder - where - Self: Sized, + where + Self: Sized, { QueryBuilder::::update() } - fn get_parameter(value: &T, param_name: &String) -> Result - where - T: BigDataValueType + Debug, + fn get_parameter(value: &T, param_name: &String) -> Option + where + T: BigDataValueType + Debug, { trace!("get_parameter({:?}, {})", value, param_name); let value = value.to_param(); @@ -106,12 +105,7 @@ pub trait BigQueryTable: BigQueryTableBase { value: Some(param_value), ..Default::default() }), - Err(_) => todo!( - "a parameter value probably of sort null is not yet \ - implemented. Does this even make sense or should the code that's \ - calling this react if there is an error returned from this function \ - and modify the where to be 'is null' instead of '== @__PARAM_x'?" - ), + Err(_) => return None, }; debug!("param_value: {:?}", param_value); @@ -120,7 +114,7 @@ pub trait BigQueryTable: BigQueryTableBase { parameter_value: param_value, name: Some(param_name.clone()), }; - Ok(param) + Some(param) } fn get_field_param_name(field_name: &str) -> Result { trace!("get_field_param_name({})", field_name); @@ -153,9 +147,9 @@ pub trait BigQueryTable: BigQueryTableBase { } async fn get_by_pk(client: BigqueryClient, pk_value: &PK) -> Result - where - PK: BigDataValueType + Send + Sync + 'static, - Self: Sized + Debug, + where + PK: BigDataValueType + Send + Sync + 'static, + Self: Sized + Debug, { trace!("get_by_pk({:?}, {:?})", client, pk_value); let pk_field_name = Self::get_pk_field_name(); @@ -173,7 +167,7 @@ pub trait BigQueryTable: BigQueryTableBase { "something went wrong when getting for {} = {:?};\tresult: {:?}", pk_field_name, pk_value, success ) - .into()); + .into()); } }; @@ -184,15 +178,15 @@ pub trait BigQueryTable: BigQueryTableBase { "More than one entry found for {} = {:?}", pk_db_name, pk_value ) - .into()) + .into()) } else { Ok(rows.remove(0)) } } async fn upsert(&mut self) -> Result<()> - where - Self: Sized + Clone + Send + Sync + Debug + Default, + where + Self: Sized + Clone + Send + Sync + Debug + Default, { trace!("upsert()"); @@ -217,8 +211,8 @@ pub trait BigQueryTable: BigQueryTableBase { /// proxy for update async fn save(&mut self) -> Result<()> - where - Self: Sized + Clone + Send + Sync + Debug + Default, + where + Self: Sized + Clone + Send + Sync + Debug + Default, { trace!("save(): {:?}", self); let result = Self::update() @@ -238,7 +232,7 @@ pub trait BigQueryTable: BigQueryTableBase { "save should return empty data, but returned {} rows.", count ) - .into()) + .into()) } } diff --git a/src/data/query_builder.rs b/src/data/query_builder.rs index 0945645..82155a9 100644 --- a/src/data/query_builder.rs +++ b/src/data/query_builder.rs @@ -167,7 +167,7 @@ pub struct QueryBuilder { //region default implementation for QueryBuilder impl Default - for QueryBuilder +for QueryBuilder { fn default() -> Self { Self { @@ -189,7 +189,7 @@ impl Defau //region general QueryBuilder //region functions for all queries impl - QueryBuilder +QueryBuilder { fn get_sorted_selected_fields(&self) -> Vec<(String, String)> { trace!("get_sorted_selected_fields()"); @@ -214,22 +214,26 @@ impl //region functions for not built queries //region with Starting data impl - QueryBuilder> +QueryBuilder> { pub fn add_field_where(self, field: &str) -> Result { trace!("add_field_where(field: {})", field); let field_db_name = Table::get_field_db_name(field)?; let param = Table::get_parameter_from_field(&self.starting_data.0, &field)?; - let has_param_value = param.parameter_value.is_some(); let mut params = self.params; let mut wheres = self.where_clauses; - if has_param_value { - let param_name = param.name.as_ref().unwrap().to_string(); - params.push(param); - wheres.push(format!("{} = @{}", field_db_name, param_name)); - } else { + let mut has_param_value = false; + if let Some(param) = param { + if param.parameter_value.is_some() { + has_param_value = true; + let param_name = param.name.as_ref().unwrap().to_string(); + params.push(param); + wheres.push(format!("{} = @{}", field_db_name, param_name)); + } + } + if !has_param_value { wheres.push(format!("{} is NULL", field_db_name)); } Ok(Self { @@ -238,16 +242,38 @@ impl ..self }) } + + fn add_params_for_table_query_fields(&mut self) -> Result<()> { + trace!("add_params_for_table_query_fields()"); + let local_fields = Table::get_query_fields(true); + let starting_data = &self.starting_data.0; + for (local_field_name, _) in local_fields { + let para = Table::get_parameter_from_field(starting_data, &local_field_name)?; + if let Some(para) = para { + let mut has_param = false; + for existing_para in &self.params { + if existing_para.name == para.name { + has_param = true; + break; + } + } + if !has_param { + self.params.push(para); + } + } + } + Ok(()) + } } //endregion impl - QueryBuilder +QueryBuilder { //region set query content pub fn add_where_eq(self, column: &str, value: Option<&T>) -> Result - where - T: BigDataValueType + Debug, + where + T: BigDataValueType + Debug, { trace!("add_where_eq({:?}, {:?})", column, value); let column = Table::get_field_db_name(column)?; @@ -255,19 +281,19 @@ impl - QueryBuilder +QueryBuilder { pub fn set_data( self, @@ -352,7 +378,7 @@ impl - QueryBuilder +QueryBuilder { pub fn select() -> QueryBuilder { @@ -383,26 +409,20 @@ impl //endregion //region QueryTypeInsert impl - QueryBuilder> +QueryBuilder> { pub fn build_query( - self, + mut self, ) -> Result< QueryBuilder>, > { trace!("build_query: insert: {:?}", self); let table_identifier = Table::get_table_identifier_from_client(&self.client.0); - let fields = self.get_fields_string(); - let values = self.get_values_params_string()?; let params = &self.params; log::warn!("params are not used in insert query: {:?}", params); - let mut params = vec![]; - let local_fields = Table::get_query_fields(true); - let starting_data = &self.starting_data.0; - for (local_field_name, _) in local_fields { - let para = Table::get_parameter_from_field(starting_data, &local_field_name)?; - params.push(para); - } + self.add_params_for_table_query_fields()?; + let fields = self.get_fields_string(); + let values = self.get_values_params_string()?; let query = format!( "insert into {} ({}) values({})", @@ -410,7 +430,7 @@ impl ); Ok(QueryBuilder { query, - params, + params: self.params, where_clauses: self.where_clauses, order_by: self.order_by, limit: self.limit, @@ -423,20 +443,40 @@ impl } fn get_values_params_string(&self) -> Result { - let values = self.get_value_parameter_names()?; + trace!("get_values_params_string\tself: {:?}", self); + let values: Vec> = self.get_value_parameter_names()?; Ok(values .iter() - .map(|v| format!("@{}", v)) + .map(|v| match v { + Some(v) => format!("@{}", v), + None => String::from("NULL"), + }) .collect::>() .join(", ")) } - - fn get_value_parameter_names(&self) -> Result> { + /// Returns a vector of parameter names for the values in the insert query. + /// + /// If the parameter for a field does not exists, it will just be NULL in + /// the query, not a parameter. + fn get_value_parameter_names(&self) -> Result>> { + trace!("get_value_parameter_names\tself: {:?}", self); let mut values = self.get_sorted_selected_fields(); + let existing_params: Vec = self.params.iter().map(|p| p.name.clone().unwrap()).collect(); + debug!("existing_params: len: {} params: {:?}", existing_params.len(), existing_params); + debug!("selected_fields: len: {} fields: {:?}", values.len(), values); let res = values .iter_mut() - .map(|(field, _)| Table::get_field_param_name(field)) - .collect::>>()?; + .map(|(field, _)| match Table::get_field_param_name(field) { + Ok(param_name) => { + if existing_params.contains(¶m_name) { + Ok(Some(param_name)) + } else { + Ok(None) + } + } + Err(e) => Err(e), + }) + .collect::>>>()?; Ok(res) } } @@ -444,7 +484,7 @@ impl //endregion //region QueryTypeUpdate impl - QueryBuilder> +QueryBuilder> { pub fn build_query( mut self, @@ -453,7 +493,6 @@ impl > { trace!("build_query: update: {:?}", self); let table_identifier = Table::get_table_identifier_from_client(&self.client.0); - let fields_str = self.build_update_fields_string()?; if self.where_clauses.is_empty() { trace!("no where clause, adding pk field to where clause"); self = self.add_field_where(&Table::get_pk_field_name())?; @@ -461,13 +500,8 @@ impl let where_clause = self.build_where_string(); let params = &self.params; log::warn!("params are not used in update query: {:?}", params); - let mut params = vec![]; - let local_fields = Table::get_query_fields(true); - let starting_data = &self.starting_data.0; - for (local_field_name, _) in local_fields { - let para = Table::get_parameter_from_field(starting_data, &local_field_name)?; - params.push(para); - } + self.add_params_for_table_query_fields()?; + let fields_str = self.build_update_fields_string()?; let query = format!( "update {} set {} {}", @@ -475,7 +509,7 @@ impl ); Ok(QueryBuilder { query, - params, + params: self.params, where_clauses: self.where_clauses, order_by: self.order_by, limit: self.limit, @@ -492,20 +526,28 @@ impl let result = self .get_value_parameter_names()? .into_iter() - .map(|(f, p)| format!("{} = @{}", f, p).to_string()) + .map(|(f, p)| match p { + Some(p) => format!("{} = @{}", f, p), + None => format!("{} = NULL", f), + } + ) .collect::>() .join(", "); trace!("build_update_fields_string: result: {}", result); Ok(result) } - fn get_value_parameter_names(&self) -> Result> { + fn get_value_parameter_names(&self) -> Result)>> { let mut values = self.get_sorted_selected_fields(); + let existing_params: Vec = self.params.iter().map(|p| p.name.clone().unwrap()).collect(); let mut res = vec![]; for (field, _) in values.iter_mut() { res.push(( Table::get_field_db_name(field)?, - Table::get_field_param_name(field)?, + match existing_params.contains(&Table::get_field_param_name(field)?) { + true => Some(Table::get_field_param_name(field)?), + false => None, + }, )); } Ok(res) @@ -516,7 +558,7 @@ impl //region QueryTypeSelect //region client not needed impl - QueryBuilder +QueryBuilder { pub fn add_order_by( mut self, @@ -531,7 +573,7 @@ impl //endregion //region client needed impl - QueryBuilder +QueryBuilder { pub fn build_query( self, @@ -566,7 +608,7 @@ impl //endregion //region with_client impl - QueryBuilder +QueryBuilder { pub fn with_client( self, @@ -590,7 +632,7 @@ impl //endregion //region un_build & get query string impl - QueryBuilder +QueryBuilder { pub fn un_build( self, @@ -616,7 +658,7 @@ impl //endregion //region run impl - QueryBuilder +QueryBuilder { pub async fn run(self) -> Result> { trace!("run query: {}", self.query); diff --git a/tests/tests.rs b/tests/tests.rs index 07ff45a..fa7b90d 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -15,13 +15,13 @@ pub struct DbInfos { #[primary_key] #[db_name("Id")] row_id: i64, - info1: Option, + info1: Option::, #[db_name("info")] - info2: Option, - info3: Option, - info4i: Option, + info2: Option::, + info3: Option::, + info4i: Option::, #[db_name("yes")] - info4b: Option, + info4b: Option::, } pub struct DbInfos2 { @@ -47,9 +47,9 @@ async fn test1() { debug!("select result: {:?}", result); let sample_data = DbInfos { client: client.clone(), - row_id: 1, + row_id: 9999, info1: Some("test1".to_string()), - info2: Some("test2".to_string()), + info2: None, info3: Some("test3".to_string()), info4i: Some(1), info4b: Some(true), @@ -75,10 +75,10 @@ async fn test_save() { .expect("get_by_pk failed"); entry.info1 = Some("test1".to_string()); entry.info2 = Some("test2".to_string()); - entry.info3 = Some("test3".to_string()); + entry.info3 = None; entry.info4i = Some(1); entry.info4b = Some(true); - log::debug!("entry: {:?}", entry); + debug!("entry: {:?}", entry); debug!("========================================================================"); debug!("starting save"); debug!("========================================================================"); @@ -86,6 +86,18 @@ async fn test_save() { debug!("========================================================================"); debug!("save done"); debug!("========================================================================"); + let info1 = entry.info1.clone().unwrap(); + entry.info1 = Some("0987654321".to_string()); + + debug!("========================================================================"); + debug!("starting reload"); + debug!("========================================================================"); + entry.reload().await.expect("reload failed"); + debug!("========================================================================"); + debug!("reload done"); + debug!("========================================================================"); + assert_eq!(info1, entry.info1.unwrap(), "reload failed"); + assert_eq!(None, entry.info3, "Info 3 should be set to None before the save happened"); } #[tokio::test]