[英]How to optimise the code considering different input sizes?
我想以最有效的方式計算蒸汽屬性,考慮標量、向量和矩陣作為兩個參數的輸入選項。 我擔心的是,我必須使用 if 塊相對於輸入(標量、向量或矩陣)的大小,使代碼變得很長。 我是一個簡單的機械工程師,對 python 很陌生,非常感謝有關如何優化代碼的任何幫助。 這是代碼:
from iapws.iapws97 import _Region4
import numpy as np
def h_x(P,x):
''' spec enthalpy in liquid, steam and wet (two-phase flow) regions
P - pressure in bar
x - drayness steam fraction [-]
h - specific heat of wet region returned [kJ/kW]
'''
mm = len(np.shape(x))
if mm == 0:
h_ = _Region4(P/10,0)['h']
h__ = _Region4(P/10,1)['h']
# return h_ + x * (h__ - h_)
return h_ + x * (h__ - h_)
elif mm == 1:
return np.array([ _Region4(i/10,0)['h'] + j * ( _Region4(i/10,1)['h'] - _Region4(i/10,0)['h'] ) for i,j in zip(P,x) ])
elif mm == 2:
mmm,nnn = x.shape
h = np.ndarray(shape=(mmm,nnn)) #(mm,nn)
for i in range(mmm):
for j in range(nnn):
h_ = _Region4(P[i,j]/10,0)['h']
h__ = _Region4(P[i,j]/10,1)['h']
h[i,j] = h_ + x[i,j] * (h__ - h_)
return h
else:
print('h_x input must be scalar, vector or 2D matrix!')
# code testing
P = np.array([[.0234,.0193,0.244],[.0244,.0185,0.254]])
x = np.array([[.812,.782,.620],[.912,.882,.820]])
h_x(P,x)
你真的只做一種計算,但是用兩種不同的方式。 您可以將其拉出並將其應用於 function 使用內置map
作為輸入的任何內容。 如果失敗,那么你有一個單一的(不可迭代的)值,你可以直接應用你的計算。
# Define a dummy func to make code work
def _Region4(a, b):
return {'h': a + 3 * b}
import numpy as np
def calculate(P, x):
'''
Spec enthalpy in liquid, steam and wet (two-phase flow) regions
P - pressure in bar
x - drayness steam fraction [-]
h - specific heat of wet region returned [kJ/kW]
'''
h_ = _Region4(P/10,0)['h']
h__ = _Region4(P/10,1)['h']
return h_ + x * (h__ - h_)
def h_x(P,x):
shape = np.shape(x)
dimensions = len(shape)
# Check for wrong input
if dimensions > 2:
raise ValueError('h_x input must be scalar, vector or 2D matrix!')
# Try a general mapping
try:
return np.array(list(map(calculate, P.flat, x.flat))).reshape(shape)
# if it fails then you got a pair of scalar
except AttributeError:
return calculate(P, x)
P = np.array([[.023,.23,.05],[.023,.23,.05]])
x = np.array([[.92,.98,.99],[.92,.98,.99]])
print(h_x(P, x))
"""
Out:
[[2.7623 2.963 2.975 ]
[2.7623 2.963 2.975 ]]
"""
print(h_x(3, 4))
"""
Out:
12.3
"""
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.